Replace the use of `read_volatile` and `write_volatile` with assembly code

This commit is contained in:
Hsy-Intel 2026-02-10 11:49:52 +08:00
parent 5f8b019369
commit c87549ef0a
14 changed files with 924 additions and 69 deletions

View File

@ -5,12 +5,8 @@ use core::{fmt::Debug, marker::PhantomData};
use aster_rights::{Dup, Exec, Full, Read, Signal, TRightSet, TRights, Write}; use aster_rights::{Dup, Exec, Full, Read, Signal, TRightSet, TRights, Write};
use aster_rights_proc::require; use aster_rights_proc::require;
use ostd::{ use ostd::{
Error, Result, Result,
mm::{ mm::{Daddr, HasDaddr, HasPaddr, Paddr, PodOnce, VmIo, VmIoOnce, dma::DmaDirection},
Daddr, HasDaddr, HasPaddr, Paddr, PodOnce, VmIo, VmIoOnce,
dma::DmaDirection,
io_util::{HasVmReaderWriter, VmReaderWriterTypes},
},
}; };
use ostd_pod::Pod; use ostd_pod::Pod;
@ -239,18 +235,12 @@ impl<T: Pod, M: VmIo, R: TRights> SafePtr<T, M, TRightSet<R>> {
/// - the Read right for another pointer. /// - the Read right for another pointer.
#[require(R > Write)] #[require(R > Write)]
#[require(R1 > Read)] #[require(R1 > Read)]
pub fn copy_from<M1: HasVmReaderWriter, R1: TRights>( pub fn copy_from<M1: VmIo, R1: TRights>(
&self, &self,
ptr: &SafePtr<T, M1, TRightSet<R1>>, ptr: &SafePtr<T, M1, TRightSet<R1>>,
) -> Result<()> { ) -> Result<()> {
let mut reader = M1::Types::to_reader_result(ptr.vm_obj.reader())?.to_fallible(); let val = ptr.vm_obj.read_val::<T>(ptr.offset)?;
self.vm_obj.write_val(self.offset, &val)
if reader.remain() < size_of::<T>() {
return Err(Error::InvalidArgs);
}
reader.limit(size_of::<T>());
self.vm_obj.write(self.offset, &mut reader)
} }
} }

View File

@ -4,7 +4,7 @@ use alloc::sync::Arc;
use aster_framebuffer::{ColorMapEntry, FRAMEBUFFER, FrameBuffer, MAX_CMAP_SIZE, PixelFormat}; use aster_framebuffer::{ColorMapEntry, FRAMEBUFFER, FrameBuffer, MAX_CMAP_SIZE, PixelFormat};
use device_id::{DeviceId, MajorId, MinorId}; use device_id::{DeviceId, MajorId, MinorId};
use ostd::mm::{HasPaddr, HasSize, VmIo, io_util::HasVmReaderWriter}; use ostd::mm::{HasPaddr, HasSize, VmIo};
use super::registry::char; use super::registry::char;
use crate::{ use crate::{
@ -412,20 +412,26 @@ impl InodeIo for FbHandle {
return Ok(0); return Ok(0);
} }
let mut reader = self.framebuffer.io_mem().reader(); let io_mem = self.framebuffer.io_mem();
let size = io_mem.size();
if offset >= reader.remain() { if offset >= size {
return Ok(0); return Ok(0);
} }
reader.skip(offset);
let mut reader = reader.to_fallible(); let len = writer.avail().min(size - offset);
let len = match reader.read_fallible(writer) { if len == 0 {
Ok(len) => len, return Ok(0);
Err((err, 0)) => return Err(err.into()), }
Err((_err, len)) => len,
};
{
// Create a new writer. We should preserve the end of the original writer.
let mut new_writer = writer.fork();
new_writer.limit(len);
io_mem.read(offset, &mut new_writer)?;
}
// Synchronize the cursor inside the original writer.
writer.skip(len);
Ok(len) Ok(len)
} }
@ -439,22 +445,27 @@ impl InodeIo for FbHandle {
return Ok(0); return Ok(0);
} }
let mut writer = self.framebuffer.io_mem().writer(); let io_mem = self.framebuffer.io_mem();
if offset >= writer.avail() { let size = io_mem.size();
if offset >= size {
return_errno_with_message!( return_errno_with_message!(
Errno::ENOSPC, Errno::ENOSPC,
"the write offset is beyond the framebuffer size" "the write offset is beyond the framebuffer size"
); );
} }
writer.skip(offset);
let mut writer = writer.to_fallible(); let len = reader.remain().min(size - offset);
let len = match writer.write_fallible(reader) { if len == 0 {
Ok(len) => len, return Ok(0);
Err((err, 0)) => return Err(err.into()), }
Err((_err, len)) => len,
};
// Create a new reader. We should preserve the end of the original reader.
let mut new_reader = reader.clone();
new_reader.limit(len);
io_mem.write(offset, &mut new_reader)?;
// Synchronize the cursor inside the original reader.
reader.skip(len);
Ok(len) Ok(len)
} }
} }

View File

@ -0,0 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
use crate::mm::PodOnce;
pub(crate) unsafe fn read_once<T: PodOnce>(ptr: *const T) -> T {
// TODO: Use arch-specific single-instruction load for LoongArch.
// For detail, see https://github.com/asterinas/asterinas/issues/2948.
unsafe { core::ptr::read_volatile(ptr) }
}
pub(crate) unsafe fn write_once<T: PodOnce>(ptr: *mut T, val: T) {
// TODO: Use arch-specific single-instruction store for LoongArch.
// For detail, see https://github.com/asterinas/asterinas/issues/2948.
unsafe { core::ptr::write_volatile(ptr, val) }
}

