Report `ENOBUFS` if netlink messages overrun

This commit is contained in:
Ruihan Li 2025-09-03 22:59:35 +08:00 committed by Tate, Hongliang Tian
parent b57c94d05d
commit c289f96d23
9 changed files with 210 additions and 97 deletions

View File

@ -50,6 +50,9 @@ impl<Message: 'static> BoundNetlink<Message> {
if !receive_queue.is_empty() {
events |= IoEvents::IN;
}
if receive_queue.has_errors() {
events |= IoEvents::ERR;
}
events
}

View File

@ -69,24 +69,15 @@ impl datagram_common::Bound for BoundNetlinkUevent {
let mut receive_queue = self.receive_queue.lock();
let Some(response) = receive_queue.peek() else {
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty");
};
receive_queue.dequeue_if(|response, response_len| {
let len = response_len.min(writer.sum_lens());
response.write_to(writer)?;
let len = {
let max_len = writer.sum_lens();
response.total_len().min(max_len)
};
let remote = *response.src_addr();
response.write_to(writer)?;
let remote = *response.src_addr();
if !flags.contains(SendRecvFlags::MSG_PEEK) {
receive_queue.dequeue().unwrap();
}
Ok((len, remote))
let should_dequeue = !flags.contains(SendRecvFlags::MSG_PEEK);
Ok((should_dequeue, (len, remote)))
})
}
fn check_io_events(&self) -> IoEvents {

View File

@ -5,7 +5,9 @@
use uevent::Uevent;
use crate::{
net::socket::netlink::{table::MulticastMessage, NetlinkSocketAddr},
net::socket::netlink::{
receiver::QueueableMessage, table::MulticastMessage, NetlinkSocketAddr,
},
prelude::*,
util::MultiWrite,
};
@ -39,11 +41,6 @@ impl UeventMessage {
&self.src_addr
}
/// Returns the total length of the uevent.
pub(super) fn total_len(&self) -> usize {
self.uevent.len()
}
/// Writes the uevent to the given `writer`.
pub(super) fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> {
let _nbytes = writer.write(&mut VmReader::from(self.uevent.as_bytes()))?;
@ -53,4 +50,10 @@ impl UeventMessage {
}
}
impl QueueableMessage for UeventMessage {
fn total_len(&self) -> usize {
self.uevent.len()
}
}
impl MulticastMessage for UeventMessage {}

View File

@ -16,6 +16,7 @@ pub(super) use segment::{
CSegmentType, SegmentBody,
};
use super::receiver::QueueableMessage;
use crate::{
prelude::*,
util::{MultiRead, MultiWrite},
@ -26,11 +27,11 @@ use crate::{
/// A netlink message can be transmitted to and from user space using a single send/receive syscall.
/// It consists of one or more [`ProtocolSegment`]s.
#[derive(Debug)]
pub struct Message<T: ProtocolSegment> {
pub struct Message<T> {
segments: Vec<T>,
}
impl<T: ProtocolSegment> Message<T> {
impl<T> Message<T> {
pub(super) const fn new(segments: Vec<T>) -> Self {
Self { segments }
}
@ -42,7 +43,9 @@ impl<T: ProtocolSegment> Message<T> {
pub(super) fn segments_mut(&mut self) -> &mut [T] {
&mut self.segments
}
}
impl<T: ProtocolSegment> Message<T> {
pub(super) fn read_from(reader: &mut dyn MultiRead) -> Result<Self> {
// FIXME: Does a request contain only one segment? We need to investigate further.
let segments = {
@ -60,8 +63,10 @@ impl<T: ProtocolSegment> Message<T> {
Ok(())
}
}
pub(super) fn total_len(&self) -> usize {
impl<T: ProtocolSegment> QueueableMessage for Message<T> {
fn total_len(&self) -> usize {
self.segments
.iter()
.map(|segment| segment.header().len as usize)

View File

@ -7,11 +7,20 @@ pub struct MessageReceiver<Message> {
pollee: Pollee,
}
pub(super) struct MessageQueue<Message>(VecDeque<Message>);
pub(super) struct MessageQueue<Message> {
messages: VecDeque<Message>,
total_length: usize,
error: Option<Error>,
}
impl<Message> MessageQueue<Message> {
/// Creates a pair of a [`MessageQueue`] and a [`MessageReceiver`].
pub(super) fn new_pair(pollee: Pollee) -> (Arc<Mutex<Self>>, MessageReceiver<Message>) {
let queue = Arc::new(Mutex::new(Self(VecDeque::new())));
let queue = Arc::new(Mutex::new(Self {
messages: VecDeque::new(),
total_length: 0,
error: None,
}));
let receiver = MessageReceiver {
message_queue: queue.clone(),
pollee,
@ -19,31 +28,88 @@ impl<Message> MessageQueue<Message> {
(queue, receiver)
}
/// Returns whether the message queue is empty.
pub(super) fn is_empty(&self) -> bool {
self.0.is_empty()
self.messages.is_empty()
}
pub(super) fn peek(&self) -> Option<&Message> {
self.0.front()
}
pub(super) fn dequeue(&mut self) -> Option<Message> {
self.0.pop_front()
}
pub(self) fn enqueue(&mut self, message: Message) -> Result<()> {
// FIXME: We should verify the socket buffer length to ensure
// that adding the message doesn't exceed the buffer capacity.
self.0.push_back(message);
Ok(())
/// Returns whether the message queue contains errors.
///
/// Currently, the message queue contains errors only if the queue is full but the kernel still
/// wants to enqueue new messages.
pub(super) fn has_errors(&self) -> bool {
self.error.is_some()
}
}
impl<Message> MessageReceiver<Message> {
pub(super) fn enqueue_message(&self, message: Message) -> Result<()> {
self.message_queue.lock().enqueue(message)?;
self.pollee.notify(IoEvents::IN);
/// Messages that fit into the [`MessageQueue`].
pub trait QueueableMessage {
/// Counts and returns the length of the message.
fn total_len(&self) -> usize;
}
Ok(())
impl<Message: QueueableMessage> MessageQueue<Message> {
/// Dequeues a message if executing the closure returns `Ok((true, _))`.
///
/// The closure will be executed with a reference to the message that is ready to be dequeued
/// and the length of the message.
///
/// If the queue contains errors (see [`Self::has_errors`]), the error will be cleared and
/// returned. In this case, the closure will not be executed.
pub(super) fn dequeue_if<F, R>(&mut self, f: F) -> Result<R>
where
F: FnOnce(&Message, usize) -> Result<(bool, R)>,
{
if let Some(error) = self.error.take() {
return Err(error);
}
let Some(message) = self.messages.front() else {
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty");
};
let length = message.total_len();
let (should_pop, result) = f(message, length)?;
if should_pop {
self.messages.pop_front().unwrap();
self.total_length -= length;
}
Ok(result)
}
/// Tries to enqueue a new message. Returns `false` if the buffer is full.
#[must_use]
pub(self) fn enqueue(&mut self, message: Message) -> bool {
let length = message.total_len();
// Currently, we don't support sending netlink messages between user spaces, so only the
// kernel can enqueue new messages. If the kernel fails to enqueue a new message, `ENOBUFS`
// will be returned when userspace calls `recv`.
if NETLINK_DEFAULT_BUF_SIZE - self.total_length < length {
self.error = Some(Error::with_message(
Errno::ENOBUFS,
"the receive buffer is full",
));
return false;
}
self.messages.push_back(message);
self.total_length += length;
true
}
}
impl<Message: QueueableMessage> MessageReceiver<Message> {
pub(super) fn enqueue_message(&self, message: Message) {
let is_ok = self.message_queue.lock().enqueue(message);
if is_ok {
self.pollee.notify(IoEvents::IN);
} else {
self.pollee.notify(IoEvents::ERR);
}
}
}
const NETLINK_DEFAULT_BUF_SIZE: usize = 65536;

View File

@ -101,25 +101,16 @@ impl datagram_common::Bound for BoundNetlinkRoute {
let mut receive_queue = self.receive_queue.lock();
let Some(response) = receive_queue.peek() else {
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty");
};
receive_queue.dequeue_if(|response, response_len| {
let len = response_len.min(writer.sum_lens());
response.write_to(writer)?;
let len = {
let max_len = writer.sum_lens();
response.total_len().min(max_len)
};
// TODO: The message can only come from kernel socket currently.
let remote = NetlinkSocketAddr::new_unspecified();
response.write_to(writer)?;
if !flags.contains(SendRecvFlags::MSG_PEEK) {
receive_queue.dequeue().unwrap();
}
// TODO: The message can only come from kernel socket currently.
let remote = NetlinkSocketAddr::new_unspecified();
Ok((len, remote))
let should_dequeue = !flags.contains(SendRecvFlags::MSG_PEEK);
Ok((should_dequeue, (len, remote)))
})
}
fn check_io_events(&self) -> IoEvents {

View File

@ -4,7 +4,10 @@ use multicast::MulticastGroup;
pub(super) use multicast::MulticastMessage;
use spin::Once;
use super::addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS};
use super::{
addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS},
receiver::QueueableMessage,
};
use crate::{
net::socket::netlink::{
addr::UNSPECIFIED_PORT, kobject_uevent::UeventMessage, receiver::MessageReceiver,
@ -46,7 +49,10 @@ pub trait SupportedNetlinkProtocol {
socket_table.bind(Self::socket_table(), addr, receiver)
}
fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()> {
fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()>
where
Self::Message: QueueableMessage,
{
let socket_table = Self::socket_table().read();
socket_table.unicast(dst_port, message)
}
@ -141,13 +147,17 @@ impl<Message: 'static> ProtocolSocketTable<Message> {
Ok(BoundHandle::new(socket_table, port, addr.groups()))
}
fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()> {
fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()>
where
Message: QueueableMessage,
{
let Some(receiver) = self.unicast_sockets.get(&dst_port) else {
// FIXME: Should we return error here?
return Ok(());
};
receiver.enqueue_message(message);
receiver.enqueue_message(message)
Ok(())
}
fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<()>
@ -163,9 +173,7 @@ impl<Message: 'static> ProtocolSocketTable<Message> {
let Some(receiver) = self.unicast_sockets.get(port_num) else {
continue;
};
// FIXME: Should we slightly ignore the error if the socket's buffer has no enough space?
receiver.enqueue_message(message.clone())?;
receiver.enqueue_message(message.clone());
}
}

View File

@ -1,6 +1,9 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{net::socket::netlink::addr::PortNum, prelude::*};
use crate::{
net::socket::netlink::{addr::PortNum, receiver::QueueableMessage},
prelude::*,
};
/// A netlink multicast group.
///
@ -34,4 +37,4 @@ impl MulticastGroup {
}
}
pub trait MulticastMessage: Clone {}
pub trait MulticastMessage: QueueableMessage + Clone {}

View File

@ -156,6 +156,14 @@ struct nl_req {
char abuf[4];
};
#define INIT_REQ(req) \
memset(&req, 0, sizeof(req)); \
req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); \
req.hdr.nlmsg_type = RTM_GETADDR; \
req.hdr.nlmsg_flags = NLM_F_REQUEST; \
req.hdr.nlmsg_seq = 1; \
req.ifa.ifa_family = AF_UNSPEC;
FN_TEST(get_addr_error)
{
int sock_fd;
@ -170,12 +178,7 @@ FN_TEST(get_addr_error)
// 1. Without NLM_F_DUMP flag
struct nl_req req;
memset(&req, 0, sizeof(req));
req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg));
req.hdr.nlmsg_type = RTM_GETADDR;
req.hdr.nlmsg_flags = NLM_F_REQUEST;
req.hdr.nlmsg_seq = 1;
req.ifa.ifa_family = AF_UNSPEC;
INIT_REQ(req);
struct iovec iov = { &req, req.hdr.nlmsg_len };
struct msghdr msg = { &sa, sizeof(sa), &iov, 1, NULL, 0, 0 };
@ -187,18 +190,20 @@ FN_TEST(get_addr_error)
-EOPNOTSUPP);
int found_new_addr;
#define TEST_KERNEL_RESPONSE \
found_new_addr = 0; \
while (1) { \
size_t recv_len = \
TEST_SUCC(recv(sock_fd, buffer, BUFFER_SIZE, 0)); \
\
int found_done = TEST_SUCC(find_new_addr_until_done( \
buffer, recv_len, &found_new_addr)); \
\
if (found_done != 0) { \
break; \
} \
#define TEST_KERNEL_RESPONSE \
found_new_addr = 0; \
while (1) { \
size_t recv_len = \
TEST_SUCC(recv(sock_fd, buffer, BUFFER_SIZE, 0)); \
\
int found_done = \
TEST_RES(find_new_addr_until_done(buffer, recv_len, \
&found_new_addr), \
_ret >= 0); \
\
if (found_done != 0) { \
break; \
} \
}
// 2. Invalid required index
@ -232,13 +237,7 @@ FN_TEST(bufsize_msgsize)
sock_fd = TEST_SUCC(
socket(AF_NETLINK, SOCK_RAW | SOCK_NONBLOCK, NETLINK_ROUTE));
memset(&req, 0, sizeof(req));
req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg));
req.hdr.nlmsg_type = RTM_GETADDR;
req.hdr.nlmsg_flags = NLM_F_REQUEST;
req.hdr.nlmsg_seq = 1;
req.ifa.ifa_family = AF_UNSPEC;
INIT_REQ(req);
// Send the request
TEST_RES(send(sock_fd, &req, sizeof(req), 0), _ret == sizeof(req));
@ -252,3 +251,47 @@ FN_TEST(bufsize_msgsize)
TEST_SUCC(close(sock_fd));
}
END_TEST()
int fill_receive_buffer(int sock_fd, const struct nl_req *req)
{
struct pollfd pfd = { .fd = sock_fd, .events = POLLIN | POLLOUT };
int i;
for (i = 0; i < 4096; ++i) {
if (send(sock_fd, req, sizeof(*req), 0) != sizeof(*req))
return -1;
if (poll(&pfd, 1, 0) < 0)
return -1;
switch (pfd.revents) {
case POLLIN | POLLOUT:
continue;
case POLLIN | POLLOUT | POLLERR:
return 0;
default:
return -1;
}
}
return -1;
}
FN_TEST(enobufs)
{
int sock_fd;
struct nl_req req;
sock_fd = TEST_SUCC(
socket(AF_NETLINK, SOCK_RAW | SOCK_NONBLOCK, NETLINK_ROUTE));
INIT_REQ(req);
TEST_RES(fill_receive_buffer(sock_fd, &req), _ret >= 0);
// Now the receive buffer is full. We can still send a new message,
// but the first `recv` should fail with `ENOBUFS`.
TEST_RES(send(sock_fd, &req, sizeof(req), 0), _ret == sizeof(req));
TEST_ERRNO(recv(sock_fd, buffer, 1, 0), ENOBUFS);
TEST_SUCC(recv(sock_fd, buffer, 1, 0));
TEST_SUCC(close(sock_fd));
}
END_TEST()