Report `ENOBUFS` if netlink messages overrun
This commit is contained in:
parent
b57c94d05d
commit
c289f96d23
|
@ -50,6 +50,9 @@ impl<Message: 'static> BoundNetlink<Message> {
|
|||
if !receive_queue.is_empty() {
|
||||
events |= IoEvents::IN;
|
||||
}
|
||||
if receive_queue.has_errors() {
|
||||
events |= IoEvents::ERR;
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
|
|
@ -69,24 +69,15 @@ impl datagram_common::Bound for BoundNetlinkUevent {
|
|||
|
||||
let mut receive_queue = self.receive_queue.lock();
|
||||
|
||||
let Some(response) = receive_queue.peek() else {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty");
|
||||
};
|
||||
receive_queue.dequeue_if(|response, response_len| {
|
||||
let len = response_len.min(writer.sum_lens());
|
||||
response.write_to(writer)?;
|
||||
|
||||
let len = {
|
||||
let max_len = writer.sum_lens();
|
||||
response.total_len().min(max_len)
|
||||
};
|
||||
let remote = *response.src_addr();
|
||||
|
||||
response.write_to(writer)?;
|
||||
|
||||
let remote = *response.src_addr();
|
||||
|
||||
if !flags.contains(SendRecvFlags::MSG_PEEK) {
|
||||
receive_queue.dequeue().unwrap();
|
||||
}
|
||||
|
||||
Ok((len, remote))
|
||||
let should_dequeue = !flags.contains(SendRecvFlags::MSG_PEEK);
|
||||
Ok((should_dequeue, (len, remote)))
|
||||
})
|
||||
}
|
||||
|
||||
fn check_io_events(&self) -> IoEvents {
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
use uevent::Uevent;
|
||||
|
||||
use crate::{
|
||||
net::socket::netlink::{table::MulticastMessage, NetlinkSocketAddr},
|
||||
net::socket::netlink::{
|
||||
receiver::QueueableMessage, table::MulticastMessage, NetlinkSocketAddr,
|
||||
},
|
||||
prelude::*,
|
||||
util::MultiWrite,
|
||||
};
|
||||
|
@ -39,11 +41,6 @@ impl UeventMessage {
|
|||
&self.src_addr
|
||||
}
|
||||
|
||||
/// Returns the total length of the uevent.
|
||||
pub(super) fn total_len(&self) -> usize {
|
||||
self.uevent.len()
|
||||
}
|
||||
|
||||
/// Writes the uevent to the given `writer`.
|
||||
pub(super) fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> {
|
||||
let _nbytes = writer.write(&mut VmReader::from(self.uevent.as_bytes()))?;
|
||||
|
@ -53,4 +50,10 @@ impl UeventMessage {
|
|||
}
|
||||
}
|
||||
|
||||
impl QueueableMessage for UeventMessage {
|
||||
fn total_len(&self) -> usize {
|
||||
self.uevent.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl MulticastMessage for UeventMessage {}
|
||||
|
|
|
@ -16,6 +16,7 @@ pub(super) use segment::{
|
|||
CSegmentType, SegmentBody,
|
||||
};
|
||||
|
||||
use super::receiver::QueueableMessage;
|
||||
use crate::{
|
||||
prelude::*,
|
||||
util::{MultiRead, MultiWrite},
|
||||
|
@ -26,11 +27,11 @@ use crate::{
|
|||
/// A netlink message can be transmitted to and from user space using a single send/receive syscall.
|
||||
/// It consists of one or more [`ProtocolSegment`]s.
|
||||
#[derive(Debug)]
|
||||
pub struct Message<T: ProtocolSegment> {
|
||||
pub struct Message<T> {
|
||||
segments: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: ProtocolSegment> Message<T> {
|
||||
impl<T> Message<T> {
|
||||
pub(super) const fn new(segments: Vec<T>) -> Self {
|
||||
Self { segments }
|
||||
}
|
||||
|
@ -42,7 +43,9 @@ impl<T: ProtocolSegment> Message<T> {
|
|||
pub(super) fn segments_mut(&mut self) -> &mut [T] {
|
||||
&mut self.segments
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ProtocolSegment> Message<T> {
|
||||
pub(super) fn read_from(reader: &mut dyn MultiRead) -> Result<Self> {
|
||||
// FIXME: Does a request contain only one segment? We need to investigate further.
|
||||
let segments = {
|
||||
|
@ -60,8 +63,10 @@ impl<T: ProtocolSegment> Message<T> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn total_len(&self) -> usize {
|
||||
impl<T: ProtocolSegment> QueueableMessage for Message<T> {
|
||||
fn total_len(&self) -> usize {
|
||||
self.segments
|
||||
.iter()
|
||||
.map(|segment| segment.header().len as usize)
|
||||
|
|
|
@ -7,11 +7,20 @@ pub struct MessageReceiver<Message> {
|
|||
pollee: Pollee,
|
||||
}
|
||||
|
||||
pub(super) struct MessageQueue<Message>(VecDeque<Message>);
|
||||
pub(super) struct MessageQueue<Message> {
|
||||
messages: VecDeque<Message>,
|
||||
total_length: usize,
|
||||
error: Option<Error>,
|
||||
}
|
||||
|
||||
impl<Message> MessageQueue<Message> {
|
||||
/// Creates a pair of a [`MessageQueue`] and a [`MessageReceiver`].
|
||||
pub(super) fn new_pair(pollee: Pollee) -> (Arc<Mutex<Self>>, MessageReceiver<Message>) {
|
||||
let queue = Arc::new(Mutex::new(Self(VecDeque::new())));
|
||||
let queue = Arc::new(Mutex::new(Self {
|
||||
messages: VecDeque::new(),
|
||||
total_length: 0,
|
||||
error: None,
|
||||
}));
|
||||
let receiver = MessageReceiver {
|
||||
message_queue: queue.clone(),
|
||||
pollee,
|
||||
|
@ -19,31 +28,88 @@ impl<Message> MessageQueue<Message> {
|
|||
(queue, receiver)
|
||||
}
|
||||
|
||||
/// Returns whether the message queue is empty.
|
||||
pub(super) fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
self.messages.is_empty()
|
||||
}
|
||||
|
||||
pub(super) fn peek(&self) -> Option<&Message> {
|
||||
self.0.front()
|
||||
}
|
||||
|
||||
pub(super) fn dequeue(&mut self) -> Option<Message> {
|
||||
self.0.pop_front()
|
||||
}
|
||||
|
||||
pub(self) fn enqueue(&mut self, message: Message) -> Result<()> {
|
||||
// FIXME: We should verify the socket buffer length to ensure
|
||||
// that adding the message doesn't exceed the buffer capacity.
|
||||
self.0.push_back(message);
|
||||
Ok(())
|
||||
/// Returns whether the message queue contains errors.
|
||||
///
|
||||
/// Currently, the message queue contains errors only if the queue is full but the kernel still
|
||||
/// wants to enqueue new messages.
|
||||
pub(super) fn has_errors(&self) -> bool {
|
||||
self.error.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Message> MessageReceiver<Message> {
|
||||
pub(super) fn enqueue_message(&self, message: Message) -> Result<()> {
|
||||
self.message_queue.lock().enqueue(message)?;
|
||||
self.pollee.notify(IoEvents::IN);
|
||||
/// Messages that fit into the [`MessageQueue`].
|
||||
pub trait QueueableMessage {
|
||||
/// Counts and returns the length of the message.
|
||||
fn total_len(&self) -> usize;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
impl<Message: QueueableMessage> MessageQueue<Message> {
|
||||
/// Dequeues a message if executing the closure returns `Ok((true, _))`.
|
||||
///
|
||||
/// The closure will be executed with a reference to the message that is ready to be dequeued
|
||||
/// and the length of the message.
|
||||
///
|
||||
/// If the queue contains errors (see [`Self::has_errors`]), the error will be cleared and
|
||||
/// returned. In this case, the closure will not be executed.
|
||||
pub(super) fn dequeue_if<F, R>(&mut self, f: F) -> Result<R>
|
||||
where
|
||||
F: FnOnce(&Message, usize) -> Result<(bool, R)>,
|
||||
{
|
||||
if let Some(error) = self.error.take() {
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let Some(message) = self.messages.front() else {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty");
|
||||
};
|
||||
|
||||
let length = message.total_len();
|
||||
let (should_pop, result) = f(message, length)?;
|
||||
if should_pop {
|
||||
self.messages.pop_front().unwrap();
|
||||
self.total_length -= length;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Tries to enqueue a new message. Returns `false` if the buffer is full.
|
||||
#[must_use]
|
||||
pub(self) fn enqueue(&mut self, message: Message) -> bool {
|
||||
let length = message.total_len();
|
||||
|
||||
// Currently, we don't support sending netlink messages between user spaces, so only the
|
||||
// kernel can enqueue new messages. If the kernel fails to enqueue a new message, `ENOBUFS`
|
||||
// will be returned when userspace calls `recv`.
|
||||
if NETLINK_DEFAULT_BUF_SIZE - self.total_length < length {
|
||||
self.error = Some(Error::with_message(
|
||||
Errno::ENOBUFS,
|
||||
"the receive buffer is full",
|
||||
));
|
||||
return false;
|
||||
}
|
||||
|
||||
self.messages.push_back(message);
|
||||
self.total_length += length;
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<Message: QueueableMessage> MessageReceiver<Message> {
|
||||
pub(super) fn enqueue_message(&self, message: Message) {
|
||||
let is_ok = self.message_queue.lock().enqueue(message);
|
||||
if is_ok {
|
||||
self.pollee.notify(IoEvents::IN);
|
||||
} else {
|
||||
self.pollee.notify(IoEvents::ERR);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const NETLINK_DEFAULT_BUF_SIZE: usize = 65536;
|
||||
|
|
|
@ -101,25 +101,16 @@ impl datagram_common::Bound for BoundNetlinkRoute {
|
|||
|
||||
let mut receive_queue = self.receive_queue.lock();
|
||||
|
||||
let Some(response) = receive_queue.peek() else {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty");
|
||||
};
|
||||
receive_queue.dequeue_if(|response, response_len| {
|
||||
let len = response_len.min(writer.sum_lens());
|
||||
response.write_to(writer)?;
|
||||
|
||||
let len = {
|
||||
let max_len = writer.sum_lens();
|
||||
response.total_len().min(max_len)
|
||||
};
|
||||
// TODO: The message can only come from kernel socket currently.
|
||||
let remote = NetlinkSocketAddr::new_unspecified();
|
||||
|
||||
response.write_to(writer)?;
|
||||
|
||||
if !flags.contains(SendRecvFlags::MSG_PEEK) {
|
||||
receive_queue.dequeue().unwrap();
|
||||
}
|
||||
|
||||
// TODO: The message can only come from kernel socket currently.
|
||||
let remote = NetlinkSocketAddr::new_unspecified();
|
||||
|
||||
Ok((len, remote))
|
||||
let should_dequeue = !flags.contains(SendRecvFlags::MSG_PEEK);
|
||||
Ok((should_dequeue, (len, remote)))
|
||||
})
|
||||
}
|
||||
|
||||
fn check_io_events(&self) -> IoEvents {
|
||||
|
|
|
@ -4,7 +4,10 @@ use multicast::MulticastGroup;
|
|||
pub(super) use multicast::MulticastMessage;
|
||||
use spin::Once;
|
||||
|
||||
use super::addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS};
|
||||
use super::{
|
||||
addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS},
|
||||
receiver::QueueableMessage,
|
||||
};
|
||||
use crate::{
|
||||
net::socket::netlink::{
|
||||
addr::UNSPECIFIED_PORT, kobject_uevent::UeventMessage, receiver::MessageReceiver,
|
||||
|
@ -46,7 +49,10 @@ pub trait SupportedNetlinkProtocol {
|
|||
socket_table.bind(Self::socket_table(), addr, receiver)
|
||||
}
|
||||
|
||||
fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()> {
|
||||
fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()>
|
||||
where
|
||||
Self::Message: QueueableMessage,
|
||||
{
|
||||
let socket_table = Self::socket_table().read();
|
||||
socket_table.unicast(dst_port, message)
|
||||
}
|
||||
|
@ -141,13 +147,17 @@ impl<Message: 'static> ProtocolSocketTable<Message> {
|
|||
Ok(BoundHandle::new(socket_table, port, addr.groups()))
|
||||
}
|
||||
|
||||
fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()> {
|
||||
fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()>
|
||||
where
|
||||
Message: QueueableMessage,
|
||||
{
|
||||
let Some(receiver) = self.unicast_sockets.get(&dst_port) else {
|
||||
// FIXME: Should we return error here?
|
||||
return Ok(());
|
||||
};
|
||||
receiver.enqueue_message(message);
|
||||
|
||||
receiver.enqueue_message(message)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<()>
|
||||
|
@ -163,9 +173,7 @@ impl<Message: 'static> ProtocolSocketTable<Message> {
|
|||
let Some(receiver) = self.unicast_sockets.get(port_num) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// FIXME: Should we slightly ignore the error if the socket's buffer has no enough space?
|
||||
receiver.enqueue_message(message.clone())?;
|
||||
receiver.enqueue_message(message.clone());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use crate::{net::socket::netlink::addr::PortNum, prelude::*};
|
||||
use crate::{
|
||||
net::socket::netlink::{addr::PortNum, receiver::QueueableMessage},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
/// A netlink multicast group.
|
||||
///
|
||||
|
@ -34,4 +37,4 @@ impl MulticastGroup {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait MulticastMessage: Clone {}
|
||||
pub trait MulticastMessage: QueueableMessage + Clone {}
|
||||
|
|
|
@ -156,6 +156,14 @@ struct nl_req {
|
|||
char abuf[4];
|
||||
};
|
||||
|
||||
#define INIT_REQ(req) \
|
||||
memset(&req, 0, sizeof(req)); \
|
||||
req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); \
|
||||
req.hdr.nlmsg_type = RTM_GETADDR; \
|
||||
req.hdr.nlmsg_flags = NLM_F_REQUEST; \
|
||||
req.hdr.nlmsg_seq = 1; \
|
||||
req.ifa.ifa_family = AF_UNSPEC;
|
||||
|
||||
FN_TEST(get_addr_error)
|
||||
{
|
||||
int sock_fd;
|
||||
|
@ -170,12 +178,7 @@ FN_TEST(get_addr_error)
|
|||
|
||||
// 1. Without NLM_F_DUMP flag
|
||||
struct nl_req req;
|
||||
memset(&req, 0, sizeof(req));
|
||||
req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg));
|
||||
req.hdr.nlmsg_type = RTM_GETADDR;
|
||||
req.hdr.nlmsg_flags = NLM_F_REQUEST;
|
||||
req.hdr.nlmsg_seq = 1;
|
||||
req.ifa.ifa_family = AF_UNSPEC;
|
||||
INIT_REQ(req);
|
||||
|
||||
struct iovec iov = { &req, req.hdr.nlmsg_len };
|
||||
struct msghdr msg = { &sa, sizeof(sa), &iov, 1, NULL, 0, 0 };
|
||||
|
@ -187,18 +190,20 @@ FN_TEST(get_addr_error)
|
|||
-EOPNOTSUPP);
|
||||
|
||||
int found_new_addr;
|
||||
#define TEST_KERNEL_RESPONSE \
|
||||
found_new_addr = 0; \
|
||||
while (1) { \
|
||||
size_t recv_len = \
|
||||
TEST_SUCC(recv(sock_fd, buffer, BUFFER_SIZE, 0)); \
|
||||
\
|
||||
int found_done = TEST_SUCC(find_new_addr_until_done( \
|
||||
buffer, recv_len, &found_new_addr)); \
|
||||
\
|
||||
if (found_done != 0) { \
|
||||
break; \
|
||||
} \
|
||||
#define TEST_KERNEL_RESPONSE \
|
||||
found_new_addr = 0; \
|
||||
while (1) { \
|
||||
size_t recv_len = \
|
||||
TEST_SUCC(recv(sock_fd, buffer, BUFFER_SIZE, 0)); \
|
||||
\
|
||||
int found_done = \
|
||||
TEST_RES(find_new_addr_until_done(buffer, recv_len, \
|
||||
&found_new_addr), \
|
||||
_ret >= 0); \
|
||||
\
|
||||
if (found_done != 0) { \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
|
||||
// 2. Invalid required index
|
||||
|
@ -232,13 +237,7 @@ FN_TEST(bufsize_msgsize)
|
|||
|
||||
sock_fd = TEST_SUCC(
|
||||
socket(AF_NETLINK, SOCK_RAW | SOCK_NONBLOCK, NETLINK_ROUTE));
|
||||
|
||||
memset(&req, 0, sizeof(req));
|
||||
req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg));
|
||||
req.hdr.nlmsg_type = RTM_GETADDR;
|
||||
req.hdr.nlmsg_flags = NLM_F_REQUEST;
|
||||
req.hdr.nlmsg_seq = 1;
|
||||
req.ifa.ifa_family = AF_UNSPEC;
|
||||
INIT_REQ(req);
|
||||
|
||||
// Send the request
|
||||
TEST_RES(send(sock_fd, &req, sizeof(req), 0), _ret == sizeof(req));
|
||||
|
@ -252,3 +251,47 @@ FN_TEST(bufsize_msgsize)
|
|||
TEST_SUCC(close(sock_fd));
|
||||
}
|
||||
END_TEST()
|
||||
|
||||
int fill_receive_buffer(int sock_fd, const struct nl_req *req)
|
||||
{
|
||||
struct pollfd pfd = { .fd = sock_fd, .events = POLLIN | POLLOUT };
|
||||
int i;
|
||||
|
||||
for (i = 0; i < 4096; ++i) {
|
||||
if (send(sock_fd, req, sizeof(*req), 0) != sizeof(*req))
|
||||
return -1;
|
||||
if (poll(&pfd, 1, 0) < 0)
|
||||
return -1;
|
||||
switch (pfd.revents) {
|
||||
case POLLIN | POLLOUT:
|
||||
continue;
|
||||
case POLLIN | POLLOUT | POLLERR:
|
||||
return 0;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
FN_TEST(enobufs)
|
||||
{
|
||||
int sock_fd;
|
||||
struct nl_req req;
|
||||
|
||||
sock_fd = TEST_SUCC(
|
||||
socket(AF_NETLINK, SOCK_RAW | SOCK_NONBLOCK, NETLINK_ROUTE));
|
||||
INIT_REQ(req);
|
||||
|
||||
TEST_RES(fill_receive_buffer(sock_fd, &req), _ret >= 0);
|
||||
|
||||
// Now the receive buffer is full. We can still send a new message,
|
||||
// but the first `recv` should fail with `ENOBUFS`.
|
||||
TEST_RES(send(sock_fd, &req, sizeof(req), 0), _ret == sizeof(req));
|
||||
TEST_ERRNO(recv(sock_fd, buffer, 1, 0), ENOBUFS);
|
||||
TEST_SUCC(recv(sock_fd, buffer, 1, 0));
|
||||
|
||||
TEST_SUCC(close(sock_fd));
|
||||
}
|
||||
END_TEST()
|
||||
|
|
Loading…
Reference in New Issue