387 lines
12 KiB
Rust
387 lines
12 KiB
Rust
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
use core::sync::atomic::{AtomicBool, Ordering};
|
|
|
|
use super::{connected::Connected, connecting::Connecting, init::Init, listen::Listen};
|
|
use crate::{
|
|
events::IoEvents,
|
|
fs::{
|
|
file_handle::FileLike,
|
|
utils::{InodeMode, Metadata, StatusFlags},
|
|
},
|
|
net::socket::{
|
|
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
|
MessageHeader, SendRecvFlags, SockShutdownCmd, Socket, SocketAddr,
|
|
},
|
|
prelude::*,
|
|
process::signal::{Pollable, Poller},
|
|
util::{MultiRead, MultiWrite},
|
|
};
|
|
|
|
pub struct VsockStreamSocket {
|
|
status: RwLock<Status>,
|
|
is_nonblocking: AtomicBool,
|
|
}
|
|
|
|
pub enum Status {
|
|
Init(Arc<Init>),
|
|
Listen(Arc<Listen>),
|
|
Connected(Arc<Connected>),
|
|
}
|
|
|
|
impl VsockStreamSocket {
|
|
pub fn new(nonblocking: bool) -> Self {
|
|
let init = Arc::new(Init::new());
|
|
Self {
|
|
status: RwLock::new(Status::Init(init)),
|
|
is_nonblocking: AtomicBool::new(nonblocking),
|
|
}
|
|
}
|
|
|
|
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
|
|
Self {
|
|
status: RwLock::new(Status::Connected(connected)),
|
|
is_nonblocking: AtomicBool::new(false),
|
|
}
|
|
}
|
|
|
|
fn is_nonblocking(&self) -> bool {
|
|
self.is_nonblocking.load(Ordering::Relaxed)
|
|
}
|
|
|
|
fn set_nonblocking(&self, nonblocking: bool) {
|
|
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
|
}
|
|
|
|
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
|
let listen = match &*self.status.read() {
|
|
Status::Listen(listen) => listen.clone(),
|
|
Status::Init(_) | Status::Connected(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
|
|
}
|
|
};
|
|
|
|
let connected = listen.try_accept()?;
|
|
listen.update_io_events();
|
|
|
|
let peer_addr = connected.peer_addr();
|
|
|
|
VSOCK_GLOBAL
|
|
.get()
|
|
.unwrap()
|
|
.insert_connected_socket(connected.id(), connected.clone());
|
|
|
|
VSOCK_GLOBAL
|
|
.get()
|
|
.unwrap()
|
|
.response(&connected.get_info())
|
|
.unwrap();
|
|
|
|
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
|
|
Ok((socket, peer_addr.into()))
|
|
}
|
|
|
|
fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result<usize> {
|
|
let inner = self.status.read();
|
|
match &*inner {
|
|
Status::Connected(connected) => connected.send(reader, flags),
|
|
Status::Init(_) | Status::Listen(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
|
}
|
|
}
|
|
}
|
|
|
|
fn try_recv(
|
|
&self,
|
|
writer: &mut dyn MultiWrite,
|
|
_flags: SendRecvFlags,
|
|
) -> Result<(usize, SocketAddr)> {
|
|
let connected = match &*self.status.read() {
|
|
Status::Connected(connected) => connected.clone(),
|
|
Status::Init(_) | Status::Listen(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
|
}
|
|
};
|
|
|
|
let read_size = connected.try_recv(writer)?;
|
|
connected.update_io_events();
|
|
|
|
let peer_addr = self.peer_addr()?;
|
|
// If buffer is now empty and the peer requested shutdown, finish shutting down the
|
|
// connection.
|
|
if connected.should_close() {
|
|
if let Err(e) = self.shutdown(SockShutdownCmd::SHUT_RDWR) {
|
|
debug!("The error is {:?}", e);
|
|
}
|
|
}
|
|
Ok((read_size, peer_addr))
|
|
}
|
|
|
|
fn recv(
|
|
&self,
|
|
writer: &mut dyn MultiWrite,
|
|
flags: SendRecvFlags,
|
|
) -> Result<(usize, SocketAddr)> {
|
|
if self.is_nonblocking() {
|
|
self.try_recv(writer, flags)
|
|
} else {
|
|
self.wait_events(IoEvents::IN, None, || self.try_recv(writer, flags))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Pollable for VsockStreamSocket {
|
|
fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
|
|
match &*self.status.read() {
|
|
Status::Init(init) => init.poll(mask, poller),
|
|
Status::Listen(listen) => listen.poll(mask, poller),
|
|
Status::Connected(connected) => connected.poll(mask, poller),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FileLike for VsockStreamSocket {
|
|
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
|
|
Some(self)
|
|
}
|
|
|
|
fn read(&self, writer: &mut VmWriter) -> Result<usize> {
|
|
// TODO: Set correct flags
|
|
let read_len = self
|
|
.recv(writer, SendRecvFlags::empty())
|
|
.map(|(len, _)| len)?;
|
|
Ok(read_len)
|
|
}
|
|
|
|
fn write(&self, reader: &mut VmReader) -> Result<usize> {
|
|
// TODO: Set correct flags
|
|
self.send(reader, SendRecvFlags::empty())
|
|
}
|
|
|
|
fn status_flags(&self) -> StatusFlags {
|
|
if self.is_nonblocking() {
|
|
StatusFlags::O_NONBLOCK
|
|
} else {
|
|
StatusFlags::empty()
|
|
}
|
|
}
|
|
|
|
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
|
|
if new_flags.contains(StatusFlags::O_NONBLOCK) {
|
|
self.set_nonblocking(true);
|
|
} else {
|
|
self.set_nonblocking(false);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn metadata(&self) -> Metadata {
|
|
// This is a dummy implementation.
|
|
// TODO: Add "SockFS" and link `VsockStreamSocket` to it.
|
|
Metadata::new_socket(
|
|
0,
|
|
InodeMode::from_bits_truncate(0o140777),
|
|
aster_block::BLOCK_SIZE,
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Socket for VsockStreamSocket {
|
|
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
|
|
let addr = VsockSocketAddr::try_from(sockaddr)?;
|
|
let inner = self.status.read();
|
|
match &*inner {
|
|
Status::Init(init) => init.bind(addr),
|
|
Status::Listen(_) | Status::Connected(_) => {
|
|
return_errno_with_message!(
|
|
Errno::EINVAL,
|
|
"cannot bind a listening or connected socket"
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Since blocking mode is supported, there is no need to store the connecting status.
|
|
// TODO: Refactor when nonblocking mode is supported.
|
|
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
|
|
let init = match &*self.status.read() {
|
|
Status::Init(init) => init.clone(),
|
|
Status::Listen(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is listened");
|
|
}
|
|
Status::Connected(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is connected");
|
|
}
|
|
};
|
|
let remote_addr = VsockSocketAddr::try_from(sockaddr)?;
|
|
let local_addr = init.bound_addr();
|
|
if let Some(addr) = local_addr {
|
|
if addr == remote_addr {
|
|
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
|
|
}
|
|
} else {
|
|
init.bind(VsockSocketAddr::any_addr())?;
|
|
}
|
|
|
|
let connecting = Arc::new(Connecting::new(remote_addr, init.bound_addr().unwrap()));
|
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
|
vsockspace.insert_connecting_socket(connecting.local_addr(), connecting.clone());
|
|
|
|
// Send request
|
|
vsockspace.request(&connecting.info()).unwrap();
|
|
// wait for response from driver
|
|
// TODO: Add timeout
|
|
let mut poller = Poller::new();
|
|
if !connecting
|
|
.poll(IoEvents::IN, Some(&mut poller))
|
|
.contains(IoEvents::IN)
|
|
{
|
|
if let Err(e) = poller.wait(None) {
|
|
vsockspace
|
|
.remove_connecting_socket(&connecting.local_addr())
|
|
.unwrap();
|
|
return Err(e);
|
|
}
|
|
}
|
|
|
|
vsockspace
|
|
.remove_connecting_socket(&connecting.local_addr())
|
|
.unwrap();
|
|
let connected = Arc::new(Connected::from_connecting(connecting));
|
|
*self.status.write() = Status::Connected(connected.clone());
|
|
// move connecting socket map to connected sockmap
|
|
vsockspace.insert_connected_socket(connected.id(), connected);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn listen(&self, backlog: usize) -> Result<()> {
|
|
let init = match &*self.status.read() {
|
|
Status::Init(init) => init.clone(),
|
|
Status::Listen(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is already listened");
|
|
}
|
|
Status::Connected(_) => {
|
|
return_errno_with_message!(Errno::EISCONN, "the socket is already connected");
|
|
}
|
|
};
|
|
let addr = init.bound_addr().ok_or(Error::with_message(
|
|
Errno::EINVAL,
|
|
"the socket is not bound",
|
|
))?;
|
|
let listen = Arc::new(Listen::new(addr, backlog));
|
|
*self.status.write() = Status::Listen(listen.clone());
|
|
|
|
// push listen socket into vsockspace
|
|
VSOCK_GLOBAL
|
|
.get()
|
|
.unwrap()
|
|
.insert_listen_socket(listen.addr(), listen);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
|
if self.is_nonblocking() {
|
|
self.try_accept()
|
|
} else {
|
|
self.wait_events(IoEvents::IN, None, || self.try_accept())
|
|
}
|
|
}
|
|
|
|
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
|
match &*self.status.read() {
|
|
Status::Connected(connected) => connected.shutdown(cmd),
|
|
Status::Init(_) | Status::Listen(_) => {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
|
}
|
|
}
|
|
}
|
|
|
|
fn sendmsg(
|
|
&self,
|
|
reader: &mut dyn MultiRead,
|
|
message_header: MessageHeader,
|
|
flags: SendRecvFlags,
|
|
) -> Result<usize> {
|
|
// TODO: Deal with flags
|
|
debug_assert!(flags.is_all_supported());
|
|
|
|
let MessageHeader {
|
|
control_message, ..
|
|
} = message_header;
|
|
|
|
if control_message.is_some() {
|
|
// TODO: Support sending control message
|
|
warn!("sending control message is not supported");
|
|
}
|
|
|
|
self.send(reader, flags)
|
|
}
|
|
|
|
fn recvmsg(
|
|
&self,
|
|
writer: &mut dyn MultiWrite,
|
|
flags: SendRecvFlags,
|
|
) -> Result<(usize, MessageHeader)> {
|
|
// TODO: Deal with flags
|
|
debug_assert!(flags.is_all_supported());
|
|
|
|
let (received_bytes, _) = self.recv(writer, flags)?;
|
|
|
|
// TODO: Receive control message
|
|
|
|
let messsge_header = MessageHeader::new(None, None);
|
|
|
|
Ok((received_bytes, messsge_header))
|
|
}
|
|
|
|
fn addr(&self) -> Result<SocketAddr> {
|
|
let inner = self.status.read();
|
|
let addr = match &*inner {
|
|
Status::Init(init) => init.bound_addr(),
|
|
Status::Listen(listen) => Some(listen.addr()),
|
|
Status::Connected(connected) => Some(connected.local_addr()),
|
|
};
|
|
addr.map(Into::<SocketAddr>::into)
|
|
.ok_or(Error::with_message(
|
|
Errno::EINVAL,
|
|
"The socket does not bind to addr",
|
|
))
|
|
}
|
|
|
|
fn peer_addr(&self) -> Result<SocketAddr> {
|
|
let inner = self.status.read();
|
|
if let Status::Connected(connected) = &*inner {
|
|
Ok(connected.peer_addr().into())
|
|
} else {
|
|
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for VsockStreamSocket {
|
|
fn drop(&mut self) {
|
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
|
let inner = self.status.read();
|
|
match &*inner {
|
|
Status::Init(init) => {
|
|
if let Some(addr) = init.bound_addr() {
|
|
vsockspace.recycle_port(&addr.port);
|
|
}
|
|
}
|
|
Status::Listen(listen) => {
|
|
vsockspace.recycle_port(&listen.addr().port);
|
|
vsockspace.remove_listen_socket(&listen.addr());
|
|
}
|
|
Status::Connected(connected) => {
|
|
if !connected.is_closed() {
|
|
vsockspace.reset(&connected.get_info()).unwrap();
|
|
}
|
|
vsockspace.remove_connected_socket(&connected.id());
|
|
vsockspace.recycle_port(&connected.local_addr().port);
|
|
}
|
|
}
|
|
}
|
|
}
|