From d0aceac7c15f927dfdbdd4b6678b6fdf5bbd9cfd Mon Sep 17 00:00:00 2001 From: longjin Date: Tue, 10 Feb 2026 17:02:55 +0800 Subject: [PATCH] feat(net): enhance network interface binding and loopback handling - Add smoltcp features for increased address and route limits - Improve error handling for address binding with specific error messages - Enhance subnet-directed broadcast and loopback interface matching logic - Add broadcast delivery support for UDP sockets - Refactor netlink socket to return original message length for TRUNC flag - Fix route addition to include local routes and proper rollback on failure - Add batch polling for network interfaces to improve performance Signed-off-by: longjin --- kernel/Cargo.toml | 2 + kernel/src/net/socket/inet/common/mod.rs | 77 ++++++++++- kernel/src/net/socket/inet/datagram/mod.rs | 122 +++++++++++------- .../net/socket/inet/datagram/udp_bindings.rs | 23 ++++ kernel/src/net/socket/netlink/common/mod.rs | 80 ++++++++---- .../src/net/socket/netlink/kobject/bound.rs | 7 +- kernel/src/net/socket/netlink/route/bound.rs | 20 ++- .../src/net/socket/netlink/route/kern/addr.rs | 77 ++++++++++- .../src/net/socket/utils/datagram_common.rs | 4 +- ...ket_ipv6_udp_unbound_loopback_netlink_test | 1 + user/apps/tests/syscall/gvisor/whitelist.txt | 2 + 11 files changed, 314 insertions(+), 101 deletions(-) create mode 100644 user/apps/tests/syscall/gvisor/blocklists/socket_ipv6_udp_unbound_loopback_netlink_test diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 460a695ea..ccc73dd66 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -74,6 +74,8 @@ smoltcp = { version = "=0.12.0", git = "https://git.mirrors.dragonos.org.cn/Drag "medium-ip", "log", "multicast", + "iface-max-addr-count-8", + "iface-max-route-count-8", "iface-max-multicast-group-count-8" ] } syscall_table_macros = { path = "crates/syscall_table_macros" } diff --git a/kernel/src/net/socket/inet/common/mod.rs b/kernel/src/net/socket/inet/common/mod.rs index 5dc08e631..21dfd04b4 100644 --- a/kernel/src/net/socket/inet/common/mod.rs +++ b/kernel/src/net/socket/inet/common/mod.rs @@ -54,7 +54,8 @@ impl BoundInner { netns, }); } else { - let iface = get_iface_to_bind(address, netns.clone()).ok_or(SystemError::ENODEV)?; + let iface = get_iface_to_bind(address, netns.clone()) + .ok_or_else(|| bind_addr_not_found_error(address, &netns))?; // log::debug!( // "BoundInner::bind: binding to iface {} for address {:?}", // iface.iface_name(), @@ -169,19 +170,68 @@ pub fn get_iface_to_bind( ip_addr: &smoltcp::wire::IpAddress, netns: Arc, ) -> Option> { - // For multicast or broadcast addresses, use the default interface or first available - // Linux allows binding to these addresses for filtering purposes + let device_list = netns.device_list(); + + // Subnet-directed broadcast should prefer the iface whose configured subnet matches. + if let smoltcp::wire::IpAddress::Ipv4(target_broadcast) = ip_addr { + if target_broadcast.is_broadcast() { + if let Some(iface) = device_list.iter().find_map(|(_, iface)| { + iface_matches_directed_broadcast(iface, *target_broadcast).then(|| iface.clone()) + }) { + return Some(iface); + } + } + } + + // For multicast/broadcast fallback, use default or first iface. if ip_addr.is_multicast() || ip_addr.is_broadcast() { return netns .default_iface() - .or_else(|| netns.device_list().values().next().cloned()); + .or_else(|| device_list.values().next().cloned()); } - netns - .device_list() + if let Some(iface) = device_list .iter() .find(|(_, iface)| iface.smol_iface().lock().has_ip_addr(*ip_addr)) .map(|(_, iface)| iface.clone()) + { + return Some(iface); + } + + // Linux-like loopback behavior for IPv4: lo considers the whole configured subnet local. + if let smoltcp::wire::IpAddress::Ipv4(v4_addr) = ip_addr { + return device_list.iter().find_map(|(_, iface)| { + loopback_iface_contains_v4(iface, *v4_addr).then(|| iface.clone()) + }); + } + + None +} + +#[inline] +fn iface_matches_directed_broadcast( + iface: &Arc, + target_broadcast: smoltcp::wire::Ipv4Address, +) -> bool { + let smol_iface = iface.smol_iface().lock(); + smol_iface.ip_addrs().iter().any(|cidr| match cidr { + smoltcp::wire::IpCidr::Ipv4(v4_cidr) => { + v4_cidr.broadcast().is_some_and(|b| b == target_broadcast) + } + _ => false, + }) +} + +#[inline] +fn loopback_iface_contains_v4(iface: &Arc, v4_addr: smoltcp::wire::Ipv4Address) -> bool { + if !iface.flags().contains(InterfaceFlags::LOOPBACK) { + return false; + } + let smol_iface = iface.smol_iface().lock(); + smol_iface.ip_addrs().iter().any(|cidr| match cidr { + smoltcp::wire::IpCidr::Ipv4(v4_cidr) => v4_cidr.contains_addr(&v4_addr), + _ => false, + }) } /// Get a suitable iface to deal with sendto/connect request if the socket is not bound to an iface. @@ -270,6 +320,21 @@ fn is_loopback_destination(remote_ip_addr: &smoltcp::wire::IpAddress) -> bool { } } +fn bind_addr_not_found_error( + addr: &smoltcp::wire::IpAddress, + netns: &Arc, +) -> SystemError { + if netns.device_list().is_empty() { + return SystemError::ENODEV; + } + + match addr { + smoltcp::wire::IpAddress::Ipv4(_) | smoltcp::wire::IpAddress::Ipv6(_) => { + SystemError::EADDRNOTAVAIL + } + } +} + /// Select a suitable network interface for binding to an unspecified address. /// /// Selection logic (in priority order): diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index 5e8eefc05..4f7ef0d25 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -35,6 +35,7 @@ pub mod multicast_loopback; mod udp_bindings; type EP = crate::filesystem::epoll::EPollEventType; +const IFACE_POLL_BATCH_ROUNDS: usize = 128; #[repr(C)] #[derive(Clone, Copy, Debug, Default)] @@ -228,6 +229,30 @@ impl UdpSocket { self.nonblock.load(core::sync::atomic::Ordering::Relaxed) } + #[inline] + fn poll_iface_until_quiescent(iface: &dyn crate::net::Iface) { + loop { + let mut progressed = false; + for i in 0..IFACE_POLL_BATCH_ROUNDS { + if !iface.poll() { + return; + } + progressed = true; + if (i & 0x7) == 0x7 { + let pcb = ProcessManager::current_pcb(); + if pcb.has_pending_signal_fast() && pcb.has_pending_not_masked_signal() { + return; + } + } + } + if progressed { + crate::sched::sched_yield(); + } else { + return; + } + } + } + fn recv_timeout(&self) -> Option { let us = self .recv_timeout_us @@ -239,6 +264,18 @@ impl UdpSocket { } } + #[inline] + fn loopback_send_len_result( + payload_len: usize, + max_payload: usize, + ) -> Result { + if payload_len > max_payload || payload_len > u16::MAX as usize { + Err(SystemError::EMSGSIZE) + } else { + Ok(payload_len) + } + } + fn loopback_accepts_with_preconnect( &self, pkt: &LoopbackPacket, @@ -749,6 +786,7 @@ impl UdpSocket { result, send_iface, dest, + dest_is_broadcast, loopback_send, send_iface_is_loopback, mcast_ifindex, @@ -841,55 +879,34 @@ impl UdpSocket { Ipv4(v4) => v4.is_loopback(), Ipv6(v6) => v6.is_loopback(), }; - let is_broadcast = matches!(dest.addr, Ipv4(v4) if v4.is_broadcast()); - if is_loopback { + let loopback_broadcast = self + .netns + .loopback_iface() + .map(|lo| lo.smol_iface().lock().inner.is_broadcast(&dest.addr)) + .unwrap_or(false); + let is_broadcast = loopback_broadcast + || (send_iface_is_loopback + && bound_iface + .smol_iface() + .lock() + .inner + .is_broadcast(&dest.addr)); + let should_loopback_send = is_loopback + || ((is_multicast || is_broadcast) + && (send_iface_is_loopback || loopback_broadcast)); + if should_loopback_send { let max_payload = bound.with_socket(|socket| socket.payload_send_capacity()); - if buf.len() > max_payload || buf.len() > u16::MAX as usize { - ( - Err(SystemError::EMSGSIZE), - bound_iface, - Some(dest), - true, - send_iface_is_loopback, - mcast_ifindex, - None, - ) - } else { - ( - Ok(buf.len()), - bound_iface, - Some(dest), - true, - send_iface_is_loopback, - mcast_ifindex, - None, - ) - } - } else if (is_multicast || is_broadcast) && send_iface_is_loopback { - let max_payload = - bound.with_socket(|socket| socket.payload_send_capacity()); - if buf.len() > max_payload || buf.len() > u16::MAX as usize { - ( - Err(SystemError::EMSGSIZE), - bound_iface, - Some(dest), - true, - send_iface_is_loopback, - mcast_ifindex, - None, - ) - } else { - ( - Ok(buf.len()), - bound_iface, - Some(dest), - true, - send_iface_is_loopback, - mcast_ifindex, - None, - ) - } + ( + Self::loopback_send_len_result(buf.len(), max_payload), + bound_iface, + Some(dest), + is_broadcast, + true, + send_iface_is_loopback, + mcast_ifindex, + None, + ) } else { let mut send_iface = bound_iface.clone(); let mut restore_iface = None; @@ -913,6 +930,7 @@ impl UdpSocket { ret, send_iface, Some(dest), + is_broadcast, false, send_iface_is_loopback, mcast_ifindex, @@ -958,6 +976,14 @@ impl UdpSocket { ); } } + } else if dest_is_broadcast { + udp_bindings::deliver_broadcast_all( + &self.netns, + dest, + src_endpoint, + ifindex, + buf, + ); } else { udp_bindings::deliver_unicast_loopback( &self.netns, @@ -1016,7 +1042,7 @@ impl UdpSocket { // Poll AFTER releasing the lock to avoid deadlock // when socket sends to itself on loopback - send_iface.poll(); + Self::poll_iface_until_quiescent(send_iface.as_ref()); if let Some(orig_iface) = restore_iface { let mut inner_guard = self.inner.write(); diff --git a/kernel/src/net/socket/inet/datagram/udp_bindings.rs b/kernel/src/net/socket/inet/datagram/udp_bindings.rs index f14a5389d..9b32879ab 100644 --- a/kernel/src/net/socket/inet/datagram/udp_bindings.rs +++ b/kernel/src/net/socket/inet/datagram/udp_bindings.rs @@ -127,6 +127,29 @@ pub fn deliver_multicast_all( delivered } +pub fn deliver_broadcast_all( + netns: &Arc, + dest: IpEndpoint, + src: IpEndpoint, + ifindex: i32, + payload: &[u8], +) -> usize { + let candidates = match_udp_bindings(netns, dest.addr, dest.port); + if candidates.is_empty() { + return 0; + } + let mut delivered = 0; + for cand in candidates { + if cand + .socket + .inject_loopback_packet(src, dest.addr, dest.port, ifindex, payload) + { + delivered += 1; + } + } + delivered +} + fn match_udp_bindings( netns: &Arc, dest_addr: IpAddress, diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs index c258e792e..6ae739c77 100644 --- a/kernel/src/net/socket/netlink/common/mod.rs +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -91,15 +91,52 @@ where &self, buf: &mut [u8], flags: crate::net::socket::PMSG, - ) -> Result<(usize, Endpoint), SystemError> { - let (recv_bytes, endpoint) = self - .inner - .read() - .try_recv(buf, flags) - .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; + ) -> Result<(usize, usize, Endpoint), SystemError> { + let (recv_bytes, orig_len, endpoint) = self.inner.read().try_recv(buf, flags).map( + |(recv_bytes, orig_len, remote_endpoint)| { + (recv_bytes, orig_len, remote_endpoint.into()) + }, + )?; // todo self.pollee.invalidate(); - Ok((recv_bytes, endpoint)) + Ok((recv_bytes, orig_len, endpoint)) + } + + #[inline] + fn recv_return_len(copy_len: usize, orig_len: usize, flags: PMSG) -> usize { + if flags.contains(PMSG::TRUNC) { + orig_len + } else { + copy_len + } + } + + fn recv_from_inner( + &self, + buffer: &mut [u8], + flags: crate::net::socket::PMSG, + address: Option, + ) -> Result<(usize, usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError> + { + if let Some(addr) = address { + let endpoint = addr.try_into()?; + self.inner + .write() + .connect(&endpoint, self.wait_queue.clone(), self.netns())?; + } + + if self.is_nonblocking() || flags.contains(PMSG::DONTWAIT) { + self.try_recv(buffer, flags) + } else { + loop { + match self.try_recv(buffer, flags) { + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); + } + result => break result, + } + } + } } /// 判断当前的netlink是否可以接收数据 @@ -171,24 +208,8 @@ where flags: crate::net::socket::PMSG, address: Option, ) -> Result<(usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError> { - // log::info!("NetlinkSocket recv_from called"); - if let Some(addr) = address { - self.connect(addr)?; - } - - return if self.is_nonblocking() || flags.contains(PMSG::DONTWAIT) { - self.try_recv(buffer, flags) - } else { - loop { - match self.try_recv(buffer, flags) { - Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { - let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); - } - result => break result, - } - } - }; - // self.try_recv(buffer, flags) + let (copy_len, orig_len, endpoint) = self.recv_from_inner(buffer, flags, address)?; + Ok((Self::recv_return_len(copy_len, orig_len, flags), endpoint)) } fn check_io_event(&self) -> crate::filesystem::epoll::EPollEventType { @@ -220,8 +241,8 @@ where let iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? }; let mut buf = iovs.new_buf(true); - let (recv_size, endpoint) = self.recv_from(&mut buf, flags, None)?; - iovs.scatter(&buf[..recv_size])?; + let (copy_len, orig_len, endpoint) = self.recv_from_inner(&mut buf, flags, None)?; + iovs.scatter(&buf[..copy_len])?; if !msg.msg_name.is_null() { let actual_len = endpoint.write_to_user_msghdr(msg.msg_name, msg.msg_namelen)?; @@ -232,7 +253,10 @@ where msg.msg_controllen = 0; msg.msg_flags = 0; - Ok(recv_size) + if orig_len > copy_len { + msg.msg_flags |= PMSG::TRUNC.bits() as i32; + } + Ok(Self::recv_return_len(copy_len, orig_len, flags)) } fn send(&self, buffer: &[u8], flags: PMSG) -> Result { diff --git a/kernel/src/net/socket/netlink/kobject/bound.rs b/kernel/src/net/socket/netlink/kobject/bound.rs index c720d3aaf..1bb128769 100644 --- a/kernel/src/net/socket/netlink/kobject/bound.rs +++ b/kernel/src/net/socket/netlink/kobject/bound.rs @@ -64,13 +64,14 @@ impl datagram_common::Bound for BoundNetlink { &self, writer: &mut [u8], flags: PMSG, - ) -> Result<(usize, Self::Endpoint), SystemError> { + ) -> Result<(usize, usize, Self::Endpoint), SystemError> { let mut receive_queue = self.receive_queue.0.lock(); let Some(message) = receive_queue.front() else { return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); }; - let copied = writer.len().min(message.as_bytes().len()); + let orig_len = message.as_bytes().len(); + let copied = writer.len().min(orig_len); if copied > 0 { writer[..copied].copy_from_slice(&message.as_bytes()[..copied]); } @@ -79,7 +80,7 @@ impl datagram_common::Bound for BoundNetlink { receive_queue.pop_front(); } - Ok((copied, NetlinkSocketAddr::new_unspecified())) + Ok((copied, orig_len, NetlinkSocketAddr::new_unspecified())) } fn check_io_events(&self) -> EPollEventType { diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs index ff4e746bd..12ef3591d 100644 --- a/kernel/src/net/socket/netlink/route/bound.rs +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -90,20 +90,26 @@ impl datagram_common::Bound for BoundNetlink { &self, writer: &mut [u8], flags: crate::net::socket::PMSG, - ) -> Result<(usize, Self::Endpoint), SystemError> { + ) -> Result<(usize, usize, Self::Endpoint), SystemError> { let mut receive_queue = self.receive_queue.0.lock(); let Some(res) = receive_queue.front() else { return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); }; - let len = { - let max = writer.len(); - res.total_len().min(max) + let orig_len = res.total_len(); + let copied = if writer.len() >= orig_len { + res.write_to(writer)? + } else { + let mut full = alloc::vec![0u8; orig_len]; + let written = res.write_to(&mut full)?; + let copy_len = written.min(writer.len()); + if copy_len > 0 { + writer[..copy_len].copy_from_slice(&full[..copy_len]); + } + copy_len }; - let _copied = res.write_to(writer)?; - if !flags.contains(PMSG::PEEK) { receive_queue.pop_front(); } @@ -111,7 +117,7 @@ impl datagram_common::Bound for BoundNetlink { // todo 目前这个信息只能来自内核 let remote = NetlinkSocketAddr::new_unspecified(); - Ok((len, remote)) + Ok((copied, orig_len, remote)) } fn check_io_events(&self) -> EPollEventType { diff --git a/kernel/src/net/socket/netlink/route/kern/addr.rs b/kernel/src/net/socket/netlink/route/kern/addr.rs index dc84066eb..038b04074 100644 --- a/kernel/src/net/socket/netlink/route/kern/addr.rs +++ b/kernel/src/net/socket/netlink/route/kern/addr.rs @@ -76,7 +76,7 @@ fn add_addr(request_segment: &AddrSegment, netns: Arc) -> Result<( let flags = NewRequestFlags::from_bits_truncate(request_segment.header().flags); let mut exists = false; - let mut pushed = true; + let mut pushed = false; iface.smol_iface().lock().update_ip_addrs(|ip_addrs| { exists = ip_addrs.contains(&cidr); @@ -89,8 +89,8 @@ fn add_addr(request_segment: &AddrSegment, netns: Arc) -> Result<( IpAddress::Ipv6(_) => ip_addrs.len(), }; - if ip_addrs.insert(insert_index, cidr).is_err() { - pushed = false; + if ip_addrs.insert(insert_index, cidr).is_ok() { + pushed = true; } } }); @@ -102,14 +102,15 @@ fn add_addr(request_segment: &AddrSegment, netns: Arc) -> Result<( return Err(SystemError::EEXIST); } - if flags.contains(NewRequestFlags::REPLACE) { - return Err(SystemError::ENOENT); - } - if !pushed { return Err(SystemError::ENOSPC); } + if let Err(err) = add_local_route(&iface, cidr) { + rollback_added_addr(&iface, cidr); + return Err(err); + } + sync_router_ip_addrs(&iface); Ok(()) @@ -132,6 +133,7 @@ fn del_addr(request_segment: &AddrSegment, netns: Arc) -> Result<( return Err(SystemError::EADDRNOTAVAIL); } + remove_local_route(&iface, cidr); sync_router_ip_addrs(&iface); Ok(()) @@ -249,3 +251,64 @@ fn sync_router_ip_addrs(iface: &Arc) { router_ip_addrs.clear(); router_ip_addrs.extend_from_slice(&smol_ip_addrs); } + +fn add_local_route(iface: &Arc, cidr: IpCidr) -> Result<(), SystemError> { + let mut pushed = false; + let via_router = cidr.address(); + + iface.smol_iface().lock().routes_mut().update(|routes| { + let exists = routes + .iter() + .any(|route| is_same_local_route(route, cidr, via_router)); + if exists { + pushed = true; + return; + } + + pushed = routes + .push(smoltcp::iface::Route { + cidr, + via_router, + preferred_until: None, + expires_at: None, + }) + .is_ok(); + }); + + if !pushed { + log::warn!( + "netlink add_addr: route table full while adding local route {} via {}", + cidr, + via_router + ); + return Err(SystemError::ENOSPC); + } + + Ok(()) +} + +fn remove_local_route(iface: &Arc, cidr: IpCidr) { + let via_router = cidr.address(); + iface.smol_iface().lock().routes_mut().update(|routes| { + if let Some(index) = routes + .iter() + .position(|route| is_same_local_route(route, cidr, via_router)) + { + routes.remove(index); + } + }); +} + +fn rollback_added_addr(iface: &Arc, cidr: IpCidr) { + iface.smol_iface().lock().update_ip_addrs(|ip_addrs| { + if let Some(index) = ip_addrs.iter().position(|configured| *configured == cidr) { + ip_addrs.remove(index); + } + }); + sync_router_ip_addrs(iface); +} + +#[inline] +fn is_same_local_route(route: &smoltcp::iface::Route, cidr: IpCidr, via_router: IpAddress) -> bool { + route.cidr == cidr && route.via_router == via_router +} diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs index 70710ca90..4078bd699 100644 --- a/kernel/src/net/socket/utils/datagram_common.rs +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -51,7 +51,7 @@ pub trait Bound { &self, writer: &mut [u8], flags: PMSG, - ) -> Result<(usize, Self::Endpoint), SystemError>; + ) -> Result<(usize, usize, Self::Endpoint), SystemError>; fn try_send(&self, buf: &[u8], to: &Self::Endpoint, flags: PMSG) -> Result; @@ -151,7 +151,7 @@ where &self, writer: &mut [u8], flags: PMSG, - ) -> Result<(usize, UnboundSocket::Endpoint), SystemError> { + ) -> Result<(usize, usize, UnboundSocket::Endpoint), SystemError> { match self { Inner::Unbound(_) => Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Inner::Bound(bound) => bound.try_recv(writer, flags), diff --git a/user/apps/tests/syscall/gvisor/blocklists/socket_ipv6_udp_unbound_loopback_netlink_test b/user/apps/tests/syscall/gvisor/blocklists/socket_ipv6_udp_unbound_loopback_netlink_test new file mode 100644 index 000000000..0e7a8715a --- /dev/null +++ b/user/apps/tests/syscall/gvisor/blocklists/socket_ipv6_udp_unbound_loopback_netlink_test @@ -0,0 +1 @@ +IPv6UDPSockets/IPv6UDPUnboundSocketNetlinkTest.JoinSubnet/1 diff --git a/user/apps/tests/syscall/gvisor/whitelist.txt b/user/apps/tests/syscall/gvisor/whitelist.txt index fd1b59929..595b65d78 100644 --- a/user/apps/tests/syscall/gvisor/whitelist.txt +++ b/user/apps/tests/syscall/gvisor/whitelist.txt @@ -110,6 +110,8 @@ socket_ipv4_udp_unbound_loopback_test socket_ipv4_udp_unbound_loopback_nogotsan_test socket_netlink_test socket_ip_unbound_netlink_test +socket_ipv4_udp_unbound_loopback_netlink_test +socket_ipv6_udp_unbound_loopback_netlink_test # 信号处理测试 sigaction_test