feat(net): enhance network interface binding and loopback handling

- Add smoltcp features for increased address and route limits
- Improve error handling for address binding with specific error messages
- Enhance subnet-directed broadcast and loopback interface matching logic
- Add broadcast delivery support for UDP sockets
- Refactor netlink socket to return original message length for TRUNC flag
- Fix route addition to include local routes and proper rollback on failure
- Add batch polling for network interfaces to improve performance

Signed-off-by: longjin <longjin@DragonOS.org>
This commit is contained in:
longjin 2026-02-10 17:02:55 +08:00
parent c0c8892710
commit d0aceac7c1
11 changed files with 314 additions and 101 deletions

View File

@ -74,6 +74,8 @@ smoltcp = { version = "=0.12.0", git = "https://git.mirrors.dragonos.org.cn/Drag
"medium-ip",
"log",
"multicast",
"iface-max-addr-count-8",
"iface-max-route-count-8",
"iface-max-multicast-group-count-8"
] }
syscall_table_macros = { path = "crates/syscall_table_macros" }

View File

@ -54,7 +54,8 @@ impl BoundInner {
netns,
});
} else {
let iface = get_iface_to_bind(address, netns.clone()).ok_or(SystemError::ENODEV)?;
let iface = get_iface_to_bind(address, netns.clone())
.ok_or_else(|| bind_addr_not_found_error(address, &netns))?;
// log::debug!(
// "BoundInner::bind: binding to iface {} for address {:?}",
// iface.iface_name(),
@ -169,19 +170,68 @@ pub fn get_iface_to_bind(
ip_addr: &smoltcp::wire::IpAddress,
netns: Arc<NetNamespace>,
) -> Option<Arc<dyn Iface>> {
// For multicast or broadcast addresses, use the default interface or first available
// Linux allows binding to these addresses for filtering purposes
let device_list = netns.device_list();
// Subnet-directed broadcast should prefer the iface whose configured subnet matches.
if let smoltcp::wire::IpAddress::Ipv4(target_broadcast) = ip_addr {
if target_broadcast.is_broadcast() {
if let Some(iface) = device_list.iter().find_map(|(_, iface)| {
iface_matches_directed_broadcast(iface, *target_broadcast).then(|| iface.clone())
}) {
return Some(iface);
}
}
}
// For multicast/broadcast fallback, use default or first iface.
if ip_addr.is_multicast() || ip_addr.is_broadcast() {
return netns
.default_iface()
.or_else(|| netns.device_list().values().next().cloned());
.or_else(|| device_list.values().next().cloned());
}
netns
.device_list()
if let Some(iface) = device_list
.iter()
.find(|(_, iface)| iface.smol_iface().lock().has_ip_addr(*ip_addr))
.map(|(_, iface)| iface.clone())
{
return Some(iface);
}
// Linux-like loopback behavior for IPv4: lo considers the whole configured subnet local.
if let smoltcp::wire::IpAddress::Ipv4(v4_addr) = ip_addr {
return device_list.iter().find_map(|(_, iface)| {
loopback_iface_contains_v4(iface, *v4_addr).then(|| iface.clone())
});
}
None
}
#[inline]
fn iface_matches_directed_broadcast(
iface: &Arc<dyn Iface>,
target_broadcast: smoltcp::wire::Ipv4Address,
) -> bool {
let smol_iface = iface.smol_iface().lock();
smol_iface.ip_addrs().iter().any(|cidr| match cidr {
smoltcp::wire::IpCidr::Ipv4(v4_cidr) => {
v4_cidr.broadcast().is_some_and(|b| b == target_broadcast)
}
_ => false,
})
}
#[inline]
fn loopback_iface_contains_v4(iface: &Arc<dyn Iface>, v4_addr: smoltcp::wire::Ipv4Address) -> bool {
if !iface.flags().contains(InterfaceFlags::LOOPBACK) {
return false;
}
let smol_iface = iface.smol_iface().lock();
smol_iface.ip_addrs().iter().any(|cidr| match cidr {
smoltcp::wire::IpCidr::Ipv4(v4_cidr) => v4_cidr.contains_addr(&v4_addr),
_ => false,
})
}
/// Get a suitable iface to deal with sendto/connect request if the socket is not bound to an iface.
@ -270,6 +320,21 @@ fn is_loopback_destination(remote_ip_addr: &smoltcp::wire::IpAddress) -> bool {
}
}
fn bind_addr_not_found_error(
addr: &smoltcp::wire::IpAddress,
netns: &Arc<NetNamespace>,
) -> SystemError {
if netns.device_list().is_empty() {
return SystemError::ENODEV;
}
match addr {
smoltcp::wire::IpAddress::Ipv4(_) | smoltcp::wire::IpAddress::Ipv6(_) => {
SystemError::EADDRNOTAVAIL
}
}
}
/// Select a suitable network interface for binding to an unspecified address.
///
/// Selection logic (in priority order):