View File

@ -4,6 +4,8 @@ use alloc::vec::Vec;
use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder}; use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder};
pub(crate) mod io_mem;
/// Initializes the allocatable MMIO area based on the LoongArch memory /// Initializes the allocatable MMIO area based on the LoongArch memory
/// distribution map. /// distribution map.
/// ///

View File

@ -7,7 +7,7 @@
pub mod boot; pub mod boot;
pub mod cpu; pub mod cpu;
pub mod device; pub mod device;
mod io; pub(crate) mod io;
pub(crate) mod iommu; pub(crate) mod iommu;
pub(crate) mod irq; pub(crate) mod irq;
pub(crate) mod mm; pub(crate) mod mm;

View File

@ -5,7 +5,11 @@
use spin::Once; use spin::Once;
use crate::{ use crate::{
arch::{boot::DEVICE_TREE, mm::paddr_to_daddr}, arch::{
boot::DEVICE_TREE,
io::io_mem::{read_once, write_once},
mm::paddr_to_daddr,
},
console::uart_ns16650a::{Ns16550aAccess, Ns16550aRegister, Ns16550aUart}, console::uart_ns16650a::{Ns16550aAccess, Ns16550aRegister, Ns16550aUart},
sync::{LocalIrqDisabled, SpinLock}, sync::{LocalIrqDisabled, SpinLock},
}; };
@ -34,12 +38,12 @@ impl SerialAccess {
impl Ns16550aAccess for SerialAccess { impl Ns16550aAccess for SerialAccess {
fn read(&self, reg: Ns16550aRegister) -> u8 { fn read(&self, reg: Ns16550aRegister) -> u8 {
// SAFETY: `self.base + reg` is a valid register of the serial port. // SAFETY: `self.base + reg` is a valid register of the serial port.
unsafe { core::ptr::read_volatile(self.base.add(reg as u8 as usize)) } unsafe { read_once(self.base.add(usize::from(reg as u8))) }
} }
fn write(&mut self, reg: Ns16550aRegister, val: u8) { fn write(&mut self, reg: Ns16550aRegister, val: u8) {
// SAFETY: `self.base + reg` is a valid register of the serial port. // SAFETY: `self.base + reg` is a valid register of the serial port.
unsafe { core::ptr::write_volatile(self.base.add(reg as u8 as usize), val) }; unsafe { write_once(self.base.add(usize::from(reg as u8)), val) }
} }
} }

View File

@ -0,0 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
use crate::mm::PodOnce;
pub(crate) unsafe fn read_once<T: PodOnce>(ptr: *const T) -> T {
// TODO: Use arch-specific single-instruction load for RISC-V.
// For detail, see https://github.com/asterinas/asterinas/issues/2948.
unsafe { core::ptr::read_volatile(ptr) }
}
pub(crate) unsafe fn write_once<T: PodOnce>(ptr: *mut T, val: T) {
// TODO: Use arch-specific single-instruction store for RISC-V.
// For detail, see https://github.com/asterinas/asterinas/issues/2948.
unsafe { core::ptr::write_volatile(ptr, val) }
}

View File

@ -4,6 +4,8 @@ use alloc::vec::Vec;
use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder}; use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder};
pub(crate) mod io_mem;
/// Initializes the allocatable MMIO area based on the RISC-V memory /// Initializes the allocatable MMIO area based on the RISC-V memory
/// distribution map. /// distribution map.
/// ///

View File

@ -7,7 +7,7 @@
pub mod boot; pub mod boot;
pub mod cpu; pub mod cpu;
pub mod device; pub mod device;
mod io; pub(crate) mod io;
pub(crate) mod iommu; pub(crate) mod iommu;
pub mod irq; pub mod irq;
pub(crate) mod mm; pub(crate) mod mm;

View File

