diff --git a/kernel/src/net/socket/unix/stream/listener.rs b/kernel/src/net/socket/unix/stream/listener.rs index 05a425547..c7f296fec 100644 --- a/kernel/src/net/socket/unix/stream/listener.rs +++ b/kernel/src/net/socket/unix/stream/listener.rs @@ -15,7 +15,7 @@ use crate::{ SockShutdownCmd, SocketAddr, }, prelude::*, - process::signal::{Pollee, Poller}, + process::signal::{Pauser, Pollee, Poller}, }; pub(super) struct Listener { @@ -105,6 +105,8 @@ impl Listener { impl Drop for Listener { fn drop(&mut self) { + self.backlog.shutdown(); + unregister_backlog(&self.backlog.addr().to_key()) } } @@ -147,37 +149,17 @@ impl BacklogTable { self.backlog_sockets.read().get(addr).cloned() } - fn push_incoming( - &self, - server_key: &UnixSocketAddrKey, - init: Init, - ) -> core::result::Result { - let backlog = match self.get_backlog(server_key) { - Some(backlog) => backlog, - None => { - return Err(( - Error::with_message( - Errno::ECONNREFUSED, - "no socket is listening at the remote address", - ), - init, - )) - } - }; - - backlog.push_incoming(init) - } - fn remove_backlog(&self, addr_key: &UnixSocketAddrKey) { self.backlog_sockets.write().remove(addr_key); } } -struct Backlog { +pub(super) struct Backlog { addr: UnixSocketAddrBound, pollee: Pollee, backlog: AtomicUsize, incoming_conns: Mutex>>, + pauser: Arc, } impl Backlog { @@ -193,6 +175,7 @@ impl Backlog { pollee, backlog: AtomicUsize::new(backlog), incoming_conns: Mutex::new(incoming_sockets), + pauser: Pauser::new(), } } @@ -200,7 +183,73 @@ impl Backlog { &self.addr } - fn push_incoming(&self, init: Init) -> core::result::Result { + fn pop_incoming(&self) -> Result { + let mut locked_incoming_conns = self.incoming_conns.lock(); + + let Some(incoming_conns) = &mut *locked_incoming_conns else { + return_errno_with_message!(Errno::EINVAL, "the socket is shut down for reading"); + }; + + let conn = incoming_conns.pop_front(); + if incoming_conns.is_empty() { + self.pollee.del_events(IoEvents::IN); + } + + drop(locked_incoming_conns); + + if conn.is_some() { + self.pauser.resume_one(); + } + + conn.ok_or_else(|| Error::with_message(Errno::EAGAIN, "no pending connection is available")) + } + + fn set_backlog(&self, backlog: usize) { + let old_backlog = self.backlog.swap(backlog, Ordering::Relaxed); + + if old_backlog < backlog { + self.pauser.resume_all(); + } + } + + fn shutdown(&self) { + let mut incoming_conns = self.incoming_conns.lock(); + + *incoming_conns = None; + self.pollee.add_events(IoEvents::HUP); + self.pollee.del_events(IoEvents::IN); + + drop(incoming_conns); + + self.pauser.resume_all(); + } + + fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { + self.pollee.poll(mask, poller) + } + + fn register_observer( + &self, + observer: Weak>, + mask: IoEvents, + ) -> Result<()> { + self.pollee.register_observer(observer, mask); + Ok(()) + } + + fn unregister_observer( + &self, + observer: &Weak>, + ) -> Option>> { + self.pollee.unregister_observer(observer) + } +} + +impl Backlog { + pub(super) fn push_incoming( + &self, + init: Init, + ) -> core::result::Result { let mut locked_incoming_conns = self.incoming_conns.lock(); let Some(incoming_conns) = &mut *locked_incoming_conns else { @@ -231,51 +280,14 @@ impl Backlog { Ok(client_conn) } - fn pop_incoming(&self) -> Result { - let mut locked_incoming_conns = self.incoming_conns.lock(); - - let Some(incoming_conns) = &mut *locked_incoming_conns else { - return_errno_with_message!(Errno::EINVAL, "the socket is shut down for reading"); - }; - - let conn = incoming_conns.pop_front(); - if incoming_conns.is_empty() { - self.pollee.del_events(IoEvents::IN); - } - - conn.ok_or_else(|| Error::with_message(Errno::EAGAIN, "no pending connection is available")) - } - - fn set_backlog(&self, backlog: usize) { - self.backlog.store(backlog, Ordering::Relaxed); - } - - fn shutdown(&self) { - let mut incoming_conns = self.incoming_conns.lock(); - - *incoming_conns = None; - self.pollee.add_events(IoEvents::HUP); - self.pollee.del_events(IoEvents::IN); - } - - fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - self.pollee.poll(mask, poller) - } - - fn register_observer( - &self, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - self.pollee.register_observer(observer, mask); - Ok(()) - } - - fn unregister_observer( - &self, - observer: &Weak>, - ) -> Option>> { - self.pollee.unregister_observer(observer) + pub(super) fn pause_until(&self, mut cond: F) -> Result<()> + where + F: FnMut() -> Result<()>, + { + self.pauser.pause_until(|| match cond() { + Err(err) if err.error() == Errno::EAGAIN => None, + result => Some(result), + })? } } @@ -283,9 +295,11 @@ fn unregister_backlog(addr: &UnixSocketAddrKey) { BACKLOG_TABLE.remove_backlog(addr); } -pub(super) fn push_incoming( - server_key: &UnixSocketAddrKey, - init: Init, -) -> core::result::Result { - BACKLOG_TABLE.push_incoming(server_key, init) +pub(super) fn get_backlog(server_key: &UnixSocketAddrKey) -> Result> { + BACKLOG_TABLE.get_backlog(server_key).ok_or_else(|| { + Error::with_message( + Errno::ECONNREFUSED, + "no socket is listening at the remote address", + ) + }) } diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index 34aeddf13..c47d16891 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -8,13 +8,13 @@ use takeable::Takeable; use super::{ connected::Connected, init::Init, - listener::{push_incoming, Listener}, + listener::{get_backlog, Backlog, Listener}, }; use crate::{ events::{IoEvents, Observer}, fs::{file_handle::FileLike, utils::StatusFlags}, net::socket::{ - unix::{addr::UnixSocketAddrKey, UnixSocketAddr}, + unix::UnixSocketAddr, util::{ copy_message_from_user, copy_message_to_user, create_message_buffer, send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader, @@ -23,7 +23,6 @@ use crate::{ }, prelude::*, process::signal::{Pollable, Poller}, - thread::Thread, util::IoVec, }; @@ -101,7 +100,7 @@ impl UnixStreamSocket { } } - fn try_connect(&self, remote_addr: &UnixSocketAddrKey) -> Result<()> { + fn try_connect(&self, backlog: &Arc) -> Result<()> { let mut state = self.state.write(); state.borrow_result(|owned_state| { @@ -127,7 +126,7 @@ impl UnixStreamSocket { } }; - let connected = match push_incoming(remote_addr, init) { + let connected = match backlog.push_incoming(init) { Ok(connected) => connected, Err((err, init)) => return (State::Init(init), Err(err)), }; @@ -239,29 +238,23 @@ impl Socket for UnixStreamSocket { fn connect(&self, socket_addr: SocketAddr) -> Result<()> { let remote_addr = UnixSocketAddr::try_from(socket_addr)?.connect()?; + let backlog = get_backlog(&remote_addr)?; - // Note that the Linux kernel implementation locks the remote socket and checks to see if - // it is listening first. This is different from our implementation, which locks the local - // socket and checks the state of the local socket first. - // - // The difference may result in different error codes, but it's doubtful that this will - // ever lead to real problems. - // - // See also . - - loop { - let res = self.try_connect(&remote_addr); - - if !res.is_err_and(|err| err.error() == Errno::EAGAIN) { - return res; - } - - // FIXME: Add `Pauser` in `Backlog` and use it to avoid this `Thread::yield_now`. - Thread::yield_now(); + if self.is_nonblocking() { + self.try_connect(&backlog) + } else { + backlog.pause_until(|| self.try_connect(&backlog)) } } fn listen(&self, backlog: usize) -> Result<()> { + const SOMAXCONN: usize = 4096; + + // Linux allows a maximum of `backlog + 1` sockets in the backlog queue. Although this + // seems to be mostly an implementation detail, we follow the exact Linux behavior to + // ensure that our regression tests pass with the Linux kernel. + let backlog = backlog.saturating_add(1).min(SOMAXCONN); + let mut state = self.state.write(); state.borrow_result(|owned_state| { diff --git a/test/apps/network/unix_err.c b/test/apps/network/unix_err.c index c25b1c978..80035b374 100644 --- a/test/apps/network/unix_err.c +++ b/test/apps/network/unix_err.c @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -339,6 +340,67 @@ FN_TEST(recv) } END_TEST() +FN_TEST(blocking_connect) +{ + int i; + int sk, sks[4]; + int pid; + + // Setup + + sk = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0)); + TEST_SUCC( + bind(sk, (struct sockaddr *)&UNIX_ADDR("\0"), PATH_OFFSET + 1)); + TEST_SUCC(listen(sk, 2)); + + for (i = 0; i < 3; ++i) { + sks[i] = TEST_SUCC( + socket(PF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0)); + TEST_SUCC(connect(sks[i], (struct sockaddr *)&UNIX_ADDR("\0"), + PATH_OFFSET + 1)); + } + +#define MAKE_TEST(child, parent, errno) \ + sks[i] = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0)); \ + TEST_ERRNO(connect(sks[i], (struct sockaddr *)&UNIX_ADDR("\0"), \ + PATH_OFFSET + 1), \ + EAGAIN); \ + TEST_SUCC(close(sks[i])); \ + \ + pid = TEST_SUCC(fork()); \ + if (pid == 0) { \ + usleep(300 * 1000); \ + CHECK(child); \ + exit(0); \ + } \ + TEST_SUCC(parent); \ + \ + sks[i] = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0)); \ + TEST_ERRNO(connect(sks[i], (struct sockaddr *)&UNIX_ADDR("\0"), \ + PATH_OFFSET + 1), \ + errno); \ + \ + TEST_SUCC(close(sks[i])); \ + TEST_SUCC(wait(NULL)); + + // Test 1: Accepting a connection resumes the blocked connection request + MAKE_TEST(accept(sk, NULL, NULL), 0, 0); + + // Test 2: Resetting the backlog resumes the blocked connection request + MAKE_TEST(listen(sk, 3), 0, 0); + + // Test 3: Closing the listener resumes the blocked connection request + MAKE_TEST(close(sk), close(sk), ECONNREFUSED); + +#undef MAKE_TEST + + // Clean up + + for (i = 0; i < 3; ++i) + TEST_SUCC(close(sks[i])); +} +END_TEST() + FN_TEST(ns_path) { int fd;