diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index a7f920222..2e67973ad 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -126,10 +126,18 @@ impl Loopback { let buffer = self.queue.pop_front(); match buffer { Some(buffer) => { - // debug!("lo receive:{:?}", buffer); + // log::debug!( + // "lo receive: {} bytes, remaining_queue_len={}, self_ptr={:p}", + // buffer.len(), + // self.queue.len(), + // self + // ); return buffer; } None => { + if !self.queue.is_empty() { + log::warn!("lo receive: queue not empty but pop_front returned None!"); + } return Vec::new(); } } @@ -141,8 +149,13 @@ impl Loopback { /// - &mut self:自身可变引用 /// - buffer:需要发送的数据包 pub fn loopback_transmit(&mut self, buffer: Vec) { - // debug!("lo transmit:{:?}", buffer); - self.queue.push_back(buffer) + // log::debug!( + // "lo transmit: {} bytes, queue_len={}, self_ptr={:p}", + // buffer.len(), + // self.queue.len(), + // self + // ); + self.queue.push_back(buffer); } } @@ -240,10 +253,9 @@ impl phy::Device for LoopbackDriver { let buffer = self.inner.lock().loopback_receive(); //receive队列为为空,返回NONE值以通知上层没有可以receive的包 if buffer.is_empty() { - // log::debug!("lo receive none!"); return Option::None; } - // log::debug!("lo receive!"); + // log::debug!("LoopbackDriver::receive() -> packet {} bytes", buffer.len()); let rx = LoopbackRxToken { buffer }; let tx = LoopbackTxToken { driver: self.clone(), @@ -260,7 +272,6 @@ impl phy::Device for LoopbackDriver { /// ## 返回值 /// - 返回一个 `Some`,其中包含一个发送令牌,该令牌包含一个对自身的克隆引用 fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { - // log::debug!("lo transmit!"); Some(LoopbackTxToken { driver: self.clone(), }) diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 469f6514b..820fe3432 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -266,11 +266,10 @@ impl IfaceCommon { let mut interface = self.smol_iface.lock_irqsave(); let (has_events, poll_at) = { + let poll_result = interface.poll(timestamp, device, &mut sockets); + ( - matches!( - interface.poll(timestamp, device, &mut sockets), - smoltcp::iface::PollResult::SocketStateChanged - ), + matches!(poll_result, smoltcp::iface::PollResult::SocketStateChanged), loop { let poll_at = interface.poll_at(timestamp, &sockets); let Some(instant) = poll_at else { diff --git a/kernel/src/filesystem/vfs/syscall/sys_readv.rs b/kernel/src/filesystem/vfs/syscall/sys_readv.rs index 942a87de4..584908b23 100644 --- a/kernel/src/filesystem/vfs/syscall/sys_readv.rs +++ b/kernel/src/filesystem/vfs/syscall/sys_readv.rs @@ -34,6 +34,16 @@ impl Syscall for SysReadVHandle { // IoVecs 会进行用户态检验(包含 len==0 的 iov_base 校验)。 let iovecs = unsafe { IoVecs::from_user(iov, count, true) }?; + // TODO: Here work around, not suppose to read entire buf once + use crate::process::ProcessManager; + if let Ok(_socket_inode) = ProcessManager::current_pcb().get_socket_inode(fd) { + // Socket: read entire message then scatter to iovecs + let mut buf = iovecs.new_buf(true); + let nread = do_read(fd, &mut buf)?; + iovecs.scatter(&buf[..nread])?; + return Ok(nread); + } + // Linux: limit per readv() to MAX_RW_COUNT = INT_MAX & ~(PAGE_SIZE-1) let max_rw_count = (i32::MAX as usize) & !(MMArch::PAGE_SIZE - 1); diff --git a/kernel/src/net/posix.rs b/kernel/src/net/posix.rs index f74346764..91d52f5c9 100644 --- a/kernel/src/net/posix.rs +++ b/kernel/src/net/posix.rs @@ -217,6 +217,12 @@ impl From for SockAddr { impl From for SockAddr { fn from(value: Endpoint) -> Self { match value { + Endpoint::Unspecified => Self { + addr_ph: SockAddrPlaceholder { + family: 0, // AF_UNSPEC + data: [0; 14], + }, + }, Endpoint::LinkLayer(link_layer_endpoint) => Self::from(link_layer_endpoint), Endpoint::Ip(endpoint) => Self::from(endpoint), Endpoint::Unix(unix_endpoint) => Self::from(unix_endpoint), @@ -257,11 +263,11 @@ impl SockAddr { AddressFamily::INet => { // 下限检查:至少需要包含完整的 sockaddr_in 结构体 if len < size_of::() as u32 { - log::error!( - "len {} < sizeof(sockaddr_in) {}", - len, - size_of::() - ); + // log::error!( + // "len {} < sizeof(sockaddr_in) {}", + // len, + // size_of::() + // ); return Err(SystemError::EINVAL); } @@ -275,6 +281,12 @@ impl SockAddr { return Ok(Endpoint::Ip(wire::IpEndpoint::new(ip, port))); } + + AddressFamily::Unspecified => { + // AF_UNSPEC is used to disconnect sockets + Ok(Endpoint::Unspecified) + } + AddressFamily::INet6 => { // 下限检查:至少需要包含完整的 sockaddr_in6 结构体 if len < size_of::() as u32 { diff --git a/kernel/src/net/socket/base.rs b/kernel/src/net/socket/base.rs index 97dfc382e..92b13d813 100644 --- a/kernel/src/net/socket/base.rs +++ b/kernel/src/net/socket/base.rs @@ -44,12 +44,27 @@ pub trait Socket: PollableInode + IndexNode { fn send_buffer_size(&self) -> usize; fn recv_buffer_size(&self) -> usize; + + /// # `recv_bytes_available` + /// Get the number of bytes currently available to read from the socket. + /// Returns 0 by default for socket types that don't track this. + fn recv_bytes_available(&self) -> Result { + Err(SystemError::ENOTTY) + } + + /// # `send_bytes_available` + /// Get the number of bytes currently available to write to the socket. + /// Returns 0 by default for socket types that don't track this. + fn send_bytes_available(&self) -> Result { + Err(SystemError::ENOTTY) + } + /// # `accept` /// 接受连接,仅用于listening stream socket /// ## Block /// 如果没有连接到来,会阻塞 fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { - Err(SystemError::ENOSYS) + Err(SystemError::EOPNOTSUPP_OR_ENOTSUP) } /// # `bind` @@ -96,7 +111,7 @@ pub trait Socket: PollableInode + IndexNode { /// # `listen` /// 监听socket,仅用于stream socket fn listen(&self, _backlog: usize) -> Result<(), SystemError> { - Err(SystemError::ENOSYS) + Err(SystemError::EOPNOTSUPP_OR_ENOTSUP) } // poll diff --git a/kernel/src/net/socket/common/shutdown.rs b/kernel/src/net/socket/common/shutdown.rs index b52eb7e14..400df11e2 100644 --- a/kernel/src/net/socket/common/shutdown.rs +++ b/kernel/src/net/socket/common/shutdown.rs @@ -31,14 +31,9 @@ impl TryFrom for ShutdownBit { // Linux/POSIX shutdown(2): // 0 = SHUT_RD, 1 = SHUT_WR, 2 = SHUT_RDWR match value { - 0 => Ok(ShutdownBit { - bit: Self::RCV_SHUTDOWN, - }), - 1 => Ok(ShutdownBit { - bit: Self::SEND_SHUTDOWN, - }), - 2 => Ok(ShutdownBit { - bit: Self::RCV_SHUTDOWN | Self::SEND_SHUTDOWN, + // SHUT_RD = 0, SHUT_WR = 1, SHUT_RDWR = 2 + 0..=2 => Ok(ShutdownBit { + bit: value as u8 + 1, }), _ => Err(Self::Error::EINVAL), } diff --git a/kernel/src/net/socket/endpoint.rs b/kernel/src/net/socket/endpoint.rs index dcbe699b7..c4fe3567b 100644 --- a/kernel/src/net/socket/endpoint.rs +++ b/kernel/src/net/socket/endpoint.rs @@ -10,6 +10,8 @@ pub use smoltcp::wire::IpEndpoint; #[derive(Debug, Clone)] pub enum Endpoint { + /// 未指定端点 (AF_UNSPEC) - 用于UDP断开连接 + Unspecified, /// 链路层端点 LinkLayer(LinkLayerEndpoint), /// 网络层端点 @@ -95,6 +97,7 @@ impl From for Endpoint { impl Endpoint { fn sockaddr_len(&self) -> Result { match self { + Endpoint::Unspecified => Ok(SockAddr::from(self.clone()).len()?), Endpoint::LinkLayer(_) => Ok(SockAddr::from(self.clone()).len()?), Endpoint::Ip(_) => Ok(SockAddr::from(self.clone()).len()?), Endpoint::Netlink(_) => Ok(SockAddr::from(self.clone()).len()?), diff --git a/kernel/src/net/socket/inet/common/mod.rs b/kernel/src/net/socket/inet/common/mod.rs index 9b174a855..5a9bbf2cb 100644 --- a/kernel/src/net/socket/inet/common/mod.rs +++ b/kernel/src/net/socket/inet/common/mod.rs @@ -66,6 +66,11 @@ impl BoundInner { }); } else { let iface = get_iface_to_bind(address, netns.clone()).ok_or(SystemError::ENODEV)?; + // log::debug!( + // "BoundInner::bind: binding to iface {} for address {:?}", + // iface.iface_name(), + // address + // ); let handle = iface.sockets().lock().add(socket); return Ok(Self { handle, @@ -154,15 +159,26 @@ pub fn get_iface_to_bind( ) -> Option> { // log::debug!("get_iface_to_bind: {:?}", ip_addr); // if ip_addr.is_unspecified() - netns + let result = netns .device_list() .iter() .find(|(_, iface)| { let guard = iface.smol_iface().lock(); - // log::debug!("iface name: {}, ip: {:?}", iface.iface_name(), guard.ip_addrs()); + // log::debug!( + // " checking iface: {}, ip: {:?}, has_addr={}", + // iface.iface_name(), + // guard.ip_addrs(), + // guard.has_ip_addr(*ip_addr) + // ); return guard.has_ip_addr(*ip_addr); }) - .map(|(_, iface)| iface.clone()) + .map(|(_, iface)| iface.clone()); + + // log::debug!( + // "get_iface_to_bind: returning iface {:?}", + // result.as_ref().map(|i| i.iface_name()) + // ); + result } /// Get a suitable iface to deal with sendto/connect request if the socket is not bound to an iface. diff --git a/kernel/src/net/socket/inet/datagram/inner.rs b/kernel/src/net/socket/inet/datagram/inner.rs index 1902a6991..2c88a8268 100644 --- a/kernel/src/net/socket/inet/datagram/inner.rs +++ b/kernel/src/net/socket/inet/datagram/inner.rs @@ -12,8 +12,12 @@ use crate::{ pub type SmolUdpSocket = smoltcp::socket::udp::Socket<'static>; pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; -pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; -pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; +// UDP maximum datagram size is 65507 bytes (65535 - 8 byte UDP header - 20 byte IP header) +// Set buffer sizes to accommodate this plus some overhead +pub const DEFAULT_RX_BUF_SIZE: usize = 128 * 1024; // 128 KB +pub const DEFAULT_TX_BUF_SIZE: usize = 128 * 1024; // 128 KB + // Minimum buffer size (Linux uses 256 bytes minimum) +pub const MIN_BUF_SIZE: usize = 256; #[derive(Debug)] pub struct UnboundUdp { @@ -22,13 +26,53 @@ pub struct UnboundUdp { impl UnboundUdp { pub fn new() -> Self { + Self::new_with_buf_size(0, 0) + } + + pub fn new_with_buf_size(rx_size: usize, tx_size: usize) -> Self { + // Buffer sizing strategy: + // - setsockopt(SO_RCVBUF, X) stores X + // - getsockopt(SO_RCVBUF) returns 2*X (Linux convention) + // - Actual buffer allocation: 2*X + // + // This is a straightforward 2x design that matches the getsockopt return value. + // + // Note: smoltcp's PacketBuffer has separate metadata_ring and payload_ring. + // Unlike Linux where sk_buff metadata shares the same buffer space as payload, + // smoltcp allocates them independently. This means: + // - We allocate 2*X bytes purely for payload (no metadata overhead) + // - This may accept more packets than Linux in some edge cases + // + // Differences from Linux behavior: + // - Linux: Buffer space shared between metadata + payload, so effective payload < 2*X + // - DragonOS: Full 2*X available for payload (metadata stored separately) + + let rx_buf_size = if rx_size > 0 { + rx_size * 2 // Simple 2x allocation + } else { + DEFAULT_RX_BUF_SIZE + }; + let tx_buf_size = if tx_size > 0 { + tx_size * 2 // Simple 2x allocation + } else { + DEFAULT_TX_BUF_SIZE + }; + + // log::debug!( + // "new_with_buf_size: requested rx={}, tx={} -> allocating rx={}, tx={} (2x)", + // rx_size, + // tx_size, + // rx_buf_size, + // tx_buf_size + // ); + let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( vec![smoltcp::socket::udp::PacketMetadata::EMPTY; DEFAULT_METADATA_BUF_SIZE], - vec![0; DEFAULT_RX_BUF_SIZE], + vec![0; rx_buf_size], ); let tx_buffer = smoltcp::socket::udp::PacketBuffer::new( vec![smoltcp::socket::udp::PacketMetadata::EMPTY; DEFAULT_METADATA_BUF_SIZE], - vec![0; DEFAULT_TX_BUF_SIZE], + vec![0; tx_buf_size], ); let socket = SmolUdpSocket::new(rx_buffer, tx_buffer); @@ -43,11 +87,17 @@ impl UnboundUdp { let inner = BoundInner::bind(self.socket, &local_endpoint.addr, netns)?; let bind_addr = local_endpoint.addr; let bind_port = if local_endpoint.port == 0 { - inner.port_manager().bind_ephemeral_port(InetTypes::Udp)? + let port = inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?; + // log::debug!("UnboundUdp::bind: allocated ephemeral port {}", port); + port } else { inner .port_manager() .bind_port(InetTypes::Udp, local_endpoint.port)?; + // log::debug!( + // "UnboundUdp::bind: explicit bind to port {}", + // local_endpoint.port + // ); local_endpoint.port }; @@ -69,6 +119,8 @@ impl UnboundUdp { Ok(BoundUdp { inner, remote: SpinLock::new(None), + explicitly_bound: true, + has_preconnect_data: SpinLock::new(false), }) } @@ -77,13 +129,36 @@ impl UnboundUdp { remote: smoltcp::wire::IpAddress, netns: Arc, ) -> Result { - // let (addr, port) = (remote.addr, remote.port); - let (inner, address) = BoundInner::bind_ephemeral(self.socket, remote, netns)?; + let (inner, local_addr) = BoundInner::bind_ephemeral(self.socket, remote, netns)?; let bound_port = inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?; - let endpoint = smoltcp::wire::IpEndpoint::new(address, bound_port); + // log::debug!( + // "UnboundUdp::bind_ephemeral: allocated ephemeral port {} for remote {:?}", + // bound_port, + // remote + // ); + + // Bind the smoltcp socket to the local endpoint + if local_addr.is_unspecified() { + if inner + .with_mut::(|socket| socket.bind(bound_port)) + .is_err() + { + return Err(SystemError::EINVAL); + } + } else if inner + .with_mut::(|socket| { + socket.bind(smoltcp::wire::IpEndpoint::new(local_addr, bound_port)) + }) + .is_err() + { + return Err(SystemError::EINVAL); + } + Ok(BoundUdp { inner, - remote: SpinLock::new(Some(endpoint)), + remote: SpinLock::new(None), + explicitly_bound: false, + has_preconnect_data: SpinLock::new(false), }) } } @@ -92,6 +167,13 @@ impl UnboundUdp { pub struct BoundUdp { inner: BoundInner, remote: SpinLock>, + /// True if socket was explicitly bound by user, false if implicitly bound by connect + explicitly_bound: bool, + /// Whether there were buffered packets at connect time - if true, allow next recv without filtering + /// 这是用来模拟 Linux UDP 在应用filter前的行为。在smoltcp下,当有包到来时总是会推送到 + /// udp socket queue 中,而不是先针对connect进行filter操作。这里做workaround, 当connect是检查是否有包 + /// 在缓冲区,如果有,第一个包我们走非connect而不是connect的recv方法(即接受第一个非connect对端对应的包) + has_preconnect_data: SpinLock, } impl BoundUdp { @@ -123,21 +205,141 @@ impl BoundUdp { } pub fn connect(&self, remote: smoltcp::wire::IpEndpoint) { + // let _local = self.endpoint(); + // log::debug!( + // "BoundUdp::connect: local={:?}, connecting to remote={:?}", + // _local, + // remote + // ); + + // Check if there are buffered packets - if so, allow next recv without filtering + let has_buffered = self.with_socket(|socket| socket.can_recv()); + *self.has_preconnect_data.lock() = has_buffered; + // log::debug!("BoundUdp::connect: has pre-connect data = {}", has_buffered); + self.remote.lock().replace(remote); } + pub fn disconnect(&self) { + self.remote.lock().take(); + } + + /// Returns true if this socket should be unbound on disconnect + pub fn should_unbind_on_disconnect(&self) -> bool { + !self.explicitly_bound + } + #[inline] pub fn try_recv( &self, buf: &mut [u8], + peek: bool, ) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> { + let remote = *self.remote.lock(); + self.with_mut_socket(|socket| { - if socket.can_recv() { - if let Ok((size, metadata)) = socket.recv_slice(buf) { - return Ok((size, metadata.endpoint)); - } + // If connected, filter packets by source address (except pre-connect packets) + + let mut has_preconnect_guard = self.has_preconnect_data.lock(); + let has_preconnect = *has_preconnect_guard; + // let has_preconnect = false; + if has_preconnect { + *has_preconnect_guard = false; + } + drop(has_preconnect_guard); + let should_filter = remote.is_some() && !has_preconnect; + if should_filter { + let expected_remote = remote.unwrap(); + // log::debug!("try_recv: connected mode, expected_remote={:?}, buf_len={}, can_recv={}", + // expected_remote, buf.len(), socket.can_recv()); + + // Loop to skip packets from unexpected sources + loop { + if !socket.can_recv() { + // log::debug!("try_recv: can_recv=false, returning EAGAIN"); + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + + // Peek to check source address before receiving + // Note: peek() instead of peek_slice() because peek_slice() returns Truncated + // error when buffer is smaller than packet, but we still want to receive it + match socket.peek() { + Ok((payload, metadata)) => { + // log::debug!("try_recv: peeked {} bytes from {:?}, buf_len={}", payload.len(), metadata.endpoint, buf.len()); + if metadata.endpoint == expected_remote { + // Source matches + + // Special case: zero-length buffer + if buf.is_empty() { + // log::debug!("try_recv: zero-length buffer in connected mode, returning 0 bytes"); + return Ok((0, expected_remote)); + } + + if peek { + // MSG_PEEK: just copy the data we peeked + let copy_len = core::cmp::min(buf.len(), payload.len()); + buf[..copy_len].copy_from_slice(&payload[..copy_len]); + // log::debug!("try_recv: peek succeeded, size={}", copy_len); + return Ok((copy_len, expected_remote)); + } else { + // Receive the packet + let (recv_buf, _metadata) = + socket.recv().map_err(|_| SystemError::ENOBUFS)?; + let length = core::cmp::min(buf.len(), recv_buf.len()); + buf[..length].copy_from_slice(&recv_buf[..length]); + debug_assert_eq!(expected_remote, _metadata.endpoint); + return Ok((length, expected_remote)); + } + } else { + // just drop the packet + let _ = socket.recv(); + continue; + } + } + Err(smoltcp::socket::udp::RecvError::Exhausted) => { + return Err(SystemError::ENOBUFS) + } + Err(_e) => return Err(SystemError::EIO), + } + } + } else { + // log::debug!("try_recv: unconnected mode, buf_len={}, can_recv={}", buf.len(), socket.can_recv()); + // Not connected, receive from any source + + // Special case: if buffer length is 0, just peek to check if data exists + if buf.is_empty() { + if socket.can_recv() { + // Peek to get the source endpoint without consuming data + if let Ok((_payload, metadata)) = socket.peek() { + // log::debug!("try_recv: zero-length buffer with data available, returning 0 bytes from {:?}", metadata.endpoint); + return Ok((0, metadata.endpoint)); + } + } + // log::debug!("try_recv: zero-length buffer with no data, returning EAGAIN"); + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + + if socket.can_recv() { + if peek { + // MSG_PEEK: peek data without consuming + if let Ok((payload, metadata)) = socket.peek() { + let copy_len = core::cmp::min(buf.len(), payload.len()); + buf[..copy_len].copy_from_slice(&payload[..copy_len]); + // log::debug!("try_recv: unconnected peek succeeded, size={}", copy_len); + return Ok((copy_len, metadata.endpoint)); + } + } else { + // Receive the packet // Receive the packet + let (recv_buf, metadata) = + socket.recv().map_err(|_| SystemError::ENOBUFS)?; + let length = core::cmp::min(buf.len(), recv_buf.len()); + buf[..length].copy_from_slice(&recv_buf[..length]); + return Ok((length, metadata.endpoint)); + } + } + // log::debug!("try_recv: unconnected recv failed, returning EAGAIN"); + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); } - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); }) } @@ -146,15 +348,45 @@ impl BoundUdp { buf: &[u8], to: Option, ) -> Result { - let remote = to.or(*self.remote.lock()).ok_or(SystemError::ENOTCONN)?; - let result = self.with_mut_socket(|socket| { - if socket.can_send() && socket.send_slice(buf, remote).is_ok() { - // log::debug!("send {} bytes", buf.len()); - return Ok(buf.len()); + let connected_remote = *self.remote.lock(); + let mut remote = to.or(connected_remote).ok_or(SystemError::ENOTCONN)?; + + // Validate port - sending to port 0 is invalid + if remote.port == 0 { + log::warn!("UDP try_send: attempted to send to port 0"); + return Err(SystemError::EINVAL); + } + + // Linux treats sending to 0.0.0.0 (INADDR_ANY) as sending to localhost + // smoltcp rejects it as "Unaddressable", so we translate it here + if remote.addr.is_unspecified() { + remote.addr = smoltcp::wire::IpAddress::v4(127, 0, 0, 1); + } + + // log::debug!( + // "try_send: sending {} bytes to {:?}, can_send={}", + // buf.len(), + // remote, + // self.with_socket(|socket| socket.can_send()) + // ); + + self.with_mut_socket(|socket| { + if socket.can_send() { + match socket.send_slice(buf, remote) { + Ok(_) => { + // log::debug!("try_send: send successful"); + Ok(buf.len()) + } + Err(_e) => { + // log::debug!("try_send: send failed: {:?}", _e); + Err(SystemError::ENOBUFS) + } + } + } else { + // log::debug!("try_send: can_send=false, returning ENOBUFS"); + Err(SystemError::ENOBUFS) } - return Err(SystemError::ENOBUFS); - }); - return result; + }) } pub fn inner(&self) -> &BoundInner { diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index 21619e950..feddf759a 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -1,17 +1,17 @@ -use inner::{UdpInner, UnboundUdp}; +use inner::{UdpInner, UnboundUdp, DEFAULT_RX_BUF_SIZE, DEFAULT_TX_BUF_SIZE, MIN_BUF_SIZE}; use smoltcp; use system_error::SystemError; use crate::filesystem::epoll::EPollEventType; use crate::filesystem::vfs::{fasync::FAsyncItems, vcore::generate_inode_id, InodeId}; use crate::libs::wait_queue::WaitQueue; -use crate::net::socket::common::EPollItems; -use crate::net::socket::{Socket, PMSG}; +use crate::net::socket::common::{EPollItems, ShutdownBit}; +use crate::net::socket::{Socket, PMSG, PSO, PSOL}; use crate::process::namespace::net_namespace::NetNamespace; use crate::process::ProcessManager; use crate::{libs::rwlock::RwLock, net::socket::endpoint::Endpoint}; use alloc::sync::{Arc, Weak}; -use core::sync::atomic::{AtomicBool, AtomicUsize}; +use core::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}; use super::InetSocket; @@ -25,6 +25,7 @@ type EP = crate::filesystem::epoll::EPollEventType; pub struct UdpSocket { inner: RwLock>, nonblock: AtomicBool, + shutdown: AtomicU8, wait_queue: WaitQueue, inode_id: InodeId, open_files: AtomicUsize, @@ -32,6 +33,24 @@ pub struct UdpSocket { netns: Arc, epoll_items: EPollItems, fasync_items: FAsyncItems, + /// Custom send buffer size (SO_SNDBUF), 0 means use default + send_buf_size: AtomicUsize, + /// Custom receive buffer size (SO_RCVBUF), 0 means use default + recv_buf_size: AtomicUsize, + /// SO_NO_CHECK: disable UDP checksum (0=off, 1=on) + /// + /// NOTE: This is currently a stub implementation. The value can be set/get via + /// setsockopt/getsockopt, but does NOT actually control UDP checksum behavior. + /// + /// Reason: smoltcp 0.12.0 does not support per-socket checksum control. Checksum + /// behavior is controlled globally by DeviceCapabilities.checksum, which is set at + /// the Device/Interface level, not per-socket. + /// + /// To implement this properly would require either: + /// 1. Smoltcp feature that supports per-socket checksum control + /// 2. Patching smoltcp to add this feature + /// 3. Manually parsing/building UDP packets to bypass smoltcp's checksum handling + no_check: AtomicBool, } impl UdpSocket { @@ -40,6 +59,7 @@ impl UdpSocket { Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(UdpInner::Unbound(UnboundUdp::new()))), nonblock: AtomicBool::new(nonblock), + shutdown: AtomicU8::new(0), wait_queue: WaitQueue::default(), inode_id: generate_inode_id(), open_files: AtomicUsize::new(0), @@ -47,6 +67,9 @@ impl UdpSocket { netns, epoll_items: EPollItems::default(), fasync_items: FAsyncItems::default(), + send_buf_size: AtomicUsize::new(0), // 0 means use default + recv_buf_size: AtomicUsize::new(0), // 0 means use default + no_check: AtomicBool::new(false), // checksums enabled by default }) } @@ -56,53 +79,90 @@ impl UdpSocket { pub fn do_bind(&self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<(), SystemError> { let mut inner = self.inner.write(); - let prev = inner.take().ok_or(SystemError::EINVAL)?; - match prev { - UdpInner::Unbound(unbound) => match unbound.bind(local_endpoint, self.netns()) { - Ok(bound) => { - bound - .inner() - .iface() - .common() - .bind_socket(self.self_ref.upgrade().unwrap()); - *inner = Some(UdpInner::Bound(bound)); - Ok(()) - } - Err(e) => { - // bind 消费了 unbound(move)。失败时恢复到一个新的 Unbound 状态, - // 关键是避免 inner 变成 None 导致后续 check_io_event panic。 - *inner = Some(UdpInner::Unbound(UnboundUdp::new())); - Err(e) - } - }, - other => { - // 非 Unbound 情况下保持原状态 - *inner = Some(other); - Err(SystemError::EINVAL) + + // Check socket state first without taking + match inner.as_ref() { + None => return Err(SystemError::EBADF), + Some(UdpInner::Bound(_)) => return Err(SystemError::EINVAL), // Already bound + Some(UdpInner::Unbound(_)) => {} + } + + // Now safe to take - we know it's Unbound + let _old_unbound = match inner.take() { + Some(UdpInner::Unbound(unbound)) => unbound, + _ => unreachable!(), + }; + + // Check if custom buffer sizes have been set via setsockopt + let rx_size = self.recv_buf_size.load(Ordering::Acquire); + let tx_size = self.send_buf_size.load(Ordering::Acquire); + + // log::debug!( + // "do_bind: rx_size={}, tx_size={}, will use custom buffers={}", + // rx_size, + // tx_size, + // rx_size > 0 || tx_size > 0 + // ); + + // Create new UnboundUdp with custom buffer sizes if they've been set + let unbound = if rx_size > 0 || tx_size > 0 { + // log::debug!( + // "do_bind: creating socket with custom buffer sizes rx={}, tx={}", + // rx_size, + // tx_size + // ); + UnboundUdp::new_with_buf_size(rx_size, tx_size) + } else { + // log::debug!("do_bind: creating socket with default buffer sizes"); + UnboundUdp::new() + }; + + match unbound.bind(local_endpoint, self.netns()) { + Ok(bound) => { + bound + .inner() + .iface() + .common() + .bind_socket(self.self_ref.upgrade().unwrap()); + *inner = Some(UdpInner::Bound(bound)); + Ok(()) + } + Err(e) => { + // Restore unbound state on error + *inner = Some(UdpInner::Unbound(UnboundUdp::new())); + Err(e) } } } - pub fn bind_emphemeral(&self, remote: smoltcp::wire::IpAddress) -> Result<(), SystemError> { + pub fn bind_ephemeral(&self, remote: smoltcp::wire::IpAddress) -> Result<(), SystemError> { let mut inner_guard = self.inner.write(); - let prev = inner_guard.take().ok_or(SystemError::EINVAL)?; - match prev { - UdpInner::Bound(bound) => { - inner_guard.replace(UdpInner::Bound(bound)); - Ok(()) + let inner = inner_guard.take().ok_or(SystemError::EBADF)?; + let bound = match inner { + UdpInner::Bound(inner) => inner, + UdpInner::Unbound(_old_inner) => { + // Check if custom buffer sizes have been set via setsockopt + let rx_size = self.recv_buf_size.load(Ordering::Acquire); + let tx_size = self.send_buf_size.load(Ordering::Acquire); + + // Create new UnboundUdp with custom buffer sizes if they've been set + let inner = if rx_size > 0 || tx_size > 0 { + UnboundUdp::new_with_buf_size(rx_size, tx_size) + } else { + UnboundUdp::new() + }; + + match inner.bind_ephemeral(remote, self.netns()) { + Ok(bound) => bound, + Err(e) => { + inner_guard.replace(UdpInner::Unbound(UnboundUdp::new())); + return Err(e); + } + } } - UdpInner::Unbound(unbound) => match unbound.bind_ephemeral(remote, self.netns()) { - Ok(bound) => { - inner_guard.replace(UdpInner::Bound(bound)); - Ok(()) - } - Err(e) => { - // bind_ephemeral 消费了 unbound(move)。失败则恢复到新的 Unbound 状态。 - inner_guard.replace(UdpInner::Unbound(UnboundUdp::new())); - Err(e) - } - }, - } + }; + inner_guard.replace(UdpInner::Bound(bound)); + Ok(()) } pub fn is_bound(&self) -> bool { @@ -113,6 +173,77 @@ impl UdpSocket { return false; } + /// Recreates the socket with new buffer sizes if it's already bound. + /// This is needed because smoltcp doesn't support resizing socket buffers dynamically. + fn recreate_socket_if_bound(&self) -> Result<(), SystemError> { + use smoltcp::wire::IpListenEndpoint; + + let mut inner_guard = self.inner.write(); + + // Check if socket is bound + let bound = match inner_guard.as_ref() { + Some(UdpInner::Bound(b)) => b, + _ => return Ok(()), // Not bound, nothing to do + }; + + // Save current state before recreating + let local_ep = bound.endpoint(); + let remote_ep = bound.remote_endpoint().ok(); // May be None if not connected + let _explicitly_bound = !bound.should_unbind_on_disconnect(); + + // log::debug!( + // "Recreating UDP socket: local={:?}, remote={:?}, explicit={}", + // local_ep, + // remote_ep, + // explicitly_bound + // ); + + // Get the local address and port + let IpListenEndpoint { addr, port } = local_ep; + let local_addr = addr.unwrap_or_else(|| smoltcp::wire::IpAddress::v4(0, 0, 0, 0)); + + // Unbind the old socket and drop it + if let Some(UdpInner::Bound(b)) = inner_guard.take() { + b.close(); + } + + // Create new UnboundUdp with new buffer sizes + let rx_size = self.recv_buf_size.load(Ordering::Acquire); + let tx_size = self.send_buf_size.load(Ordering::Acquire); + let unbound = if rx_size > 0 || tx_size > 0 { + UnboundUdp::new_with_buf_size(rx_size, tx_size) + } else { + UnboundUdp::new() + }; + + // Rebind to the same endpoint + let new_endpoint = smoltcp::wire::IpEndpoint::new(local_addr, port); + let bound = match unbound.bind(new_endpoint, self.netns()) { + Ok(b) => b, + Err(e) => { + // Restore unbound state on error + *inner_guard = Some(UdpInner::Unbound(UnboundUdp::new())); + return Err(e); + } + }; + + // Restore connection if it existed + if let Some(remote) = remote_ep { + bound.connect(remote); + } + + // Restore the binding in the interface + bound + .inner() + .iface() + .common() + .bind_socket(self.self_ref.upgrade().unwrap()); + + *inner_guard = Some(UdpInner::Bound(bound)); + + Ok(()) + } + pub fn close(&self) { let mut inner = self.inner.write(); if let Some(UdpInner::Bound(bound)) = &mut *inner { @@ -125,27 +256,34 @@ impl UdpSocket { pub fn try_recv( &self, buf: &mut [u8], + peek: bool, ) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> { - let guard = self.inner.read(); - match guard.as_ref() { - Some(UdpInner::Bound(bound)) => { - let ret = bound.try_recv(buf); - bound.inner().iface().poll(); - ret - } - _ => Err(SystemError::ENOTCONN), + match self.inner.read().as_ref().ok_or(SystemError::EBADF)? { + UdpInner::Bound(bound) => bound.try_recv(buf, peek), + // UDP is connectionless - unbound socket just has no data yet + UdpInner::Unbound(_) => Err(SystemError::EAGAIN_OR_EWOULDBLOCK), } } #[inline] pub fn can_recv(&self) -> bool { - self.check_io_event().contains(EP::EPOLLIN) + // Can receive if there's data available OR if read is shutdown + // (shutdown should wake up recv() to return 0/EOF) + let has_data = self.check_io_event().contains(EP::EPOLLIN); + let shutdown_bits = self.shutdown.load(Ordering::Acquire); + let read_shutdown = (shutdown_bits & 0x01) != 0; + has_data || read_shutdown } #[inline] #[allow(dead_code)] pub fn can_send(&self) -> bool { - self.check_io_event().contains(EP::EPOLLOUT) + // Can send if socket is ready OR if write is shutdown + // (shutdown should wake up send() to return EPIPE) + let can_write = self.check_io_event().contains(EP::EPOLLOUT); + let shutdown_bits = self.shutdown.load(Ordering::Acquire); + let write_shutdown = (shutdown_bits & 0x02) != 0; + can_write || write_shutdown } pub fn try_send( @@ -153,43 +291,48 @@ impl UdpSocket { buf: &[u8], to: Option, ) -> Result { - // 先确保 socket 处于 Bound 状态。任何错误路径都必须恢复 inner,避免变成 None。 - { + // Send data and get iface reference, then release lock before polling + let (result, iface) = { let mut inner_guard = self.inner.write(); - let prev = inner_guard.take().ok_or(SystemError::EINVAL)?; - match prev { - UdpInner::Bound(bound) => { - inner_guard.replace(UdpInner::Bound(bound)); - } - UdpInner::Unbound(unbound) => { - let Some(dest) = to.map(|ep| ep.addr) else { - // 必须恢复原状态,避免 inner=None。 - inner_guard.replace(UdpInner::Unbound(unbound)); - return Err(SystemError::EDESTADDRREQ); - }; - match unbound.bind_ephemeral(dest, self.netns()) { - Ok(bound) => { - inner_guard.replace(UdpInner::Bound(bound)); - } - Err(e) => { - // bind_ephemeral 消费了 unbound(move)。失败则恢复到新的 Unbound 状态。 - inner_guard.replace(UdpInner::Unbound(UnboundUdp::new())); - return Err(e); - } + + // Check if socket is closed + let inner = inner_guard.as_ref().ok_or(SystemError::EBADF)?; + + // If unbound, bind to ephemeral port + if let UdpInner::Unbound(_) = inner { + let to_addr = to.ok_or(SystemError::EDESTADDRREQ)?.addr; + let unbound = match inner_guard.take().unwrap() { + UdpInner::Unbound(unbound) => unbound, + _ => unreachable!(), + }; + match unbound.bind_ephemeral(to_addr, self.netns()) { + Ok(bound) => { + inner_guard.replace(UdpInner::Bound(bound)); + } + Err(e) => { + // Restore unbound state on error + inner_guard.replace(UdpInner::Unbound(UnboundUdp::new())); + return Err(e); } } } - } - // Optimize: 拿两次锁的平均效率是否比一次长时间的读锁效率要高? - let result = match self.inner.read().as_ref() { - Some(UdpInner::Bound(bound)) => { - let ret = bound.try_send(buf, to); - bound.inner().iface().poll(); - ret + + // Send data and get iface Arc before releasing lock + match inner_guard.as_ref().ok_or(SystemError::EBADF)? { + UdpInner::Bound(bound) => { + let ret = bound.try_send(buf, to); + let iface = bound.inner().iface().clone(); + (ret, iface) + } + _ => return Err(SystemError::ENOTCONN), } - _ => Err(SystemError::ENOTCONN), - }; - return result; + }; // Lock released here + + // Poll AFTER releasing the lock to avoid deadlock + // when socket sends to itself on loopback + iface.poll(); + + result } pub fn netns(&self) -> Arc { @@ -206,48 +349,162 @@ impl Socket for UdpSocket { &self.wait_queue } + fn set_nonblocking(&self, nonblocking: bool) { + self.nonblock + .store(nonblocking, core::sync::atomic::Ordering::Relaxed); + } + fn bind(&self, local_endpoint: Endpoint) -> Result<(), SystemError> { - if let Endpoint::Ip(local_endpoint) = local_endpoint { - return self.do_bind(local_endpoint); + match local_endpoint { + Endpoint::Ip(local_endpoint) => self.do_bind(local_endpoint), + Endpoint::Unspecified => { + // AF_UNSPEC on bind() is a no-op for AF_INET sockets (Linux compatibility) + // See: https://github.com/torvalds/linux/commit/29c486df6a208432b370bd4be99ae1369ede28d8 + // log::debug!("UDP bind: AF_UNSPEC treated as no-op for compatibility"); + Ok(()) + } + _ => Err(SystemError::EAFNOSUPPORT), } - Err(SystemError::EAFNOSUPPORT) } fn send_buffer_size(&self) -> usize { + // Check if custom buffer size was set via setsockopt + let custom_size = self.send_buf_size.load(Ordering::Acquire); + if custom_size > 0 { + // Linux doubles the value when returning via getsockopt + return custom_size * 2; + } + + // Otherwise return actual buffer capacity match self.inner.read().as_ref() { Some(UdpInner::Bound(bound)) => { bound.with_socket(|socket| socket.payload_send_capacity()) } - _ => inner::DEFAULT_TX_BUF_SIZE, + _ => inner::DEFAULT_TX_BUF_SIZE * 2, // Linux doubles default too } } fn recv_buffer_size(&self) -> usize { - match self.inner.read().as_ref() { + // Check if custom buffer size was set via setsockopt + let custom_size = self.recv_buf_size.load(Ordering::Acquire); + if custom_size > 0 { + // Linux doubles the value when returning via getsockopt + // log::debug!( + // "recv_buffer_size: custom_size={}, returning={}", + // custom_size, + // custom_size * 2 + // ); + return custom_size * 2; + } + + // Otherwise return actual buffer capacity + let size = match self.inner.read().as_ref() { Some(UdpInner::Bound(bound)) => { bound.with_socket(|socket| socket.payload_recv_capacity()) } - _ => inner::DEFAULT_RX_BUF_SIZE, - } + _ => inner::DEFAULT_RX_BUF_SIZE * 2, // Linux doubles default too + }; + // log::debug!("recv_buffer_size: no custom size, returning={}", size); + size + } + + fn recv_bytes_available(&self) -> Result { + Ok(match self.inner.read().as_ref() { + Some(UdpInner::Bound(bound)) => { + // For UDP, FIONREAD should return the size of the first packet, + // not the total bytes in the queue + bound.with_mut_socket(|socket| { + match socket.peek() { + Ok((payload, _)) => payload.len(), + Err(_) => 0, // No packets available + } + }) + } + _ => 0, + }) } fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { - if let Endpoint::Ip(remote) = endpoint { - if !self.is_bound() { - self.bind_emphemeral(remote.addr)?; + match endpoint { + Endpoint::Ip(remote) => { + // Port 0 is treated as disconnect (like AF_UNSPEC) + // This matches Linux behavior where connect() to port 0 succeeds but disconnects the socket + if remote.port == 0 { + // log::debug!("UDP connect: port 0 treated as disconnect"); + // Disconnect logic - same as AF_UNSPEC case + let should_unbind = { + match self.inner.read().as_ref() { + Some(UdpInner::Bound(inner)) => { + inner.disconnect(); + inner.should_unbind_on_disconnect() + } + Some(UdpInner::Unbound(_)) => return Ok(()), // Already disconnected + None => return Err(SystemError::EBADF), + } + }; + + if should_unbind { + // Socket was implicitly bound by connect, unbind it + let mut inner_guard = self.inner.write(); + if let Some(UdpInner::Bound(bound)) = inner_guard.take() { + bound.close(); + inner_guard.replace(UdpInner::Unbound(UnboundUdp::new())); + } + } + return Ok(()); + } + + if !self.is_bound() { + self.bind_ephemeral(remote.addr)?; + } + match self.inner.read().as_ref() { + Some(UdpInner::Bound(inner)) => { + inner.connect(remote); + Ok(()) + } + Some(_) => Err(SystemError::ENOTCONN), + None => Err(SystemError::EBADF), + } } - if let UdpInner::Bound(inner) = self.inner.read().as_ref().expect("UDP Inner disappear") - { - inner.connect(remote); - return Ok(()); - } else { - panic!(""); + Endpoint::Unspecified => { + // AF_UNSPEC: disconnect the UDP socket (clear remote endpoint) + // If socket was implicitly bound (by connect), unbind it + let should_unbind = { + match self.inner.read().as_ref() { + Some(UdpInner::Bound(inner)) => { + inner.disconnect(); + inner.should_unbind_on_disconnect() + } + Some(UdpInner::Unbound(_)) => return Ok(()), // Already disconnected + None => return Err(SystemError::EBADF), + } + }; + + if should_unbind { + // Socket was implicitly bound by connect, unbind it + let mut inner_guard = self.inner.write(); + if let Some(UdpInner::Bound(bound)) = inner_guard.take() { + bound.close(); + inner_guard.replace(UdpInner::Unbound(UnboundUdp::new())); + } + } + Ok(()) } + _ => Err(SystemError::EAFNOSUPPORT), } - return Err(SystemError::EAFNOSUPPORT); } fn send(&self, buffer: &[u8], flags: PMSG) -> Result { + if buffer.is_empty() { + log::debug!("UDP send() called with ZERO-LENGTH buffer"); + } + + // Check if write is shutdown (0x02 = SEND_SHUTDOWN) + let shutdown_bits = self.shutdown.load(Ordering::Acquire); + if shutdown_bits & 0x02 != 0 { + return Err(SystemError::EPIPE); + } + if flags.contains(PMSG::DONTWAIT) { log::warn!("Nonblock send is not implemented yet"); } @@ -256,6 +513,12 @@ impl Socket for UdpSocket { } fn send_to(&self, buffer: &[u8], flags: PMSG, address: Endpoint) -> Result { + // Check if write is shutdown (0x02 = SEND_SHUTDOWN) + let shutdown_bits = self.shutdown.load(Ordering::Acquire); + if shutdown_bits & 0x02 != 0 { + return Err(SystemError::EPIPE); + } + if flags.contains(PMSG::DONTWAIT) { log::warn!("Nonblock send is not implemented yet"); } @@ -268,46 +531,87 @@ impl Socket for UdpSocket { } fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { - return if self.is_nonblock() || flags.contains(PMSG::DONTWAIT) { - self.try_recv(buffer) + // Check if read is shutdown + // Linux allows reading buffered data even after SHUT_RD, only returns EOF when buffer is empty + let shutdown_bits = self.shutdown.load(Ordering::Acquire); + let is_recv_shutdown = shutdown_bits & 0x01 != 0; + + let peek = flags.contains(PMSG::PEEK); + + if self.is_nonblock() || flags.contains(PMSG::DONTWAIT) { + let result = self.try_recv(buffer, peek); + // If shutdown and no data available, return EOF instead of EWOULDBLOCK + if is_recv_shutdown && matches!(result, Err(SystemError::EAGAIN_OR_EWOULDBLOCK)) { + return Ok(0); + } + return result.map(|(len, _)| len); } else { loop { - match self.try_recv(buffer) { + // Re-check shutdown state inside the loop + let shutdown_bits = self.shutdown.load(Ordering::Acquire); + let is_recv_shutdown = shutdown_bits & 0x01 != 0; + + match self.try_recv(buffer, peek) { + Ok((len, _endpoint)) => { + return Ok(len); + } Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + // If shutdown and no data available, return EOF + if is_recv_shutdown { + return Ok(0); + } wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; } - result => break result, + Err(e) => return Err(e), } } } - .map(|(len, _)| len); } fn recv_from( &self, buffer: &mut [u8], flags: PMSG, - address: Option, + _address: Option, ) -> Result<(usize, Endpoint), SystemError> { - // could block io - if let Some(endpoint) = address { - self.connect(endpoint)?; - } + // Linux allows reading buffered data even after SHUT_RD + // For blocking mode, check shutdown state in the loop + + let peek = flags.contains(PMSG::PEEK); return if self.is_nonblock() || flags.contains(PMSG::DONTWAIT) { - self.try_recv(buffer) + let result = self.try_recv(buffer, peek); + // For non-blocking sockets, always return EAGAIN when no data + // Even after shutdown, don't convert to EOF + result.map(|(len, endpoint)| (len, Endpoint::Ip(endpoint))) } else { loop { - match self.try_recv(buffer) { + // Re-check shutdown state inside the loop + let shutdown_bits = self.shutdown.load(Ordering::Acquire); + let is_recv_shutdown = shutdown_bits & 0x01 != 0; + + match self.try_recv(buffer, peek) { + Ok((len, endpoint)) => { + return Ok((len, Endpoint::Ip(endpoint))); + } Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + // If shutdown and no data available, return EOF + if is_recv_shutdown { + // If connected, return EOF with remote endpoint + if let Some(UdpInner::Bound(bound)) = self.inner.read().as_ref() { + if let Ok(remote) = bound.remote_endpoint() { + return Ok((0, Endpoint::Ip(remote))); + } + } + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; // log::debug!("UdpSocket::recv_from: wake up"); } - result => break result, + Err(e) => return Err(e), } } - } - .map(|(len, remote)| (len, Endpoint::Ip(remote))); + }; } fn do_close(&self) -> Result<(), SystemError> { @@ -315,11 +619,186 @@ impl Socket for UdpSocket { Ok(()) } + fn shutdown(&self, how: ShutdownBit) -> Result<(), SystemError> { + // For UDP, shutdown requires the socket to be connected (both SHUT_RD and SHUT_WR) + // Check if socket is connected + match self.inner.read().as_ref() { + Some(UdpInner::Bound(bound)) => { + if bound.remote_endpoint().is_err() { + return Err(SystemError::ENOTCONN); + } + } + Some(UdpInner::Unbound(_)) => { + return Err(SystemError::ENOTCONN); + } + None => return Err(SystemError::EBADF), + } + + // Set the shutdown bits atomically + // Use fetch_or to set the bits we want + let _old = self.shutdown.fetch_or( + (if how.is_recv_shutdown() { 0x01 } else { 0 }) + | (if how.is_send_shutdown() { 0x02 } else { 0 }), + Ordering::Release, + ); + + // log::debug!( + // "UDP shutdown: old={:#x}, recv={}, send={}", + // _old, + // how.is_recv_shutdown(), + // how.is_send_shutdown() + // ); + + // Wake up any threads blocked in recv() or send() so they can check the shutdown state + self.wait_queue.wakeup_all(None); + + Ok(()) + } + + fn set_option(&self, level: PSOL, name: usize, val: &[u8]) -> Result<(), SystemError> { + if level == PSOL::SOCKET { + let opt = PSO::try_from(name as u32).map_err(|_| SystemError::ENOPROTOOPT)?; + match opt { + PSO::SNDBUF => { + // Set send buffer size + if val.len() < core::mem::size_of::() { + return Err(SystemError::EINVAL); + } + let requested = u32::from_ne_bytes([val[0], val[1], val[2], val[3]]) as usize; + // Enforce minimum buffer size + let size = if requested < MIN_BUF_SIZE { + MIN_BUF_SIZE + } else { + requested + }; + self.send_buf_size.store(size, Ordering::Release); + // log::debug!( + // "UDP setsockopt SO_SNDBUF: requested={}, actual={}", + // requested, + // size + // ); + + // If socket is already bound, we need to recreate it with new buffer size + self.recreate_socket_if_bound()?; + return Ok(()); + } + PSO::RCVBUF => { + // Set receive buffer size + if val.len() < core::mem::size_of::() { + return Err(SystemError::EINVAL); + } + let requested = u32::from_ne_bytes([val[0], val[1], val[2], val[3]]) as usize; + // Enforce minimum buffer size + let size = if requested < MIN_BUF_SIZE { + MIN_BUF_SIZE + } else { + requested + }; + self.recv_buf_size.store(size, Ordering::Release); + // log::debug!( + // "UDP setsockopt SO_RCVBUF: requested={}, actual={}", + // requested, + // size + // ); + + // If socket is already bound, we need to recreate it with new buffer size + self.recreate_socket_if_bound()?; + return Ok(()); + } + PSO::NO_CHECK => { + // Set SO_NO_CHECK: disable/enable UDP checksum verification + // NOTE: This is a stub implementation - see field comment for details. + // The value is stored but does not affect actual checksum behavior. + if val.len() < core::mem::size_of::() { + return Err(SystemError::EINVAL); + } + let value = i32::from_ne_bytes([val[0], val[1], val[2], val[3]]); + self.no_check.store(value != 0, Ordering::Release); + // log::debug!( + // "UDP setsockopt SO_NO_CHECK: {} (stub - no actual effect)", + // value != 0 + // ); + return Ok(()); + } + _ => { + return Err(SystemError::ENOPROTOOPT); + } + } + } + Err(SystemError::ENOPROTOOPT) + } + + fn option(&self, level: PSOL, name: usize, value: &mut [u8]) -> Result { + // log::debug!( + // "UDP getsockopt called: level={:?}, name={}, value_len={}", + // level, + // name, + // value.len() + // ); + if level == PSOL::SOCKET { + let opt = PSO::try_from(name as u32).map_err(|_| SystemError::ENOPROTOOPT)?; + // log::debug!("UDP getsockopt: parsed option {:?}", opt); + match opt { + PSO::SNDBUF => { + if value.len() < core::mem::size_of::() { + return Err(SystemError::EINVAL); + } + let size = self.send_buf_size.load(Ordering::Acquire); + // Linux doubles the value when returning it + // If 0 (not set), return default size + let actual_size = if size == 0 { + DEFAULT_TX_BUF_SIZE * 2 + } else { + size * 2 + }; + let bytes = (actual_size as u32).to_ne_bytes(); + value[0..4].copy_from_slice(&bytes); + return Ok(core::mem::size_of::()); + } + PSO::RCVBUF => { + if value.len() < core::mem::size_of::() { + return Err(SystemError::EINVAL); + } + let size = self.recv_buf_size.load(Ordering::Acquire); + // Linux doubles the value when returning it + // If 0 (not set), return default size + let actual_size = if size == 0 { + DEFAULT_RX_BUF_SIZE * 2 + } else { + size * 2 + }; + // log::debug!( + // "UDP getsockopt SO_RCVBUF: size={}, returning={}", + // size, + // actual_size + // ); + let bytes = (actual_size as u32).to_ne_bytes(); + value[0..4].copy_from_slice(&bytes); + return Ok(core::mem::size_of::()); + } + PSO::NO_CHECK => { + if value.len() < core::mem::size_of::() { + return Err(SystemError::EINVAL); + } + let no_check = self.no_check.load(Ordering::Acquire); + let val = if no_check { 1i32 } else { 0i32 }; + let bytes = val.to_ne_bytes(); + value[0..4].copy_from_slice(&bytes); + return Ok(core::mem::size_of::()); + } + _ => { + return Err(SystemError::ENOPROTOOPT); + } + } + } + Err(SystemError::ENOPROTOOPT) + } + fn remote_endpoint(&self) -> Result { match self.inner.read().as_ref() { Some(UdpInner::Bound(bound)) => Ok(Endpoint::Ip(bound.remote_endpoint()?)), - // TODO: IPv6 support - _ => Err(SystemError::ENOTCONN), + Some(_) => Err(SystemError::ENOTCONN), + None => Err(SystemError::EBADF), } } @@ -328,30 +807,123 @@ impl Socket for UdpSocket { match self.inner.read().as_ref() { Some(UdpInner::Bound(bound)) => { let IpListenEndpoint { addr, port } = bound.endpoint(); - Ok(Endpoint::Ip(IpEndpoint::new( - addr.unwrap_or(Ipv4([0, 0, 0, 0].into())), - port, - ))) + + // If bound to "any" address (0.0.0.0 or ::), but connected to a specific address, + // return the actual local address that would be used for the connection + let local_addr = if let Some(addr) = addr { + addr + } else { + // Socket is bound to ANY - check if connected + if let Ok(remote) = bound.remote_endpoint() { + // Connected: return the local address for the interface that can reach the remote + // For loopback, return loopback address; otherwise get from interface + match remote.addr { + Ipv4(addr) if addr.is_loopback() => Ipv4(addr), + Ipv6(addr) if addr.is_loopback() => Ipv6(addr), + _ => { + // Get the first IP address from the interface + let iface_guard = bound.inner().iface().smol_iface().lock(); + if let Some(cidr) = iface_guard.ip_addrs().first() { + cidr.address() + } else { + Ipv4([0, 0, 0, 0].into()) + } + } + } + } else { + // Not connected, return "any" + Ipv4([0, 0, 0, 0].into()) + } + }; + + Ok(Endpoint::Ip(IpEndpoint::new(local_addr, port))) } - // TODO: IPv6 support - _ => Ok(Endpoint::Ip(IpEndpoint::new(Ipv4([0, 0, 0, 0].into()), 0))), + Some(_) => Ok(Endpoint::Ip(IpEndpoint::new(Ipv4([0, 0, 0, 0].into()), 0))), + None => Err(SystemError::EBADF), } } fn recv_msg( &self, - _msg: &mut crate::net::posix::MsgHdr, - _flags: PMSG, + msg: &mut crate::net::posix::MsgHdr, + flags: PMSG, ) -> Result { - todo!() + use crate::filesystem::vfs::iov::IoVecs; + + // log::debug!( + // "recv_msg: msg_name={:?}, msg_namelen={}, flags={:?}", + // msg.msg_name, + // msg.msg_namelen, + // flags + // ); + + // Check for MSG_ERRQUEUE - we don't support error queues yet + if flags.contains(PMSG::ERRQUEUE) { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + + // Validate and create iovecs + let iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? }; + let mut buf = iovs.new_buf(true); + + // log::debug!("recv_msg: created buffer of {} bytes", buf.len()); + + // Receive data from socket + let (recv_size, src_endpoint) = self.recv_from(&mut buf, flags, None)?; + + // log::debug!( + // "recv_msg: received {} bytes from {:?}", + // recv_size, + // src_endpoint + // ); + + // Scatter received data to user iovecs + iovs.scatter(&buf[..recv_size])?; + + // Write source address if requested + if !msg.msg_name.is_null() { + let src_addr = msg.msg_name; + // log::debug!( + // "recv_msg: writing endpoint to user, msg_namelen={}", + // msg.msg_namelen + // ); + let actual_len = src_endpoint.write_to_user_msghdr(src_addr, msg.msg_namelen)?; + msg.msg_namelen = actual_len; + // log::debug!( + // "recv_msg: endpoint written, updated msg_namelen={}", + // msg.msg_namelen + // ); + } else { + // log::debug!("recv_msg: msg_name is NULL, skipping endpoint write"); + msg.msg_namelen = 0; + } + + // No control messages for now + msg.msg_controllen = 0; + msg.msg_flags = 0; + + // log::debug!("recv_msg: returning {} bytes", recv_size); + Ok(recv_size) } - fn send_msg( - &self, - _msg: &crate::net::posix::MsgHdr, - _flags: PMSG, - ) -> Result { - todo!() + fn send_msg(&self, msg: &crate::net::posix::MsgHdr, flags: PMSG) -> Result { + use crate::filesystem::vfs::iov::IoVecs; + use crate::net::posix::SockAddr; + + // Validate and gather iovecs + // TODO: Actual iovecs sends + let iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, false)? }; + let data = iovs.gather()?; + + // Check if destination address is provided + if !msg.msg_name.is_null() && msg.msg_namelen > 0 { + // Send to specific address + let endpoint = SockAddr::to_endpoint(msg.msg_name as *const SockAddr, msg.msg_namelen)?; + self.send_to(&data, flags, endpoint) + } else { + // Send using connected endpoint + self.send(&data, flags) + } } fn epoll_items(&self) -> &crate::net::socket::common::EPollItems { @@ -365,7 +937,7 @@ impl Socket for UdpSocket { fn check_io_event(&self) -> EPollEventType { let mut event = EPollEventType::empty(); match self.inner.read().as_ref() { - None | Some(UdpInner::Unbound(_)) => { + Some(UdpInner::Unbound(_)) => { event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); } Some(UdpInner::Bound(bound)) => { @@ -380,18 +952,37 @@ impl Socket for UdpSocket { event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); } } + None => { + // Socket is closed + event.insert(EP::EPOLLERR | EP::EPOLLHUP); + } } - return event; + event } fn socket_inode_id(&self) -> InodeId { self.inode_id } + + fn send_bytes_available(&self) -> Result { + Ok(match self.inner.read().as_ref() { + Some(UdpInner::Bound(bound)) => { + bound.with_socket(|socket| socket.payload_send_capacity() - socket.send_queue()) + } + _ => 0, + }) + } } impl InetSocket for UdpSocket { fn on_iface_events(&self) { - return; + // Wake up any threads waiting on this socket + self.wait_queue.wakeup_all(None); + + // Notify epoll/poll watchers about socket state changes + let pollflag = self.check_io_event(); + use crate::filesystem::epoll::event_poll::EventPoll; + let _ = EventPoll::wakeup_epoll(self.epoll_items().as_ref(), pollflag); } } diff --git a/kernel/src/net/socket/inet/posix/option.rs b/kernel/src/net/socket/inet/posix/option.rs deleted file mode 100644 index 5c5947dd0..000000000 --- a/kernel/src/net/socket/inet/posix/option.rs +++ /dev/null @@ -1,68 +0,0 @@ - -bitflags! { - pub struct IpOptions: u32 { - const IP_TOS = 1; // Type of service - const IP_TTL = 2; // Time to live - const IP_HDRINCL = 3; // Header compression - const IP_OPTIONS = 4; // IP options - const IP_ROUTER_ALERT = 5; // Router alert - const IP_RECVOPTS = 6; // Receive options - const IP_RETOPTS = 7; // Return options - const IP_PKTINFO = 8; // Packet information - const IP_PKTOPTIONS = 9; // Packet options - const IP_MTU_DISCOVER = 10; // MTU discovery - const IP_RECVERR = 11; // Receive errors - const IP_RECVTTL = 12; // Receive time to live - const IP_RECVTOS = 13; // Receive type of service - const IP_MTU = 14; // MTU - const IP_FREEBIND = 15; // Freebind - const IP_IPSEC_POLICY = 16; // IPsec policy - const IP_XFRM_POLICY = 17; // IPipsec transform policy - const IP_PASSSEC = 18; // Pass security - const IP_TRANSPARENT = 19; // Transparent - - const IP_RECVRETOPTS = 20; // Receive return options (deprecated) - - const IP_ORIGDSTADDR = 21; // Originate destination address (used by TProxy) - const IP_RECVORIGDSTADDR = 21; // Receive originate destination address - - const IP_MINTTL = 22; // Minimum time to live - const IP_NODEFRAG = 23; // Don't fragment (used by TProxy) - const IP_CHECKSUM = 24; // Checksum offload (used by TProxy) - const IP_BIND_ADDRESS_NO_PORT = 25; // Bind to address without port (used by TProxy) - const IP_RECVFRAGSIZE = 26; // Receive fragment size - const IP_RECVERR_RFC4884 = 27; // Receive ICMPv6 error notifications - - const IP_PMTUDISC_DONT = 28; // Don't send DF frames - const IP_PMTUDISC_DO = 29; // Always DF - const IP_PMTUDISC_PROBE = 30; // Ignore dst pmtu - const IP_PMTUDISC_INTERFACE = 31; // Always use interface mtu (ignores dst pmtu) - const IP_PMTUDISC_OMIT = 32; // Weaker version of IP_PMTUDISC_INTERFACE - - const IP_MULTICAST_IF = 33; // Multicast interface - const IP_MULTICAST_TTL = 34; // Multicast time to live - const IP_MULTICAST_LOOP = 35; // Multicast loopback - const IP_ADD_MEMBERSHIP = 36; // Add multicast group membership - const IP_DROP_MEMBERSHIP = 37; // Drop multicast group membership - const IP_UNBLOCK_SOURCE = 38; // Unblock source - const IP_BLOCK_SOURCE = 39; // Block source - const IP_ADD_SOURCE_MEMBERSHIP = 40; // Add source multicast group membership - const IP_DROP_SOURCE_MEMBERSHIP = 41; // Drop source multicast group membership - const IP_MSFILTER = 42; // Multicast source filter - - const MCAST_JOIN_GROUP = 43; // Join a multicast group - const MCAST_BLOCK_SOURCE = 44; // Block a multicast source - const MCAST_UNBLOCK_SOURCE = 45; // Unblock a multicast source - const MCAST_LEAVE_GROUP = 46; // Leave a multicast group - const MCAST_JOIN_SOURCE_GROUP = 47; // Join a multicast source group - const MCAST_LEAVE_SOURCE_GROUP = 48; // Leave a multicast source group - const MCAST_MSFILTER = 49; // Multicast source filter - - const IP_MULTICAST_ALL = 50; // Multicast all - const IP_UNICAST_IF = 51; // Unicast interface - const IP_LOCAL_PORT_RANGE = 52; // Local port range - const IP_PROTOCOL = 53; // Protocol - - // ... other flags ... - } -} \ No newline at end of file diff --git a/kernel/src/net/socket/inet/posix/proto.rs b/kernel/src/net/socket/inet/posix/proto.rs deleted file mode 100644 index 39818f658..000000000 --- a/kernel/src/net/socket/inet/posix/proto.rs +++ /dev/null @@ -1,76 +0,0 @@ -pub const SOL_SOCKET: u16 = 1; - -#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)] -pub enum IPProtocol { - /// Dummy protocol for TCP. - IP = 0, - /// Internet Control Message Protocol. - ICMP = 1, - /// Internet Group Management Protocol. - IGMP = 2, - /// IPIP tunnels (older KA9Q tunnels use 94). - IPIP = 4, - /// Transmission Control Protocol. - TCP = 6, - /// Exterior Gateway Protocol. - EGP = 8, - /// PUP protocol. - PUP = 12, - /// User Datagram Protocol. - UDP = 17, - /// XNS IDP protocol. - IDP = 22, - /// SO Transport Protocol Class 4. - TP = 29, - /// Datagram Congestion Control Protocol. - DCCP = 33, - /// IPv6-in-IPv4 tunnelling. - IPv6 = 41, - /// RSVP Protocol. - RSVP = 46, - /// Generic Routing Encapsulation. (Cisco GRE) (rfc 1701, 1702) - GRE = 47, - /// Encapsulation Security Payload protocol - ESP = 50, - /// Authentication Header protocol - AH = 51, - /// Multicast Transport Protocol. - MTP = 92, - /// IP option pseudo header for BEET - BEETPH = 94, - /// Encapsulation Header. - ENCAP = 98, - /// Protocol Independent Multicast. - PIM = 103, - /// Compression Header Protocol. - COMP = 108, - /// Stream Control Transport Protocol - SCTP = 132, - /// UDP-Lite protocol (RFC 3828) - UDPLITE = 136, - /// MPLS in IP (RFC 4023) - MPLSINIP = 137, - /// Ethernet-within-IPv6 Encapsulation - ETHERNET = 143, - /// Raw IP packets - RAW = 255, - /// Multipath TCP connection - MPTCP = 262, -} - -impl TryFrom for IPProtocol { - type Error = system_error::SystemError; - - fn try_from(value: u16) -> Result { - match ::from_u16(value) { - Some(p) => Ok(p), - None => Err(system_error::SystemError::EPROTONOSUPPORT), - } - } -} - -impl From for u16 { - fn from(value: IPProtocol) -> Self { - ::to_u16(&value).unwrap() - } -} diff --git a/kernel/src/net/socket/inet/raw/ops.rs b/kernel/src/net/socket/inet/raw/ops.rs index d2f90c198..dad526d9d 100644 --- a/kernel/src/net/socket/inet/raw/ops.rs +++ b/kernel/src/net/socket/inet/raw/ops.rs @@ -237,6 +237,38 @@ impl crate::net::socket::Socket for RawSocket { _ => Err(SystemError::ENOPROTOOPT), } } + + fn recv_bytes_available(&self) -> Result { + let guard = self.inner.read(); + Ok(match *guard { + Some(RawInner::Wildcard(ref bound)) => { + bound.with_mut_socket(|socket| match socket.peek() { + Ok(payload) => payload.len(), + Err(_) => 0, + }) + } + Some(RawInner::Bound(ref bound)) => { + bound.with_mut_socket(|socket| match socket.peek() { + Ok(payload) => payload.len(), + Err(_) => 0, + }) + } + _ => 0, + }) + } + + fn send_bytes_available(&self) -> Result { + let guard = self.inner.read(); + Ok(match *guard { + Some(RawInner::Wildcard(ref bound)) => { + bound.with_socket(|socket| socket.payload_send_capacity() - socket.send_queue()) + } + Some(RawInner::Bound(ref bound)) => { + bound.with_socket(|socket| socket.payload_send_capacity() - socket.send_queue()) + } + _ => 0, + }) + } } impl InetSocket for RawSocket { diff --git a/kernel/src/net/socket/inet/stream/inner.rs b/kernel/src/net/socket/inet/stream/inner.rs index bc832981a..843593587 100644 --- a/kernel/src/net/socket/inet/stream/inner.rs +++ b/kernel/src/net/socket/inet/stream/inner.rs @@ -21,13 +21,18 @@ fn new_smoltcp_socket() -> smoltcp::socket::tcp::Socket<'static> { smoltcp::socket::tcp::Socket::new(rx_buffer, tx_buffer) } -fn new_listen_smoltcp_socket(local_endpoint: T) -> smoltcp::socket::tcp::Socket<'static> +fn new_listen_smoltcp_socket( + local_endpoint: T, +) -> Result, SystemError> where T: Into, { let mut socket = new_smoltcp_socket(); - socket.listen(local_endpoint).unwrap(); - socket + socket.listen(local_endpoint).map_err(|e| match e { + tcp::ListenError::InvalidState => SystemError::EINVAL, // TODO: Check is right impl + tcp::ListenError::Unaddressable => SystemError::EADDRINUSE, + })?; + Ok(socket) } #[derive(Debug)] @@ -64,11 +69,20 @@ impl Init { match self { Init::Unbound((socket, _)) => { let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr, netns)?; - bound - .port_manager() - .bind_port(Types::Tcp, local_endpoint.port)?; - // bound.iface().common().bind_socket() - Ok(Init::Bound((bound, local_endpoint))) + + // Handle ephemeral port assignment (port 0) + let bind_port = if local_endpoint.port == 0 { + bound.port_manager().bind_ephemeral_port(Types::Tcp)? + } else { + bound + .port_manager() + .bind_port(Types::Tcp, local_endpoint.port)?; + local_endpoint.port + }; + + // Create endpoint with actual assigned port + let final_endpoint = smoltcp::wire::IpEndpoint::new(local_endpoint.addr, bind_port); + Ok(Init::Bound((bound, final_endpoint))) } Init::Bound(_) => { log::debug!("Already Bound"); @@ -138,13 +152,26 @@ impl Init { } else { smoltcp::wire::IpListenEndpoint::from(local) }; - log::debug!("listen at {:?}", listen_addr); + if listen_addr.port == 0 { + // Invalid port number + return Err((Init::Bound((inner, local)), SystemError::EINVAL)); + } + // log::debug!("listen at {:?}, backlog {}", listen_addr, backlog); + if backlog == 0 || backlog > u16::MAX as usize { + // Invalid backlog value + return Err((Init::Bound((inner, local)), SystemError::EINVAL)); + } + + // FIXME: need refactor backlog mechanism for large number of backlog + let backlog = if backlog > 8 { 8 } else { backlog }; + let mut inners = Vec::new(); if let Err(err) = || -> Result<(), SystemError> { - for _ in 0..(backlog - 1) { + for _i in 0..(backlog - 1) { // -1 because the first one is already bound + // log::debug!("loop {:?}", _i); let new_listen = socket::inet::BoundInner::bind( - new_listen_smoltcp_socket(listen_addr), + new_listen_smoltcp_socket(listen_addr)?, listen_addr .addr .as_ref() @@ -155,6 +182,7 @@ impl Init { )?; inners.push(new_listen); } + // log::debug!("finished listen"); Ok(()) }() { return Err((Init::Bound((inner, local)), err)); @@ -320,7 +348,7 @@ impl Listening { // log::debug!("local at {:?}", local_endpoint); let mut new_listen = socket::inet::BoundInner::bind( - new_listen_smoltcp_socket(self.listen_addr), + new_listen_smoltcp_socket(self.listen_addr)?, self.listen_addr .addr .as_ref() diff --git a/kernel/src/net/socket/inode.rs b/kernel/src/net/socket/inode.rs index f7ffe2cb9..e3c7a6062 100644 --- a/kernel/src/net/socket/inode.rs +++ b/kernel/src/net/socket/inode.rs @@ -19,6 +19,8 @@ use crate::net::socket::IFNAMSIZ; // Socket ioctl commands const SIOCGIFCONF: u32 = 0x8912; // Get interface list const SIOCGIFINDEX: u32 = 0x8933; // name -> if_index mapping +const FIONREAD: u32 = 0x541B; // Get number of bytes available to read +const TIOCOUTQ: u32 = 0x5411; // Get output queue size /// ## ifreq - Interface request structure /// Used for socket ioctls. Must match C struct layout. @@ -325,6 +327,13 @@ impl IndexNode for T { buf: &[u8], data: SpinLockGuard, ) -> Result { + if buf.is_empty() { + log::debug!( + "Socket write_at: ZERO-LENGTH write, buf.len()={}, _len={}", + buf.len(), + _len + ); + } drop(data); self.write(buf) } @@ -356,17 +365,47 @@ impl IndexNode for T { Ok(md) } - // TODO: implement ioctl for socket + /// 这里应该实现 通用 Socket 作为 IndexNode 的 ioctl 选项 + /// 对于协议特定的 ioctl 选项实现,请在各个 Socket impl trait 内实现 + /// + /// ## 层级结构 + /// + /// `dyn IndexNode::ioctl` -> `impl IndexNode for T: Socket` -> `dyn Socket::ioctl` + /// + /// Socket trait 的 ioctl 覆盖了 IndexNode 这一层的调用,但由于 `impl IndexNode for T: Socket`, + /// 我们先调用在 IndexNode 这一层为 Socket 默认实现的 ioctl,再调用 `Socket` trait 内 + /// 的 ioctl fn ioctl( &self, cmd: u32, data: usize, - private_data: &FilePrivateData, + _private_data: &FilePrivateData, ) -> Result { match cmd { SIOCGIFCONF => handle_siocgifconf(data), SIOCGIFINDEX => handle_siocgifindex(data), - _ => Socket::ioctl(self, cmd, data, private_data), + FIONREAD /* TIOCINQ */ => { + // Get number of bytes available to read + let bytes_available = self.recv_bytes_available()?; + let mut writer = + UserBufferWriter::new(data as *mut u8, core::mem::size_of::(), true)?; + let to_write = core::cmp::min(bytes_available, i32::MAX as usize) as i32; + writer.buffer_protected(0)?.write_one::(0, &to_write)?; + Ok(0) + } + TIOCOUTQ => { + // Get number of bytes available to write + let bytes_available = self.send_bytes_available()?; + let mut writer = + UserBufferWriter::new(data as *mut u8, core::mem::size_of::(), true)?; + let to_write = core::cmp::min(bytes_available, i32::MAX as usize) as i32; + writer.buffer_protected(0)?.write_one::(0, &to_write)?; + Ok(0) + } + _ => { + // 透穿调用子协议栈的ioctl + Socket::ioctl(self, cmd, data, _private_data) + } } } diff --git a/kernel/src/net/socket/unix/datagram/mod.rs b/kernel/src/net/socket/unix/datagram/mod.rs index 64f9dd4f3..ff29b18b6 100644 --- a/kernel/src/net/socket/unix/datagram/mod.rs +++ b/kernel/src/net/socket/unix/datagram/mod.rs @@ -1,9 +1,7 @@ use crate::{ filesystem::epoll::{event_poll::EventPoll, EPollEventType}, filesystem::vfs::iov::IoVecs, - filesystem::vfs::{ - fasync::FAsyncItems, utils::DName, vcore::generate_inode_id, FilePrivateData, InodeId, - }, + filesystem::vfs::{fasync::FAsyncItems, utils::DName, vcore::generate_inode_id, InodeId}, libs::rwlock::RwLock, libs::spinlock::SpinLock, libs::wait_queue::WaitQueue, @@ -30,7 +28,6 @@ use system_error::SystemError; use crate::time::Duration; use crate::process::namespace::net_namespace::NetNamespace; -use crate::syscall::user_access::UserBufferWriter; use crate::{ filesystem::vfs::file::File, net::socket::unix::{current_ucred, nobody_ucred, UCred}, @@ -38,15 +35,6 @@ use crate::{ syscall::user_access::UserBufferReader, }; -// Socket ioctls used by gVisor unix socket tests. -const TIOCOUTQ: u32 = 0x5411; // Get output queue size -const FIONREAD: u32 = 0x541B; // Get input queue size (aka TIOCINQ) -const SIOCGIFINDEX: u32 = 0x8933; // name -> if_index mapping - -fn clamp_usize_to_i32(v: usize) -> i32 { - core::cmp::min(v, i32::MAX as usize) as i32 -} - // Use common ancillary message types from parent module use super::{cmsg_align, CmsgBuffer, Cmsghdr, MSG_CTRUNC, SCM_CREDENTIALS, SCM_RIGHTS, SOL_SOCKET}; @@ -682,38 +670,12 @@ impl Socket for UnixDatagramSocket { &self.open_files } - fn ioctl( - &self, - cmd: u32, - arg: usize, - _private_data: &FilePrivateData, - ) -> Result { - if arg == 0 { - return Err(SystemError::EFAULT); - } + fn recv_bytes_available(&self) -> Result { + Ok(self.ioctl_fionread()) + } - match cmd { - FIONREAD => { - let available = self.ioctl_fionread(); - let mut writer = - UserBufferWriter::new(arg as *mut u8, core::mem::size_of::(), true)?; - writer - .buffer_protected(0)? - .write_one::(0, &clamp_usize_to_i32(available))?; - Ok(0) - } - TIOCOUTQ => { - let queued = self.ioctl_tiocoutq(); - let mut writer = - UserBufferWriter::new(arg as *mut u8, core::mem::size_of::(), true)?; - writer - .buffer_protected(0)? - .write_one::(0, &clamp_usize_to_i32(queued))?; - Ok(0) - } - SIOCGIFINDEX => Err(SystemError::ENODEV), - _ => Err(SystemError::ENOSYS), - } + fn send_bytes_available(&self) -> Result { + Ok(self.ioctl_tiocoutq()) } fn set_nonblocking(&self, nonblocking: bool) { diff --git a/kernel/src/net/socket/unix/stream/mod.rs b/kernel/src/net/socket/unix/stream/mod.rs index 43849341f..2d315536c 100644 --- a/kernel/src/net/socket/unix/stream/mod.rs +++ b/kernel/src/net/socket/unix/stream/mod.rs @@ -1,6 +1,6 @@ use crate::{ filesystem::epoll::{event_poll::EventPoll, EPollEventType}, - filesystem::vfs::{fasync::FAsyncItems, vcore::generate_inode_id, FilePrivateData, InodeId}, + filesystem::vfs::{fasync::FAsyncItems, vcore::generate_inode_id, InodeId}, libs::rwlock::RwLock, net::socket::{self, *}, }; @@ -29,21 +29,12 @@ use crate::filesystem::vfs::iov::IoVecs; use crate::net::socket::unix::{current_ucred, nobody_ucred, UCred}; use crate::process::namespace::net_namespace::NetNamespace; use crate::process::ProcessManager; -use crate::syscall::user_access::{UserBufferReader, UserBufferWriter}; +use crate::syscall::user_access::UserBufferReader; use crate::time::{Duration, Instant}; // Use common ancillary message types from parent module use super::{cmsg_align, CmsgBuffer, Cmsghdr, MSG_CTRUNC, SCM_CREDENTIALS, SCM_RIGHTS, SOL_SOCKET}; -// Socket ioctls used by gVisor unix socket tests. -const TIOCOUTQ: u32 = 0x5411; // Get output queue size -const FIONREAD: u32 = 0x541B; // Get input queue size (aka TIOCINQ) -const SIOCGIFINDEX: u32 = 0x8933; // name -> if_index mapping - -fn clamp_usize_to_i32(v: usize) -> i32 { - core::cmp::min(v, i32::MAX as usize) as i32 -} - #[repr(C)] #[derive(Clone, Copy, Debug, Default)] struct Linger { @@ -516,41 +507,12 @@ impl Socket for UnixStreamSocket { &self.open_files } - fn ioctl( - &self, - cmd: u32, - arg: usize, - _private_data: &FilePrivateData, - ) -> Result { - if arg == 0 { - return Err(SystemError::EFAULT); - } + fn recv_bytes_available(&self) -> Result { + Ok(self.ioctl_fionread()) + } - match cmd { - // Return bytes available for reading. - FIONREAD => { - let available = self.ioctl_fionread(); - let mut writer = - UserBufferWriter::new(arg as *mut u8, core::mem::size_of::(), true)?; - writer - .buffer_protected(0)? - .write_one::(0, &clamp_usize_to_i32(available))?; - Ok(0) - } - // Return bytes queued for transmission. - TIOCOUTQ => { - let queued = self.ioctl_tiocoutq(); - let mut writer = - UserBufferWriter::new(arg as *mut u8, core::mem::size_of::(), true)?; - writer - .buffer_protected(0)? - .write_one::(0, &clamp_usize_to_i32(queued))?; - Ok(0) - } - // Netdevice ioctls on AF_UNIX sockets: gVisor tests accept ENODEV. - SIOCGIFINDEX => Err(SystemError::ENODEV), - _ => Err(SystemError::ENOSYS), - } + fn send_bytes_available(&self) -> Result { + Ok(self.ioctl_tiocoutq()) } fn set_nonblocking(&self, nonblocking: bool) { diff --git a/kernel/src/net/syscall/sys_recvfrom.rs b/kernel/src/net/syscall/sys_recvfrom.rs index 62e753473..ffda74280 100644 --- a/kernel/src/net/syscall/sys_recvfrom.rs +++ b/kernel/src/net/syscall/sys_recvfrom.rs @@ -163,18 +163,16 @@ pub(super) fn do_recvfrom( pmsg_flags.insert(socket::PMSG::DONTWAIT); } - if addr.is_null() { - let (n, _) = socket.recv_from(buf, pmsg_flags, None)?; - return Ok(n); - } - // Linux 语义:recvfrom 的 addr/addrlen 是纯输出参数,内核不得读取 addr 缓冲区内容。 // 用户栈上的 sockaddr 可能是未初始化的;读取它会导致错误解析并返回 EINVAL。 - if addr_len.is_null() { - return Err(SystemError::EFAULT); - } + // recv() passes NULL for both addr and addr_len, which is valid. let (recv_len, endpoint) = socket.recv_from(buf, pmsg_flags, None)?; - endpoint.write_to_user(addr, addr_len)?; + + // Only write the source address if the caller provided addr and addr_len + if !addr.is_null() && !addr_len.is_null() { + endpoint.write_to_user(addr, addr_len)?; + } + Ok(recv_len) } diff --git a/kernel/src/net/syscall/sys_recvmsg.rs b/kernel/src/net/syscall/sys_recvmsg.rs index 0240eb18d..fc87bd326 100644 --- a/kernel/src/net/syscall/sys_recvmsg.rs +++ b/kernel/src/net/syscall/sys_recvmsg.rs @@ -103,6 +103,14 @@ pub(super) fn do_recvmsg( let reader = UserBufferReader::new(msg, core::mem::size_of::(), from_user)?; let mut kmsg = reader.buffer_protected(0)?.read_one::(0)?; + // log::debug!( + // "do_recvmsg: fd={}, msg_iovlen={}, msg_iov={:?}, flags={:#x}", + // fd, + // kmsg.msg_iovlen, + // kmsg.msg_iov, + // flags + // ); + // 检查每个缓冲区地址是否合法,生成iovecs(fallback path needs this). let iovs = unsafe { IoVecs::from_user(kmsg.msg_iov, kmsg.msg_iovlen, true)? }; diff --git a/kernel/src/syscall/user_access.rs b/kernel/src/syscall/user_access.rs index e7ad54ec4..ac4b47b73 100644 --- a/kernel/src/syscall/user_access.rs +++ b/kernel/src/syscall/user_access.rs @@ -490,14 +490,20 @@ impl UserBufferReader<'_> { } fn convert_with_offset(&self, src: &[u8], offset: usize) -> Result<&[T], SystemError> { - if offset >= src.len() { + // offset == src.len is valid, as long as don't try to dereference it in &src[offset..] + if offset > src.len() { return Err(SystemError::EINVAL); } let byte_buffer: &[u8] = &src[offset..]; - if !byte_buffer.len().is_multiple_of(core::mem::size_of::()) || byte_buffer.is_empty() { + if byte_buffer.is_empty() { + // Empty buffer is valid - return empty slice + return Ok(&[]); + } + if !byte_buffer.len().is_multiple_of(core::mem::size_of::()) { return Err(SystemError::EINVAL); } + debug_assert!(offset < src.len()); let chunks = unsafe { from_raw_parts( byte_buffer.as_ptr() as *const T, @@ -776,15 +782,17 @@ impl<'a> UserBufferWriter<'a> { return Err(SystemError::EINVAL); } let byte_buffer: &mut [u8] = &mut src[offset..]; - if !byte_buffer.len().is_multiple_of(core::mem::size_of::()) { - return Err(SystemError::EINVAL); - } let len = byte_buffer.len() / core::mem::size_of::(); if len == 0 { + // Empty buffer is valid - return empty slice return Ok(&mut []); } + if !byte_buffer.len().is_multiple_of(core::mem::size_of::()) { + return Err(SystemError::EINVAL); + } + let chunks = unsafe { from_raw_parts_mut(byte_buffer.as_mut_ptr() as *mut T, len) }; return Ok(chunks); } diff --git a/user/apps/tests/syscall/gvisor/blocklists/udp_socket_test b/user/apps/tests/syscall/gvisor/blocklists/udp_socket_test index 16d46a0a4..533e1bf6c 100644 --- a/user/apps/tests/syscall/gvisor/blocklists/udp_socket_test +++ b/user/apps/tests/syscall/gvisor/blocklists/udp_socket_test @@ -1,81 +1,15 @@ **/1 **/2 UdpInet6SocketTest.ConnectInet4Sockaddr -# **.Creation -# **.Getsockname -# **.Getpeername -**.SendNotConnected/** -**.ConnectBinds/** -**.ReceiveNotBound/** -**.Bind/** -**.BindInUse/** -**.ConnectWriteToInvalidPort/** -**.ConnectSimultaneousWriteToInvalidPort/** -**.ReceiveAfterConnect/** -**.ReceiveAfterDisconnect/** -**.Connect/** -**.ConnectAnyZero/** -**.ConnectAnyWithPort/** -**.DisconnectAfterConnectAny/** -**.DisconnectAfterConnectAnyWithPort/** -**.DisconnectAfterBind/** -**.DisconnectAfterBindToUnspecAndConnect/** -**.DisconnectAfterConnectWithoutBind/** -**.BindToAnyConnnectToLocalhost/** -**.DisconnectAfterBindToAny/** -**.Disconnect/** -**.ConnectBadAddress/** -**.SendToAddressOtherThanConnected/** -**.ConnectAndSendNoReceiver/** -**.RecvErrorConnRefusedOtherAFSockOpt/** -**.RecvErrorConnRefused/** -**.ZerolengthWriteAllowed/** -**.ZerolengthWriteAllowedNonBlockRead/** -**.SendAndReceiveNotConnected/** -**.SendAndReceiveConnected/** -**.ReceiveFromNotConnected/** -**.ReceiveBeforeConnect/** -**.ReceiveFrom/** -**.Listen/** -**.Accept/** -**.ReadShutdownNonblockPendingData/** -**.ReadShutdownSameSocketResetsShutdownState/** -**.ReadShutdown/** -**.ReadShutdownDifferentThread/** -**.WriteShutdown/** -**.SynchronousReceive/** -**.BoundaryPreserved_SendRecv/** -**.BoundaryPreserved_WritevReadv/** -**.BoundaryPreserved_SendMsgRecvMsg/** -**.FIONREADShutdown/** -**.FIONREADWriteShutdown/** -**.Fionread/** -**.FIONREADZeroLengthPacket/** -**.FIONREADZeroLengthWriteShutdown/** -**.SoNoCheckOffByDefault/** -**.SoNoCheck/** -**.ErrorQueue/** -**.SoTimestampOffByDefault/** -**.SoTimestamp/** -**.WriteShutdownNotConnected/** -**.TimestampIoctl/** -**.TimestampIoctlNothingRead/** -**.TimestampIoctlPersistence/** -**.RecvBufLimitsEmptyRcvBuf/** -**.RecvBufLimits/** -**.SetSocketDetachFilter/** -**.SetSocketDetachFilterNoInstalledFilter/** -**.GetSocketDetachFilter/** -**.SendToZeroPort/** -**.ConnectToZeroPortUnbound/** -**.ConnectToZeroPortBound/** -**.ConnectToZeroPortConnected/** -**.SetAndReceiveTOSOrTClass/** -**.SendAndReceiveTOSorTClass/** -**.SetAndReceiveTTLOrHopLimit/** -**.SendAndReceiveTTLOrHopLimit/** -**.SetAndReceivePktInfo/** -**.SendPacketLargerThanSendBufOnNonBlockingSocket/** -**.ReadShutdownOnBoundSocket/** -**.ReconnectDoesNotClearReadShutdown/** -**.ReconnectDoesNotClearWriteShutdown/** \ No newline at end of file +**SoTimestamp** +**TimestampIoctl** +**RecvBufLimits** +**UdpSocketControlMessagesTest** +**SetSocketDetachFilter** +# ICMP +**ConnectWriteToInvalidPort** +**ConnectAndSendNoReceiver** +# Message Queue +**RecvErrorConnRefused** +# AF Option Isolation +**RecvErrorRefusedOtherAFSockOpt**