@ -0,0 +1,186 @@
// SPDX-License-Identifier: MPL-2.0
use core::arch::asm;
use crate::mm::PodOnce;
/// Reads from a pointer with a non-tearing memory load.
///
/// This function is semantically equivalent to `core::ptr::read_volatile`
/// but is manually implemented with a single memory load instruction.
///
/// # Safety
///
/// Same as the safety requirement of `core::ptr::read_volatile`.
///
/// # Guarantee
///
/// The single-memory-load-instruction guarantee is particularly useful for
/// Confidential VMs (CVMs) such as Intel TDX and AMD SEV,
/// where the memory load may cause CPU exceptions (#VE and #VC, respectively)
/// and the kernel has to handle such exceptions
/// by decoding the faulting CPU instruction.
/// As such, the kernel must be compiled to emit simple load/store CPU instructions.
pub(crate) unsafe fn read_once<T: PodOnce>(ptr: *const T) -> T {
let mut val: u64 = 0;
// SAFETY: The caller guarantees `ptr` is valid for reads.
unsafe {
// EFFICIENCY: The match should be optimized out for release builds.
//
// This match is resolved at compile-time via monomorphization.
// Since `size_of::<T>()` is a constant for each concrete instance of this
// function, the compiler eliminates the branch and only emits the
// `MOV` instruction for the matching size.
match size_of::<T>() {
1 => {
asm!("mov {0:l}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags))
}
2 => {
asm!("mov {0:x}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags))
}
4 => {
asm!("mov {0:e}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags))
}
8 => {
asm!("mov {0}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags))
}
_ => core::hint::unreachable_unchecked(),
}
// EFFICIENCY: This should compile to a no-op for release builds.
// This is because both the source and destination locations fit in a register.
// So it only _re-interprets_ bits and no copying is needed.
core::ptr::read((&val as *const u64).cast::<T>())
}
}
/// Writes to a pointer with a non-tearing memory store.
///
/// # Safety
///
/// Same as the safety requirement of `core::ptr::write_volatile`.
///
/// # Guarantee
///
/// Refer to the "Guarantee" section of [`read_once`].
pub(crate) unsafe fn write_once<T: PodOnce>(ptr: *mut T, val: T) {
let mut tmp: u64 = 0;
// SAFETY: The caller guarantees `ptr` is valid for writes.
unsafe {
// EFFICIENCY: This should be a no-op for release build.
// This is because both the source and destination locations fit in a register.
// So it only _re-interprets_ bits and no copying is needed.
core::ptr::write((&mut tmp as *mut u64).cast::<T>(), val);
// EFFICIENCY: The match here has no overhead for release build.
//
// This match is resolved at compile-time via monomorphization.
// Since `size_of::<T>()` is a constant for each concrete instance of this
// function, the compiler eliminates the branch and only emits the
// `MOV` instruction for the matching size.
match size_of::<T>() {
1 => {
asm!("mov [{0}], {1:l}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags))
}
2 => {
asm!("mov [{0}], {1:x}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags))
}
4 => {
asm!("mov [{0}], {1:e}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags))
}
8 => {
asm!("mov [{0}], {1}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags))
}
_ => core::hint::unreachable_unchecked(),
}
}
}
/// Copies from MMIO to regular memory using string move instructions.
///
/// # Safety
///
/// - `src` must be valid for MMIO reads of `count` bytes.
/// - `dst` must be valid for writes of `count` bytes.
#[cfg(not(feature = "cvm_guest"))]
pub(crate) unsafe fn copy_from_bulk(mut dst: *mut u8, mut src: *const u8, mut count: usize) {
if count == 0 {
return;
}
// Align source IO to 2 bytes.
if src.addr() & 1 != 0 {
// SAFETY: The caller guarantees both pointers are valid.
unsafe {
asm!("movsb", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags));
}
count -= 1;
}
if count > 1 && src.addr() & 2 != 0 {
// SAFETY: The caller guarantees both pointers are valid for 2 bytes.
unsafe {
asm!("movsw", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags));
}
count -= 2;
}
if count > 0 {
// SAFETY: The caller guarantees both pointers are valid for `count` bytes.
unsafe {
asm!(
"rep movsb",
inout("rdi") dst,
inout("rsi") src,
inout("rcx") count,
options(nostack, preserves_flags)
);
}
let _ = (dst, src, count);
}
}
/// Copies from regular memory to MMIO using string move instructions.
///
/// # Safety
///
/// - `src` must be valid for reads of `count` bytes.
/// - `dst` must be valid for MMIO writes of `count` bytes.
#[cfg(not(feature = "cvm_guest"))]
pub(crate) unsafe fn copy_to_bulk(mut src: *const u8, mut dst: *mut u8, mut count: usize) {
if count == 0 {
return;
}
// Align destination IO to 2 bytes.
if dst.addr() & 1 != 0 {
// SAFETY: The caller guarantees both pointers are valid.
unsafe {
asm!("movsb", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags));
}
count -= 1;
}
if count > 1 && dst.addr() & 2 != 0 {
// SAFETY: The caller guarantees both pointers are valid for 2 bytes.
unsafe {
asm!("movsw", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags));
}
count -= 2;
}
if count > 0 {
// SAFETY: The caller guarantees both pointers are valid for `count` bytes.
unsafe {
asm!(
"rep movsb",
inout("rdi") dst,
inout("rsi") src,
inout("rcx") count,
options(nostack, preserves_flags)
);
}
let _ = (dst, src, count);
}
}

View File

@ -6,6 +6,8 @@ use align_ext::AlignExt;
use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder}; use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder};
pub(crate) mod io_mem;
/// Initializes the allocatable MMIO area based on the x86-64 memory distribution map. /// Initializes the allocatable MMIO area based on the x86-64 memory distribution map.
/// ///
/// In x86-64, the available physical memory area is divided into two regions below 32 bits (Low memory) /// In x86-64, the available physical memory area is divided into two regions below 32 bits (Low memory)

View File

@ -3,6 +3,7 @@
//! I/O memory and its allocator that allocates memory I/O (MMIO) to device drivers. //! I/O memory and its allocator that allocates memory I/O (MMIO) to device drivers.
mod allocator; mod allocator;
mod util;
use core::{ use core::{
marker::PhantomData, marker::PhantomData,
@ -10,6 +11,7 @@ use core::{
}; };
use align_ext::AlignExt; use align_ext::AlignExt;
use inherit_methods_macro::inherit_methods;
pub(crate) use self::allocator::IoMemAllocatorBuilder; pub(crate) use self::allocator::IoMemAllocatorBuilder;
pub(super) use self::allocator::init; pub(super) use self::allocator::init;
@ -17,8 +19,7 @@ use crate::{
Error, Error,
cpu::{AtomicCpuSet, CpuSet}, cpu::{AtomicCpuSet, CpuSet},
mm::{ mm::{
HasPaddr, HasSize, Infallible, PAGE_SIZE, Paddr, PodOnce, VmReader, VmWriter, HasPaddr, HasSize, PAGE_SIZE, Paddr, PodOnce, VmIo, VmIoFill, VmIoOnce, VmReader, VmWriter,
io_util::{HasVmReaderWriter, VmReaderWriterIdentity},
kspace::kvirt_area::KVirtArea, kspace::kvirt_area::KVirtArea,
page_prop::{CachePolicy, PageFlags, PageProperty, PrivilegedPageFlags}, page_prop::{CachePolicy, PageFlags, PageProperty, PrivilegedPageFlags},
tlb::{TlbFlushOp, TlbFlusher}, tlb::{TlbFlushOp, TlbFlusher},
@ -146,6 +147,19 @@ impl<SecuritySensitivity> IoMem<SecuritySensitivity> {
pub fn cache_policy(&self) -> CachePolicy { pub fn cache_policy(&self) -> CachePolicy {
self.cache_policy self.cache_policy
} }
/// Returns the base virtual address of the MMIO range.
fn base(&self) -> usize {
self.kvirt_area.deref().start() + self.offset
}
/// Validates that the offset range lies within the MMIO window.
fn check_range(&self, offset: usize, len: usize) -> Result<()> {
if offset.checked_add(len).is_none_or(|end| end > self.limit) {
return Err(Error::InvalidArgs);
}
Ok(())
}
} }
#[cfg_attr(target_arch = "loongarch64", expect(unused))] #[cfg_attr(target_arch = "loongarch64", expect(unused))]
@ -163,10 +177,10 @@ impl IoMem<Sensitive> {
/// not cause any out-of-bounds access, and does not cause unsound side /// not cause any out-of-bounds access, and does not cause unsound side
/// effects (e.g., corrupting the kernel memory). /// effects (e.g., corrupting the kernel memory).
pub(crate) unsafe fn read_once<T: PodOnce>(&self, offset: usize) -> T { pub(crate) unsafe fn read_once<T: PodOnce>(&self, offset: usize) -> T {
debug_assert!(offset + size_of::<T>() < self.limit); debug_assert!(offset + size_of::<T>() <= self.limit);
let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *const T; let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *const T;
// SAFETY: The safety of the read operation's semantics is upheld by the caller. // SAFETY: The safety of the read operation's semantics is upheld by the caller.
unsafe { core::ptr::read_volatile(ptr) } unsafe { crate::arch::io::io_mem::read_once(ptr) }
} }
/// Writes a value of the `PodOnce` type at the specified offset using one /// Writes a value of the `PodOnce` type at the specified offset using one
@ -182,10 +196,10 @@ impl IoMem<Sensitive> {
/// not cause any out-of-bounds access, and does not cause unsound side /// not cause any out-of-bounds access, and does not cause unsound side
/// effects (e.g., corrupting the kernel memory). /// effects (e.g., corrupting the kernel memory).
pub(crate) unsafe fn write_once<T: PodOnce>(&self, offset: usize, value: &T) { pub(crate) unsafe fn write_once<T: PodOnce>(&self, offset: usize, value: &T) {
debug_assert!(offset + size_of::<T>() < self.limit); debug_assert!(offset + size_of::<T>() <= self.limit);
let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *mut T; let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *mut T;
// SAFETY: The safety of the write operation's semantics is upheld by the caller. // SAFETY: The safety of the write operation's semantics is upheld by the caller.
unsafe { core::ptr::write_volatile(ptr, *value) }; unsafe { crate::arch::io::io_mem::write_once(ptr, *value) };
} }
} }
@ -210,39 +224,160 @@ impl IoMem<Insensitive> {
} }
} }
// For now, we reuse `VmReader` and `VmWriter` to access I/O memory. impl VmIoOnce for IoMem<Insensitive> {
// fn read_once<T: PodOnce>(&self, offset: usize) -> Result<T> {
// Note that I/O memory is not normal typed or untyped memory. Strictly speaking, it is not self.check_range(offset, size_of::<T>())?;
// "memory", but rather I/O ports that communicate directly with the hardware. However, this code let ptr = (self.base() + offset) as *const T;
// is in OSTD, so we can rely on the implementation details of `VmReader` and `VmWriter`, which we if !ptr.is_aligned() {
// know are also suitable for accessing I/O memory. return Err(Error::InvalidArgs);
}
impl HasVmReaderWriter for IoMem<Insensitive> { // SAFETY: The pointer is properly aligned and within the validated range.
type Types = VmReaderWriterIdentity; let val = unsafe { crate::arch::io::io_mem::read_once(ptr) };
Ok(val)
}
fn reader(&self) -> VmReader<'_, Infallible> { fn write_once<T: PodOnce>(&self, offset: usize, value: &T) -> Result<()> {
// SAFETY: The constructor of the `IoMem` structure has already ensured the self.check_range(offset, size_of::<T>())?;
// safety of reading from the mapped physical address, and the mapping is valid. let ptr = (self.base() + offset) as *mut T;
if !ptr.is_aligned() {
return Err(Error::InvalidArgs);
}
// SAFETY: The pointer is properly aligned and within the validated range.
unsafe { crate::arch::io::io_mem::write_once(ptr, *value) };
Ok(())
}
}
impl VmIo for IoMem<Insensitive> {
fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()> {
let len = writer.avail();
self.check_range(offset, len)?;
let src = (self.base() + offset) as *const u8;
// SAFETY: `check_range` guarantees a valid MMIO range for `len` bytes.
let result = unsafe { crate::io::io_mem::util::copy_from_fallible(writer, src, len) };
match result {
Ok(()) => Ok(()),
Err((err, _copied_len)) => Err(err),
}
}
fn read_bytes(&self, offset: usize, buf: &mut [u8]) -> Result<()> {
self.check_range(offset, buf.len())?;
let src = (self.base() + offset) as *const u8;
let dst = buf.as_mut_ptr();
// SAFETY: The `dst` and `src` buffers are valid to write and read, respectively.
unsafe { unsafe {
VmReader::from_kernel_space( crate::io::io_mem::util::copy_from(dst, src, buf.len());
(self.kvirt_area.deref().start() + self.offset) as *mut u8, }
self.limit, Ok(())
) }
fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()> {
let len = reader.remain();
self.check_range(offset, len)?;
let dst = (self.base() + offset) as *mut u8;
// SAFETY: `check_range` guarantees a valid MMIO range for `len` bytes.
let result = unsafe { crate::io::io_mem::util::copy_to_fallible(reader, dst, len) };
match result {
Ok(()) => Ok(()),
Err((err, _copied_len)) => Err(err),
} }
} }
fn writer(&self) -> VmWriter<'_, Infallible> { fn write_bytes(&self, offset: usize, buf: &[u8]) -> Result<()> {
// SAFETY: The constructor of the `IoMem` structure has already ensured the self.check_range(offset, buf.len())?;
// safety of writing to the mapped physical address, and the mapping is valid. let src = buf.as_ptr();
let dst = (self.base() + offset) as *mut u8;
// SAFETY: The `dst` and `src` buffers are valid to write and read, respectively.
unsafe { unsafe {
VmWriter::from_kernel_space( crate::io::io_mem::util::copy_to(src, dst, buf.len());
(self.kvirt_area.deref().start() + self.offset) as *mut u8, }
self.limit, Ok(())
) }
}
impl VmIoFill for IoMem<Insensitive> {
fn fill_zeros(&self, offset: usize, len: usize) -> core::result::Result<(), (Error, usize)> {
if offset > self.limit {
return Err((Error::InvalidArgs, 0));
}
let available = self.limit - offset;
let write_len = core::cmp::min(len, available);
if write_len == 0 {
return Ok(());
}
let mut remaining = write_len;
let mut ptr = (self.base() + offset) as *mut u8;
let word_size = size_of::<usize>();
// Align destination to word size.
while remaining > 0 && !ptr.addr().is_multiple_of(word_size) {
// SAFETY: `check_range` guarantees a valid MMIO range for the range.
unsafe { crate::arch::io::io_mem::write_once(ptr, 0u8) };
ptr = ptr.wrapping_add(1);
remaining -= 1;
}
while remaining >= word_size {
// SAFETY: `check_range` guarantees a valid MMIO range for the range.
unsafe { crate::arch::io::io_mem::write_once(ptr.cast::<usize>(), 0usize) };
ptr = ptr.wrapping_add(word_size);
remaining -= word_size;
}
while remaining > 0 {
// SAFETY: The remaining range is within the validated MMIO window.
unsafe { crate::arch::io::io_mem::write_once(ptr, 0u8) };
ptr = ptr.wrapping_add(1);
remaining -= 1;
}
if write_len < len {
Err((Error::InvalidArgs, write_len))
} else {
Ok(())
} }
} }
} }
macro_rules! impl_vm_io_pointer {
($ty:ty, $from:tt) => {
#[inherit_methods(from = $from)]
impl VmIo for $ty {
fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()>;
fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()>;
}
#[inherit_methods(from = $from)]
impl VmIoOnce for $ty {
fn read_once<T: PodOnce>(&self, offset: usize) -> Result<T>;
fn write_once<T: PodOnce>(&self, offset: usize, value: &T) -> Result<()>;
}
#[inherit_methods(from = $from)]
impl VmIoFill for $ty {
fn fill_zeros(
&self,
offset: usize,
len: usize,
) -> core::result::Result<(), (Error, usize)>;
}
};
}
impl_vm_io_pointer!(&IoMem<Insensitive>, "(**self)");
impl_vm_io_pointer!(&mut IoMem<Insensitive>, "(**self)");
impl<SecuritySensitivity> HasPaddr for IoMem<SecuritySensitivity> { impl<SecuritySensitivity> HasPaddr for IoMem<SecuritySensitivity> {
fn paddr(&self) -> Paddr { fn paddr(&self) -> Paddr {
self.pa self.pa

479
ostd/src/io/io_mem/util.rs Normal file
View File

@ -0,0 +1,479 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
Error,
arch::io::io_mem::{read_once, write_once},
mm::{FallibleVmRead, FallibleVmWrite, VmReader, VmWriter},
};
#[cfg(all(target_arch = "x86_64", not(feature = "cvm_guest")))]
const BULK_THRESHOLD: usize = 128;
/// Attempts to copy from MMIO to regular memory using a bulk path.
///
/// # Safety
///
/// - `dst` must be valid for writes of `count` bytes.
/// - `src` must be valid for MMIO reads of `count` bytes.
#[cfg(all(target_arch = "x86_64", not(feature = "cvm_guest")))]
unsafe fn try_bulk_copy_from(dst: *mut u8, src: *const u8, count: usize) -> bool {
if count >= BULK_THRESHOLD {
// SAFETY: The caller guarantees both pointers are valid for `count` bytes.
unsafe { crate::arch::io::io_mem::copy_from_bulk(dst, src, count) };
return true;
}
false
}
/// Attempts to copy from MMIO to regular memory using a bulk path.
///
/// # Safety
///
/// - `dst` must be valid for writes of `count` bytes.
/// - `src` must be valid for MMIO reads of `count` bytes.
#[cfg(not(all(target_arch = "x86_64", not(feature = "cvm_guest"))))]
unsafe fn try_bulk_copy_from(_dst: *mut u8, _src: *const u8, _count: usize) -> bool {
false
}
/// Attempts to copy from regular memory to MMIO using a bulk path.
///
/// # Safety
///
/// - `src` must be valid for reads of `count` bytes.
/// - `dst` must be valid for MMIO writes of `count` bytes.
#[cfg(all(target_arch = "x86_64", not(feature = "cvm_guest")))]
unsafe fn try_bulk_copy_to(src: *const u8, dst: *mut u8, count: usize) -> bool {
if count >= BULK_THRESHOLD {
// SAFETY: The caller guarantees both pointers are valid for `count` bytes.
unsafe { crate::arch::io::io_mem::copy_to_bulk(src, dst, count) };
return true;
}
false
}
/// Attempts to copy from regular memory to MMIO using a bulk path.
///
/// # Safety
///
/// - `src` must be valid for reads of `count` bytes.
/// - `dst` must be valid for MMIO writes of `count` bytes.
#[cfg(not(all(target_arch = "x86_64", not(feature = "cvm_guest"))))]
unsafe fn try_bulk_copy_to(_src: *const u8, _dst: *mut u8, _count: usize) -> bool {
false
}
/// Copies one byte from MMIO to regular memory and advances both pointers.
///
/// # Safety
///
/// - `src_io_ptr` must be valid for one byte of MMIO read.
/// - `dst_ptr` must be valid for one byte of regular memory write.
unsafe fn copy_u8_from(dst_ptr: &mut *mut u8, src_io_ptr: &mut *const u8) {
// SAFETY: The caller guarantees the MMIO and regular memory pointers are valid.
unsafe {
let val: u8 = read_once(*src_io_ptr);
core::ptr::write(*dst_ptr, val);
*src_io_ptr = (*src_io_ptr).add(1);
*dst_ptr = (*dst_ptr).add(1);
}
}
/// Copies a u16 from MMIO to regular memory and advances both pointers.
///
/// # Safety
///
/// - `src_io_ptr` must be aligned and valid for a u16 MMIO read.
/// - `dst_ptr` must be valid for a u16 write (aligned if `write_unaligned` is false).
unsafe fn copy_u16_from(dst_ptr: &mut *mut u8, src_io_ptr: &mut *const u8, write_unaligned: bool) {
// SAFETY: The caller guarantees the MMIO and regular memory pointers are valid.
unsafe {
let val: u16 = read_once((*src_io_ptr).cast::<u16>());
if write_unaligned {
core::ptr::write_unaligned((*dst_ptr).cast::<u16>(), val);
} else {
core::ptr::write((*dst_ptr).cast::<u16>(), val);
}
*src_io_ptr = (*src_io_ptr).add(size_of::<u16>());
*dst_ptr = (*dst_ptr).add(size_of::<u16>());
}
}
/// Copies one byte from regular memory to MMIO and advances both pointers.
///
/// # Safety
///
/// - `src_ptr` must be valid for one byte of regular memory read.
/// - `dst_io_ptr` must be valid for one byte of MMIO write.
unsafe fn copy_u8_to(src_ptr: &mut *const u8, dst_io_ptr: &mut *mut u8) {
// SAFETY: The caller guarantees the regular memory and MMIO pointers are valid.
unsafe {
let val: u8 = core::ptr::read(*src_ptr);
write_once(*dst_io_ptr, val);
*src_ptr = (*src_ptr).add(1);
*dst_io_ptr = (*dst_io_ptr).add(1);
}
}
/// Copies a u16 from regular memory to MMIO and advances both pointers.
///
/// # Safety
///
/// - `src_ptr` must be valid for a u16 read (aligned if `read_unaligned` is false).
/// - `dst_io_ptr` must be aligned and valid for a u16 MMIO write.
unsafe fn copy_u16_to(src_ptr: &mut *const u8, dst_io_ptr: &mut *mut u8, read_unaligned: bool) {
// SAFETY: The caller guarantees the regular memory and MMIO pointers are valid.
unsafe {
let val: u16 = if read_unaligned {
core::ptr::read_unaligned((*src_ptr).cast::<u16>())
} else {
core::ptr::read((*src_ptr).cast::<u16>())
};
write_once((*dst_io_ptr).cast::<u16>(), val);
*src_ptr = (*src_ptr).add(size_of::<u16>());
*dst_io_ptr = (*dst_io_ptr).add(size_of::<u16>());
}
}
/// Copies from I/O memory to regular memory.
///
/// This uses simple load/store instructions, which is required on some platforms
/// (e.g., TDX).
///
/// # Safety
///
/// - `src_io_ptr` must be valid for MMIO reads of `count` bytes.
/// - `dst_ptr` must be valid for writes of `count` bytes.
pub(crate) unsafe fn copy_from(mut dst_ptr: *mut u8, mut src_io_ptr: *const u8, mut count: usize) {
// SAFETY: The caller guarantees both pointers are valid for `count` bytes.
if unsafe { try_bulk_copy_from(dst_ptr, src_io_ptr, count) } {
return;
}
if count == 0 {
return;
}
// Align any unaligned source IO.
if src_io_ptr.addr() & 1 != 0 {
// SAFETY: The caller guarantees valid MMIO reads and regular memory writes.
unsafe { copy_u8_from(&mut dst_ptr, &mut src_io_ptr) };
count -= 1;
}
if count > 1 && src_io_ptr.addr() & 2 != 0 {
let write_unaligned = dst_ptr.addr() & 1 != 0;
// SAFETY: The caller guarantees valid MMIO reads and regular memory writes.
unsafe { copy_u16_from(&mut dst_ptr, &mut src_io_ptr, write_unaligned) };
count -= 2;
}
let write_unaligned = dst_ptr.addr() & 1 != 0;
let word_size = size_of::<u16>();
while count >= word_size {
// SAFETY: The caller guarantees valid MMIO reads and regular memory writes.
// The source pointer is aligned by the steps above.
unsafe { copy_u16_from(&mut dst_ptr, &mut src_io_ptr, write_unaligned) };
count -= word_size;
}
while count > 0 {
// SAFETY: The caller guarantees valid MMIO reads and regular memory writes.
unsafe { copy_u8_from(&mut dst_ptr, &mut src_io_ptr) };
count -= 1;
}
}
/// Copies from regular memory to I/O memory.
///
/// # Safety
///
/// - `src_ptr` must be valid for reads of `count` bytes.
/// - `dst_io_ptr` must be valid for MMIO writes of `count` bytes.
pub(crate) unsafe fn copy_to(mut src_ptr: *const u8, mut dst_io_ptr: *mut u8, mut count: usize) {
// SAFETY: The caller guarantees both pointers are valid for `count` bytes.
if unsafe { try_bulk_copy_to(src_ptr, dst_io_ptr, count) } {
return;
}
if count == 0 {
return;
}
// Align any unaligned destination IO.
if dst_io_ptr.addr() & 1 != 0 {
// SAFETY: The caller guarantees valid reads from regular memory and MMIO writes.
unsafe { copy_u8_to(&mut src_ptr, &mut dst_io_ptr) };
count -= 1;
}
if count > 1 && dst_io_ptr.addr() & 2 != 0 {
let read_unaligned = src_ptr.addr() & 1 != 0;
// SAFETY: The caller guarantees valid reads from regular memory and MMIO writes.
unsafe { copy_u16_to(&mut src_ptr, &mut dst_io_ptr, read_unaligned) };
count -= 2;
}
let read_unaligned = src_ptr.addr() & 1 != 0;
let word_size = size_of::<u16>();
while count >= word_size {
// SAFETY: The caller guarantees valid reads from regular memory and MMIO writes.
// The destination pointer is aligned by the steps above.
unsafe { copy_u16_to(&mut src_ptr, &mut dst_io_ptr, read_unaligned) };
count -= word_size;
}
while count > 0 {
// SAFETY: The caller guarantees valid reads from regular memory and MMIO writes.
unsafe { copy_u8_to(&mut src_ptr, &mut dst_io_ptr) };
count -= 1;
}
}
/// Copies from I/O memory to regular memory with fallible destination writes.
///
/// # Safety
///
/// - `src_io_ptr` must be valid for MMIO reads of `count` bytes.
pub(crate) unsafe fn copy_from_fallible(
writer: &mut VmWriter,
mut src_io_ptr: *const u8,
mut count: usize,
) -> core::result::Result<(), (Error, usize)> {
const BUF_SIZE: usize = 64;
let mut buf = [0u8; BUF_SIZE];
let mut copied_total = 0;
while count > 0 {
let chunk = core::cmp::min(count, buf.len());
// SAFETY: The caller guarantees the MMIO range is valid for this chunk.
unsafe { copy_from(buf.as_mut_ptr(), src_io_ptr, chunk) };
let mut reader = VmReader::from(&buf[..chunk]);
match writer.write_fallible(&mut reader) {
Ok(written) => {
copied_total += written;
if written < chunk {
return Err((Error::PageFault, copied_total));
}
}
Err((err, written)) => {
copied_total += written;
return Err((err, copied_total));
}
}
// SAFETY: `src_io_ptr` is valid for `chunk` bytes and we advance by `chunk`.
src_io_ptr = unsafe { src_io_ptr.add(chunk) };
count -= chunk;
}
Ok(())
}
/// Copies from regular memory to I/O memory with fallible source reads.
///
/// # Safety
///
/// - `dst_io_ptr` must be valid for MMIO writes of `count` bytes.
pub(crate) unsafe fn copy_to_fallible(
reader: &mut VmReader,
mut dst_io_ptr: *mut u8,
mut count: usize,
) -> core::result::Result<(), (Error, usize)> {
const BUF_SIZE: usize = 64;
let mut buf = [0u8; BUF_SIZE];
let mut copied_total = 0;
while count > 0 {
let chunk = core::cmp::min(count, buf.len());
let mut writer = VmWriter::from(&mut buf[..chunk]);
match reader.read_fallible(&mut writer) {
Ok(read_len) => {
if read_len < chunk {
// SAFETY: The MMIO range is valid for `read_len` bytes.
unsafe { copy_to(buf.as_ptr(), dst_io_ptr, read_len) };
copied_total += read_len;
return Err((Error::PageFault, copied_total));
}
}
Err((err, read_len)) => {
if read_len > 0 {
// SAFETY: The MMIO range is valid for `read_len` bytes.
unsafe { copy_to(buf.as_ptr(), dst_io_ptr, read_len) };
copied_total += read_len;
}
return Err((err, copied_total));
}
}
// SAFETY: The MMIO range is valid for `chunk` bytes.
unsafe { copy_to(buf.as_ptr(), dst_io_ptr, chunk) };
// SAFETY: `dst_io_ptr` is valid for `chunk` bytes and we advance by `chunk`.
dst_io_ptr = unsafe { dst_io_ptr.add(chunk) };
count -= chunk;
copied_total += chunk;
}
Ok(())
}
/// CPU architecture-agnostc tests for `arch::mm::io::{read_once, write_once}`.
#[cfg(ktest)]
mod test_read_once_and_write_once {
use super::{read_once, write_once};
use crate::prelude::ktest;
#[ktest]
fn read_write_u8() {
let mut data: u8 = 0;
// SAFETY: `data` is valid for a single MMIO read/write.
unsafe {
write_once(&mut data, 42u8);
assert_eq!(read_once(&data), 42u8);
}
}
#[ktest]
fn read_write_u16() {
let mut data: u16 = 0;
let val: u16 = 0x1234;
// SAFETY: `data` is valid for a single MMIO read/write.
unsafe {
write_once(&mut data, val);
assert_eq!(read_once(&data), val);
}
}
#[ktest]
fn read_write_u32() {
let mut data: u32 = 0;
let val: u32 = 0x12345678;
// SAFETY: `data` is valid for a single MMIO read/write.
unsafe {
write_once(&mut data, val);
assert_eq!(read_once(&data), val);
}
}
#[ktest]
fn read_write_u64() {
let mut data: u64 = 0;
let val: u64 = 0xDEADBEEFCAFEBABE;
// SAFETY: `data` is valid for a single MMIO read/write.
unsafe {
write_once(&mut data, val);
assert_eq!(read_once(&data), val);
}
}
#[ktest]
fn boundary_overlap() {
// Ensure that writing a u8 doesn't corrupt neighboring bytes
// in a larger structure, verifying our instruction sizing.
let mut data: [u8; 2] = [0xAA, 0xBB];
// SAFETY: `data` is valid for a single MMIO read/write.
unsafe {
write_once(&mut data[0], 0x11u8);
assert_eq!(data[0], 0x11);
assert_eq!(data[1], 0xBB); // Should remain untouched
}
}
}
#[cfg(ktest)]
mod test_copy_helpers {
use super::{copy_from, copy_to};
use crate::prelude::ktest;
fn fill_pattern(buf: &mut [u8]) {
for (idx, byte) in buf.iter_mut().enumerate() {
*byte = (idx as u8).wrapping_mul(3).wrapping_add(1);
}
}
fn run_copy_from_case(src_offset: usize, dst_offset: usize, len: usize) {
let mut src = [0u8; 64];
let mut dst = [0u8; 64];
fill_pattern(&mut src);
let src_ptr = unsafe { src.as_ptr().add(src_offset) };
let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) };
// SAFETY: The test buffers are valid for the requested range.
unsafe { copy_from(dst_ptr, src_ptr, len) };
assert_eq!(
&dst[dst_offset..dst_offset + len],
&src[src_offset..src_offset + len]
);
}
fn run_copy_to_case(src_offset: usize, dst_offset: usize, len: usize) {
let mut src = [0u8; 64];
let mut dst = [0u8; 64];
fill_pattern(&mut src);
let src_ptr = unsafe { src.as_ptr().add(src_offset) };
let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) };
// SAFETY: The test buffers are valid for the requested range.
unsafe { copy_to(src_ptr, dst_ptr, len) };
assert_eq!(
&dst[dst_offset..dst_offset + len],
&src[src_offset..src_offset + len]
);
}
#[ktest]
fn copy_from_alignment_and_sizes() {
let word_size = size_of::<usize>();
let sizes = [
0,
1,
word_size.saturating_sub(1),
word_size,
word_size + 1,
word_size * 2 + 3,
];
let offsets = [0, 1, 2];
for &len in &sizes {
for &src_offset in &offsets {
for &dst_offset in &offsets {
if src_offset + len <= 64 && dst_offset + len <= 64 {
run_copy_from_case(src_offset, dst_offset, len);
}
}
}
}
}
#[ktest]
fn copy_to_alignment_and_sizes() {
let word_size = size_of::<usize>();
let sizes = [
0,
1,
word_size.saturating_sub(1),
word_size,
word_size + 1,
word_size * 2 + 3,
];
let offsets = [0, 1, 2];
for &len in &sizes {
for &src_offset in &offsets {
for &dst_offset in &offsets {
if src_offset + len <= 64 && dst_offset + len <= 64 {
run_copy_to_case(src_offset, dst_offset, len);
}
}
}
}
}
}

View File

@ -956,6 +956,20 @@ impl<Fallibility> VmWriter<'_, Fallibility> {
self.cursor self.cursor
} }
/// Creates a temporary writer view with the same cursor and end.
///
/// The returned writer borrows `self` mutably, so the original writer
/// cannot be used until the returned writer is dropped. This is useful
/// for creating a temporary, adjusted view without mutating the original
/// writer state.
pub fn fork(&mut self) -> VmWriter<'_, Fallibility> {
VmWriter {
cursor: self.cursor,
end: self.end,
phantom: PhantomData,
}
}
/// Returns if it has available space to write. /// Returns if it has available space to write.
pub fn has_avail(&self) -> bool { pub fn has_avail(&self) -> bool {
self.avail() > 0 self.avail() > 0