View File

@ -35,6 +35,7 @@ pub mod multicast_loopback;
mod udp_bindings;
type EP = crate::filesystem::epoll::EPollEventType;
const IFACE_POLL_BATCH_ROUNDS: usize = 128;
#[repr(C)]
#[derive(Clone, Copy, Debug, Default)]
@ -228,6 +229,30 @@ impl UdpSocket {
self.nonblock.load(core::sync::atomic::Ordering::Relaxed)
}
#[inline]
fn poll_iface_until_quiescent(iface: &dyn crate::net::Iface) {
loop {
let mut progressed = false;
for i in 0..IFACE_POLL_BATCH_ROUNDS {
if !iface.poll() {
return;
}
progressed = true;
if (i & 0x7) == 0x7 {
let pcb = ProcessManager::current_pcb();
if pcb.has_pending_signal_fast() && pcb.has_pending_not_masked_signal() {
return;
}
}
}
if progressed {
crate::sched::sched_yield();
} else {
return;
}
}
}
fn recv_timeout(&self) -> Option<crate::time::Duration> {
let us = self
.recv_timeout_us
@ -239,6 +264,18 @@ impl UdpSocket {
}
}
#[inline]
fn loopback_send_len_result(
payload_len: usize,
max_payload: usize,
) -> Result<usize, SystemError> {
if payload_len > max_payload || payload_len > u16::MAX as usize {
Err(SystemError::EMSGSIZE)
} else {
Ok(payload_len)
}
}
fn loopback_accepts_with_preconnect(
&self,
pkt: &LoopbackPacket,
@ -749,6 +786,7 @@ impl UdpSocket {
result,
send_iface,
dest,
dest_is_broadcast,
loopback_send,
send_iface_is_loopback,
mcast_ifindex,
@ -841,55 +879,34 @@ impl UdpSocket {
Ipv4(v4) => v4.is_loopback(),
Ipv6(v6) => v6.is_loopback(),
};
let is_broadcast = matches!(dest.addr, Ipv4(v4) if v4.is_broadcast());
if is_loopback {
let loopback_broadcast = self
.netns
.loopback_iface()
.map(|lo| lo.smol_iface().lock().inner.is_broadcast(&dest.addr))
.unwrap_or(false);
let is_broadcast = loopback_broadcast
|| (send_iface_is_loopback
&& bound_iface
.smol_iface()
.lock()
.inner
.is_broadcast(&dest.addr));
let should_loopback_send = is_loopback
|| ((is_multicast || is_broadcast)
&& (send_iface_is_loopback || loopback_broadcast));
if should_loopback_send {
let max_payload =
bound.with_socket(|socket| socket.payload_send_capacity());
if buf.len() > max_payload || buf.len() > u16::MAX as usize {
(
Err(SystemError::EMSGSIZE),
bound_iface,
Some(dest),
true,
send_iface_is_loopback,
mcast_ifindex,
None,
)
} else {
(
Ok(buf.len()),
bound_iface,
Some(dest),
true,
send_iface_is_loopback,
mcast_ifindex,
None,
)
}
} else if (is_multicast || is_broadcast) && send_iface_is_loopback {
let max_payload =
bound.with_socket(|socket| socket.payload_send_capacity());
if buf.len() > max_payload || buf.len() > u16::MAX as usize {
(
Err(SystemError::EMSGSIZE),
bound_iface,
Some(dest),
true,
send_iface_is_loopback,
mcast_ifindex,
None,
)
} else {
(
Ok(buf.len()),
bound_iface,
Some(dest),
true,
send_iface_is_loopback,
mcast_ifindex,
None,
)
}
(
Self::loopback_send_len_result(buf.len(), max_payload),
bound_iface,
Some(dest),
is_broadcast,
true,
send_iface_is_loopback,
mcast_ifindex,
None,
)
} else {
let mut send_iface = bound_iface.clone();
let mut restore_iface = None;
@ -913,6 +930,7 @@ impl UdpSocket {
ret,
send_iface,
Some(dest),
is_broadcast,
false,
send_iface_is_loopback,
mcast_ifindex,
@ -958,6 +976,14 @@ impl UdpSocket {
);
}
}
} else if dest_is_broadcast {
udp_bindings::deliver_broadcast_all(
&self.netns,
dest,
src_endpoint,
ifindex,
buf,
);
} else {
udp_bindings::deliver_unicast_loopback(
&self.netns,
@ -1016,7 +1042,7 @@ impl UdpSocket {
// Poll AFTER releasing the lock to avoid deadlock
// when socket sends to itself on loopback
send_iface.poll();
Self::poll_iface_until_quiescent(send_iface.as_ref());
if let Some(orig_iface) = restore_iface {
let mut inner_guard = self.inner.write();

View File

@ -127,6 +127,29 @@ pub fn deliver_multicast_all(
delivered
}
pub fn deliver_broadcast_all(
netns: &Arc<NetNamespace>,
dest: IpEndpoint,
src: IpEndpoint,
ifindex: i32,
payload: &[u8],
) -> usize {
let candidates = match_udp_bindings(netns, dest.addr, dest.port);
if candidates.is_empty() {
return 0;
}
let mut delivered = 0;
for cand in candidates {
if cand
.socket
.inject_loopback_packet(src, dest.addr, dest.port, ifindex, payload)
{
delivered += 1;
}
}
delivered
}
fn match_udp_bindings(
netns: &Arc<NetNamespace>,
dest_addr: IpAddress,

View File

@ -91,15 +91,52 @@ where
&self,
buf: &mut [u8],
flags: crate::net::socket::PMSG,
) -> Result<(usize, Endpoint), SystemError> {
let (recv_bytes, endpoint) = self
.inner
.read()
.try_recv(buf, flags)
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?;
) -> Result<(usize, usize, Endpoint), SystemError> {
let (recv_bytes, orig_len, endpoint) = self.inner.read().try_recv(buf, flags).map(
|(recv_bytes, orig_len, remote_endpoint)| {
(recv_bytes, orig_len, remote_endpoint.into())
},
)?;
// todo self.pollee.invalidate();
Ok((recv_bytes, endpoint))
Ok((recv_bytes, orig_len, endpoint))
}
#[inline]
fn recv_return_len(copy_len: usize, orig_len: usize, flags: PMSG) -> usize {
if flags.contains(PMSG::TRUNC) {
orig_len
} else {
copy_len
}
}
fn recv_from_inner(
&self,
buffer: &mut [u8],
flags: crate::net::socket::PMSG,
address: Option<crate::net::socket::endpoint::Endpoint>,
) -> Result<(usize, usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError>
{
if let Some(addr) = address {
let endpoint = addr.try_into()?;
self.inner
.write()
.connect(&endpoint, self.wait_queue.clone(), self.netns())?;
}
if self.is_nonblocking() || flags.contains(PMSG::DONTWAIT) {
self.try_recv(buffer, flags)
} else {
loop {
match self.try_recv(buffer, flags) {
Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => {
let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {});
}
result => break result,
}
}
}
}
/// 判断当前的netlink是否可以接收数据
@ -171,24 +208,8 @@ where
flags: crate::net::socket::PMSG,
address: Option<crate::net::socket::endpoint::Endpoint>,
) -> Result<(usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError> {
// log::info!("NetlinkSocket recv_from called");
if let Some(addr) = address {
self.connect(addr)?;
}
return if self.is_nonblocking() || flags.contains(PMSG::DONTWAIT) {
self.try_recv(buffer, flags)
} else {
loop {
match self.try_recv(buffer, flags) {
Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => {
let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {});
}
result => break result,
}
}
};
// self.try_recv(buffer, flags)
let (copy_len, orig_len, endpoint) = self.recv_from_inner(buffer, flags, address)?;
Ok((Self::recv_return_len(copy_len, orig_len, flags), endpoint))
}
fn check_io_event(&self) -> crate::filesystem::epoll::EPollEventType {
@ -220,8 +241,8 @@ where
let iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? };
let mut buf = iovs.new_buf(true);
let (recv_size, endpoint) = self.recv_from(&mut buf, flags, None)?;
iovs.scatter(&buf[..recv_size])?;
let (copy_len, orig_len, endpoint) = self.recv_from_inner(&mut buf, flags, None)?;
iovs.scatter(&buf[..copy_len])?;
if !msg.msg_name.is_null() {
let actual_len = endpoint.write_to_user_msghdr(msg.msg_name, msg.msg_namelen)?;
@ -232,7 +253,10 @@ where
msg.msg_controllen = 0;
msg.msg_flags = 0;
Ok(recv_size)
if orig_len > copy_len {
msg.msg_flags |= PMSG::TRUNC.bits() as i32;
}
Ok(Self::recv_return_len(copy_len, orig_len, flags))
}
fn send(&self, buffer: &[u8], flags: PMSG) -> Result<usize, SystemError> {

View File

@ -64,13 +64,14 @@ impl datagram_common::Bound for BoundNetlink<KobjectUeventMessage> {
&self,
writer: &mut [u8],
flags: PMSG,
) -> Result<(usize, Self::Endpoint), SystemError> {
) -> Result<(usize, usize, Self::Endpoint), SystemError> {
let mut receive_queue = self.receive_queue.0.lock();
let Some(message) = receive_queue.front() else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
};
let copied = writer.len().min(message.as_bytes().len());
let orig_len = message.as_bytes().len();
let copied = writer.len().min(orig_len);
if copied > 0 {
writer[..copied].copy_from_slice(&message.as_bytes()[..copied]);
}
@ -79,7 +80,7 @@ impl datagram_common::Bound for BoundNetlink<KobjectUeventMessage> {
receive_queue.pop_front();
}
Ok((copied, NetlinkSocketAddr::new_unspecified()))
Ok((copied, orig_len, NetlinkSocketAddr::new_unspecified()))
}
fn check_io_events(&self) -> EPollEventType {

View File

@ -90,20 +90,26 @@ impl datagram_common::Bound for BoundNetlink<RouteNlMessage> {
&self,
writer: &mut [u8],
flags: crate::net::socket::PMSG,
) -> Result<(usize, Self::Endpoint), SystemError> {
) -> Result<(usize, usize, Self::Endpoint), SystemError> {
let mut receive_queue = self.receive_queue.0.lock();
let Some(res) = receive_queue.front() else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
};
let len = {
let max = writer.len();
res.total_len().min(max)
let orig_len = res.total_len();
let copied = if writer.len() >= orig_len {
res.write_to(writer)?
} else {
let mut full = alloc::vec![0u8; orig_len];
let written = res.write_to(&mut full)?;
let copy_len = written.min(writer.len());
if copy_len > 0 {
writer[..copy_len].copy_from_slice(&full[..copy_len]);
}
copy_len
};
let _copied = res.write_to(writer)?;
if !flags.contains(PMSG::PEEK) {
receive_queue.pop_front();
}
@ -111,7 +117,7 @@ impl datagram_common::Bound for BoundNetlink<RouteNlMessage> {
// todo 目前这个信息只能来自内核
let remote = NetlinkSocketAddr::new_unspecified();
Ok((len, remote))
Ok((copied, orig_len, remote))
}
fn check_io_events(&self) -> EPollEventType {

View File

@ -76,7 +76,7 @@ fn add_addr(request_segment: &AddrSegment, netns: Arc<NetNamespace>) -> Result<(
let flags = NewRequestFlags::from_bits_truncate(request_segment.header().flags);
let mut exists = false;
let mut pushed = true;
let mut pushed = false;
iface.smol_iface().lock().update_ip_addrs(|ip_addrs| {
exists = ip_addrs.contains(&cidr);
@ -89,8 +89,8 @@ fn add_addr(request_segment: &AddrSegment, netns: Arc<NetNamespace>) -> Result<(
IpAddress::Ipv6(_) => ip_addrs.len(),
};
if ip_addrs.insert(insert_index, cidr).is_err() {
pushed = false;
if ip_addrs.insert(insert_index, cidr).is_ok() {
pushed = true;
}
}
});
@ -102,14 +102,15 @@ fn add_addr(request_segment: &AddrSegment, netns: Arc<NetNamespace>) -> Result<(
return Err(SystemError::EEXIST);
}
if flags.contains(NewRequestFlags::REPLACE) {
return Err(SystemError::ENOENT);
}
if !pushed {
return Err(SystemError::ENOSPC);
}
if let Err(err) = add_local_route(&iface, cidr) {
rollback_added_addr(&iface, cidr);
return Err(err);
}
sync_router_ip_addrs(&iface);
Ok(())
@ -132,6 +133,7 @@ fn del_addr(request_segment: &AddrSegment, netns: Arc<NetNamespace>) -> Result<(
return Err(SystemError::EADDRNOTAVAIL);
}
remove_local_route(&iface, cidr);
sync_router_ip_addrs(&iface);
Ok(())
@ -249,3 +251,64 @@ fn sync_router_ip_addrs(iface: &Arc<dyn Iface>) {
router_ip_addrs.clear();
router_ip_addrs.extend_from_slice(&smol_ip_addrs);
}
fn add_local_route(iface: &Arc<dyn Iface>, cidr: IpCidr) -> Result<(), SystemError> {
let mut pushed = false;
let via_router = cidr.address();
iface.smol_iface().lock().routes_mut().update(|routes| {
let exists = routes
.iter()
.any(|route| is_same_local_route(route, cidr, via_router));
if exists {
pushed = true;
return;
}
pushed = routes
.push(smoltcp::iface::Route {
cidr,
via_router,
preferred_until: None,
expires_at: None,
})
.is_ok();
});
if !pushed {
log::warn!(
"netlink add_addr: route table full while adding local route {} via {}",
cidr,
via_router
);
return Err(SystemError::ENOSPC);
}
Ok(())
}
fn remove_local_route(iface: &Arc<dyn Iface>, cidr: IpCidr) {
let via_router = cidr.address();
iface.smol_iface().lock().routes_mut().update(|routes| {
if let Some(index) = routes
.iter()
.position(|route| is_same_local_route(route, cidr, via_router))
{
routes.remove(index);
}
});
}
fn rollback_added_addr(iface: &Arc<dyn Iface>, cidr: IpCidr) {
iface.smol_iface().lock().update_ip_addrs(|ip_addrs| {
if let Some(index) = ip_addrs.iter().position(|configured| *configured == cidr) {
ip_addrs.remove(index);
}
});
sync_router_ip_addrs(iface);
}
#[inline]
fn is_same_local_route(route: &smoltcp::iface::Route, cidr: IpCidr, via_router: IpAddress) -> bool {
route.cidr == cidr && route.via_router == via_router
}

View File

@ -51,7 +51,7 @@ pub trait Bound {
&self,
writer: &mut [u8],
flags: PMSG,
) -> Result<(usize, Self::Endpoint), SystemError>;
) -> Result<(usize, usize, Self::Endpoint), SystemError>;
fn try_send(&self, buf: &[u8], to: &Self::Endpoint, flags: PMSG) -> Result<usize, SystemError>;
@ -151,7 +151,7 @@ where
&self,
writer: &mut [u8],
flags: PMSG,
) -> Result<(usize, UnboundSocket::Endpoint), SystemError> {
) -> Result<(usize, usize, UnboundSocket::Endpoint), SystemError> {
match self {
Inner::Unbound(_) => Err(SystemError::EAGAIN_OR_EWOULDBLOCK),
Inner::Bound(bound) => bound.try_recv(writer, flags),

View File

@ -0,0 +1 @@
IPv6UDPSockets/IPv6UDPUnboundSocketNetlinkTest.JoinSubnet/1

View File

@ -110,6 +110,8 @@ socket_ipv4_udp_unbound_loopback_test
socket_ipv4_udp_unbound_loopback_nogotsan_test
socket_netlink_test
socket_ip_unbound_netlink_test
socket_ipv4_udp_unbound_loopback_netlink_test
socket_ipv6_udp_unbound_loopback_netlink_test
# 信号处理测试
sigaction_test