asterinas/kernel/aster-nix/src/fs/utils/channel.rs

567 lines
15 KiB
Rust

// SPDX-License-Identifier: MPL-2.0
#![allow(dead_code)]
use core::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use aster_rights::{Read, ReadOp, TRights, Write, WriteOp};
use aster_rights_proc::require;
use ringbuf::{HeapConsumer as HeapRbConsumer, HeapProducer as HeapRbProducer, HeapRb};
use super::StatusFlags;
use crate::{
events::{IoEvents, Observer},
prelude::*,
process::signal::{Pollee, Poller},
};
/// A unidirectional communication channel, intended to implement IPC, e.g., pipe,
/// unix domain sockets, etc.
pub struct Channel<T> {
producer: Producer<T>,
consumer: Consumer<T>,
}
impl<T> Channel<T> {
pub fn with_capacity(capacity: usize) -> Result<Self> {
Self::with_capacity_and_flags(capacity, StatusFlags::empty())
}
pub fn with_capacity_and_flags(capacity: usize, flags: StatusFlags) -> Result<Self> {
let common = Arc::new(Common::with_capacity_and_flags(capacity, flags)?);
let producer = Producer(EndPoint::new(common.clone(), WriteOp::new()));
let consumer = Consumer(EndPoint::new(common, ReadOp::new()));
Ok(Self { producer, consumer })
}
pub fn split(self) -> (Producer<T>, Consumer<T>) {
let Self { producer, consumer } = self;
(producer, consumer)
}
pub fn producer(&self) -> &Producer<T> {
&self.producer
}
pub fn consumer(&self) -> &Consumer<T> {
&self.consumer
}
pub fn capacity(&self) -> usize {
self.producer.0.common.capacity()
}
}
pub struct Producer<T>(EndPoint<T, WriteOp>);
pub struct Consumer<T>(EndPoint<T, ReadOp>);
macro_rules! impl_common_methods_for_channel {
() => {
pub fn shutdown(&self) {
self.this_end().shutdown()
}
pub fn is_shutdown(&self) -> bool {
self.this_end().is_shutdown()
}
pub fn is_peer_shutdown(&self) -> bool {
self.peer_end().is_shutdown()
}
pub fn status_flags(&self) -> StatusFlags {
self.this_end().status_flags()
}
pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
self.this_end().set_status_flags(new_flags)
}
pub fn is_nonblocking(&self) -> bool {
self.this_end()
.status_flags()
.contains(StatusFlags::O_NONBLOCK)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.this_end().pollee.poll(mask, poller)
}
pub fn register_observer(
&self,
observer: Weak<dyn Observer<IoEvents>>,
mask: IoEvents,
) -> Result<()> {
self.this_end().pollee.register_observer(observer, mask);
Ok(())
}
pub fn unregister_observer(
&self,
observer: &Weak<dyn Observer<IoEvents>>,
) -> Option<Weak<dyn Observer<IoEvents>>> {
self.this_end().pollee.unregister_observer(observer)
}
};
}
impl<T> Producer<T> {
fn this_end(&self) -> &EndPointInner<HeapRbProducer<T>> {
&self.0.common.producer
}
fn peer_end(&self) -> &EndPointInner<HeapRbConsumer<T>> {
&self.0.common.consumer
}
fn update_pollee(&self) {
let this_end = self.this_end();
let peer_end = self.peer_end();
// Update the event of pollee in a critical region so that pollee
// always reflects the _true_ state of the underlying ring buffer
// regardless of any race conditions.
self.0.common.lock_event();
let rb = this_end.rb();
if rb.is_full() {
this_end.pollee.del_events(IoEvents::OUT);
}
if !rb.is_empty() {
peer_end.pollee.add_events(IoEvents::IN);
}
}
impl_common_methods_for_channel!();
}
impl<T: Copy> Producer<T> {
pub fn write(&self, buf: &[T]) -> Result<usize> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let res = self.try_write(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::OUT;
let poller = Poller::new();
loop {
let res = self.try_write(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel deal with timeout?
poller.wait()?;
}
}
}
fn try_write(&self, buf: &[T]) -> Result<usize> {
if self.is_shutdown() || self.is_peer_shutdown() {
return_errno!(Errno::EPIPE);
}
if buf.is_empty() {
return Ok(0);
}
let written_len = self.0.write(buf);
self.update_pollee();
if written_len > 0 {
Ok(written_len)
} else {
return_errno_with_message!(Errno::EAGAIN, "try write later");
}
}
}
impl<T> Producer<T> {
/// Pushes an item into the producer.
///
/// On failure, this method returns `Err` containing
/// the item fails to push.
pub fn push(&self, item: T) -> core::result::Result<(), (Error, T)> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let mut res = self.try_push(item);
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::OUT;
let poller = Poller::new();
loop {
let (_, item) = res.unwrap_err();
res = self.try_push(item);
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel deal with timeout?
if let Err(err) = poller.wait() {
return Err((err, res.unwrap_err().1));
}
}
}
}
fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> {
if self.is_shutdown() || self.is_peer_shutdown() {
let err = Error::with_message(Errno::EPIPE, "the pipe is shutdown");
return Err((err, item));
}
self.0.push(item).map_err(|item| {
let err = Error::with_message(Errno::EAGAIN, "try push again");
(err, item)
})?;
self.update_pollee();
Ok(())
}
}
impl<T> Drop for Producer<T> {
fn drop(&mut self) {
self.shutdown();
self.0.common.lock_event();
// When reading from a channel such as a pipe or a stream socket,
// POLLHUP merely indicates that the peer closed its end of the channel.
self.peer_end().pollee.add_events(IoEvents::HUP);
}
}
impl<T> Consumer<T> {
fn this_end(&self) -> &EndPointInner<HeapRbConsumer<T>> {
&self.0.common.consumer
}
fn peer_end(&self) -> &EndPointInner<HeapRbProducer<T>> {
&self.0.common.producer
}
fn update_pollee(&self) {
let this_end = self.this_end();
let peer_end = self.peer_end();
// Update the event of pollee in a critical region so that pollee
// always reflects the _true_ state of the underlying ring buffer
// regardless of any race conditions.
self.0.common.lock_event();
let rb = this_end.rb();
if rb.is_empty() {
this_end.pollee.del_events(IoEvents::IN);
}
if !rb.is_full() {
peer_end.pollee.add_events(IoEvents::OUT);
}
}
impl_common_methods_for_channel!();
}
impl<T: Copy> Consumer<T> {
pub fn read(&self, buf: &mut [T]) -> Result<usize> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let res = self.try_read(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::IN;
let poller = Poller::new();
loop {
let res = self.try_read(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel have timeout?
poller.wait()?;
}
}
}
fn try_read(&self, buf: &mut [T]) -> Result<usize> {
if self.is_shutdown() {
return_errno!(Errno::EPIPE);
}
if buf.is_empty() {
return Ok(0);
}
let read_len = self.0.read(buf);
self.update_pollee();
if self.is_peer_shutdown() {
return Ok(read_len);
}
if read_len > 0 {
Ok(read_len)
} else {
return_errno_with_message!(Errno::EAGAIN, "try read later");
}
}
}
impl<T> Consumer<T> {
/// Pops an item from the consumer
pub fn pop(&self) -> Result<T> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let res = self.try_pop();
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::IN;
let poller = Poller::new();
loop {
let res = self.try_pop();
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel have timeout?
poller.wait()?;
}
}
}
fn try_pop(&self) -> Result<T> {
if self.is_shutdown() {
return_errno_with_message!(Errno::EPIPE, "this end is shut down");
}
let item = self.0.pop();
self.update_pollee();
if let Some(item) = item {
return Ok(item);
}
if self.is_peer_shutdown() {
return_errno_with_message!(Errno::EPIPE, "remote end is shut down");
}
return_errno_with_message!(Errno::EAGAIN, "try pop again")
}
}
impl<T> Drop for Consumer<T> {
fn drop(&mut self) {
self.shutdown();
self.0.common.lock_event();
// POLLERR is also set for a file descriptor referring to the write end of a pipe
// when the read end has been closed.
self.peer_end().pollee.add_events(IoEvents::ERR);
}
}
struct EndPoint<T, R: TRights> {
common: Arc<Common<T>>,
rights: R,
}
impl<T, R: TRights> EndPoint<T, R> {
pub fn new(common: Arc<Common<T>>, rights: R) -> Self {
Self { common, rights }
}
}
impl<T: Copy, R: TRights> EndPoint<T, R> {
#[require(R > Read)]
pub fn read(&self, buf: &mut [T]) -> usize {
let mut rb = self.common.consumer.rb();
rb.pop_slice(buf)
}
#[require(R > Write)]
pub fn write(&self, buf: &[T]) -> usize {
let mut rb = self.common.producer.rb();
rb.push_slice(buf)
}
}
impl<T, R: TRights> EndPoint<T, R> {
/// Pushes an item into the endpoint.
/// If the `push` method failes, this method will return
/// `Err` containing the item that hasn't been pushed
#[require(R > Write)]
pub fn push(&self, item: T) -> core::result::Result<(), T> {
let mut rb = self.common.producer.rb();
rb.push(item)
}
/// Pops an item from the endpoint.
#[require(R > Read)]
pub fn pop(&self) -> Option<T> {
let mut rb = self.common.consumer.rb();
rb.pop()
}
}
struct Common<T> {
producer: EndPointInner<HeapRbProducer<T>>,
consumer: EndPointInner<HeapRbConsumer<T>>,
event_lock: Mutex<()>,
}
impl<T> Common<T> {
fn with_capacity_and_flags(capacity: usize, flags: StatusFlags) -> Result<Self> {
check_status_flags(flags)?;
if capacity == 0 {
return_errno_with_message!(Errno::EINVAL, "capacity cannot be zero");
}
let rb: HeapRb<T> = HeapRb::new(capacity);
let (rb_producer, rb_consumer) = rb.split();
let producer = EndPointInner::new(rb_producer, IoEvents::OUT, flags);
let consumer = EndPointInner::new(rb_consumer, IoEvents::empty(), flags);
let event_lock = Mutex::new(());
Ok(Self {
producer,
consumer,
event_lock,
})
}
pub fn lock_event(&self) -> MutexGuard<()> {
self.event_lock.lock()
}
pub fn capacity(&self) -> usize {
self.producer.rb().capacity()
}
}
struct EndPointInner<T> {
rb: Mutex<T>,
pollee: Pollee,
is_shutdown: AtomicBool,
status_flags: AtomicU32,
}
impl<T> EndPointInner<T> {
pub fn new(rb: T, init_events: IoEvents, status_flags: StatusFlags) -> Self {
Self {
rb: Mutex::new(rb),
pollee: Pollee::new(init_events),
is_shutdown: AtomicBool::new(false),
status_flags: AtomicU32::new(status_flags.bits()),
}
}
pub fn rb(&self) -> MutexGuard<T> {
self.rb.lock()
}
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::Acquire)
}
pub fn shutdown(&self) {
self.is_shutdown.store(true, Ordering::Release)
}
pub fn status_flags(&self) -> StatusFlags {
let bits = self.status_flags.load(Ordering::Relaxed);
StatusFlags::from_bits(bits).unwrap()
}
pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
check_status_flags(new_flags)?;
self.status_flags.store(new_flags.bits(), Ordering::Relaxed);
Ok(())
}
}
fn check_status_flags(flags: StatusFlags) -> Result<()> {
let valid_flags: StatusFlags = StatusFlags::O_NONBLOCK | StatusFlags::O_DIRECT;
if !valid_flags.contains(flags) {
return_errno_with_message!(Errno::EINVAL, "invalid flags");
}
if flags.contains(StatusFlags::O_DIRECT) {
return_errno_with_message!(Errno::EINVAL, "O_DIRECT is not supported");
}
Ok(())
}
fn should_io_return<T, E: AsRef<Error>>(
res: &core::result::Result<T, E>,
is_nonblocking: bool,
) -> bool {
if is_nonblocking {
return true;
}
match res {
Ok(_) => true,
Err(e) if e.as_ref().error() == Errno::EAGAIN => false,
Err(_) => true,
}
}
impl<T> AsRef<Error> for (Error, T) {
fn as_ref(&self) -> &Error {
&self.0
}
}
#[cfg(ktest)]
mod test {
use alloc::sync::Arc;
use crate::fs::utils::Channel;
#[ktest]
fn test_non_copy() {
#[derive(Clone, Debug, PartialEq, Eq)]
struct NonCopy(Arc<usize>);
let channel = Channel::with_capacity(16).unwrap();
let (producer, consumer) = channel.split();
let data = NonCopy(Arc::new(99));
let expected_data = data.clone();
for _ in 0..3 {
producer.push(data.clone()).unwrap();
}
for _ in 0..3 {
let data = consumer.pop().unwrap();
assert_eq!(data, expected_data);
}
}
}