diff --git a/kernel/libs/aster-util/src/safe_ptr.rs b/kernel/libs/aster-util/src/safe_ptr.rs index e1f924341..1b6181638 100644 --- a/kernel/libs/aster-util/src/safe_ptr.rs +++ b/kernel/libs/aster-util/src/safe_ptr.rs @@ -5,12 +5,8 @@ use core::{fmt::Debug, marker::PhantomData}; use aster_rights::{Dup, Exec, Full, Read, Signal, TRightSet, TRights, Write}; use aster_rights_proc::require; use ostd::{ - Error, Result, - mm::{ - Daddr, HasDaddr, HasPaddr, Paddr, PodOnce, VmIo, VmIoOnce, - dma::DmaDirection, - io_util::{HasVmReaderWriter, VmReaderWriterTypes}, - }, + Result, + mm::{Daddr, HasDaddr, HasPaddr, Paddr, PodOnce, VmIo, VmIoOnce, dma::DmaDirection}, }; use ostd_pod::Pod; @@ -239,18 +235,12 @@ impl SafePtr> { /// - the Read right for another pointer. #[require(R > Write)] #[require(R1 > Read)] - pub fn copy_from( + pub fn copy_from( &self, ptr: &SafePtr>, ) -> Result<()> { - let mut reader = M1::Types::to_reader_result(ptr.vm_obj.reader())?.to_fallible(); - - if reader.remain() < size_of::() { - return Err(Error::InvalidArgs); - } - reader.limit(size_of::()); - - self.vm_obj.write(self.offset, &mut reader) + let val = ptr.vm_obj.read_val::(ptr.offset)?; + self.vm_obj.write_val(self.offset, &val) } } diff --git a/kernel/src/device/fb.rs b/kernel/src/device/fb.rs index 6d05b5ea4..884a9575f 100644 --- a/kernel/src/device/fb.rs +++ b/kernel/src/device/fb.rs @@ -4,7 +4,7 @@ use alloc::sync::Arc; use aster_framebuffer::{ColorMapEntry, FRAMEBUFFER, FrameBuffer, MAX_CMAP_SIZE, PixelFormat}; use device_id::{DeviceId, MajorId, MinorId}; -use ostd::mm::{HasPaddr, HasSize, VmIo, io_util::HasVmReaderWriter}; +use ostd::mm::{HasPaddr, HasSize, VmIo}; use super::registry::char; use crate::{ @@ -412,20 +412,26 @@ impl InodeIo for FbHandle { return Ok(0); } - let mut reader = self.framebuffer.io_mem().reader(); - - if offset >= reader.remain() { + let io_mem = self.framebuffer.io_mem(); + let size = io_mem.size(); + if offset >= size { return Ok(0); } - reader.skip(offset); - let mut reader = reader.to_fallible(); - let len = match reader.read_fallible(writer) { - Ok(len) => len, - Err((err, 0)) => return Err(err.into()), - Err((_err, len)) => len, - }; + let len = writer.avail().min(size - offset); + if len == 0 { + return Ok(0); + } + { + // Create a new writer. We should preserve the end of the original writer. + let mut new_writer = writer.fork(); + new_writer.limit(len); + io_mem.read(offset, &mut new_writer)?; + } + + // Synchronize the cursor inside the original writer. + writer.skip(len); Ok(len) } @@ -439,22 +445,27 @@ impl InodeIo for FbHandle { return Ok(0); } - let mut writer = self.framebuffer.io_mem().writer(); - if offset >= writer.avail() { + let io_mem = self.framebuffer.io_mem(); + let size = io_mem.size(); + if offset >= size { return_errno_with_message!( Errno::ENOSPC, "the write offset is beyond the framebuffer size" ); } - writer.skip(offset); - let mut writer = writer.to_fallible(); - let len = match writer.write_fallible(reader) { - Ok(len) => len, - Err((err, 0)) => return Err(err.into()), - Err((_err, len)) => len, - }; + let len = reader.remain().min(size - offset); + if len == 0 { + return Ok(0); + } + // Create a new reader. We should preserve the end of the original reader. + let mut new_reader = reader.clone(); + new_reader.limit(len); + io_mem.write(offset, &mut new_reader)?; + + // Synchronize the cursor inside the original reader. + reader.skip(len); Ok(len) } } diff --git a/ostd/src/arch/loongarch/io/io_mem.rs b/ostd/src/arch/loongarch/io/io_mem.rs new file mode 100644 index 000000000..229067f9b --- /dev/null +++ b/ostd/src/arch/loongarch/io/io_mem.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::mm::PodOnce; + +pub(crate) unsafe fn read_once(ptr: *const T) -> T { + // TODO: Use arch-specific single-instruction load for LoongArch. + // For detail, see https://github.com/asterinas/asterinas/issues/2948. + unsafe { core::ptr::read_volatile(ptr) } +} + +pub(crate) unsafe fn write_once(ptr: *mut T, val: T) { + // TODO: Use arch-specific single-instruction store for LoongArch. + // For detail, see https://github.com/asterinas/asterinas/issues/2948. + unsafe { core::ptr::write_volatile(ptr, val) } +} diff --git a/ostd/src/arch/loongarch/io.rs b/ostd/src/arch/loongarch/io/mod.rs similarity index 98% rename from ostd/src/arch/loongarch/io.rs rename to ostd/src/arch/loongarch/io/mod.rs index 549ab793f..eb0c22b8b 100644 --- a/ostd/src/arch/loongarch/io.rs +++ b/ostd/src/arch/loongarch/io/mod.rs @@ -4,6 +4,8 @@ use alloc::vec::Vec; use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder}; +pub(crate) mod io_mem; + /// Initializes the allocatable MMIO area based on the LoongArch memory /// distribution map. /// diff --git a/ostd/src/arch/loongarch/mod.rs b/ostd/src/arch/loongarch/mod.rs index cf3dae50e..56e666122 100644 --- a/ostd/src/arch/loongarch/mod.rs +++ b/ostd/src/arch/loongarch/mod.rs @@ -7,7 +7,7 @@ pub mod boot; pub mod cpu; pub mod device; -mod io; +pub(crate) mod io; pub(crate) mod iommu; pub(crate) mod irq; pub(crate) mod mm; diff --git a/ostd/src/arch/loongarch/serial.rs b/ostd/src/arch/loongarch/serial.rs index 59a2ce291..e1b35926b 100644 --- a/ostd/src/arch/loongarch/serial.rs +++ b/ostd/src/arch/loongarch/serial.rs @@ -5,7 +5,11 @@ use spin::Once; use crate::{ - arch::{boot::DEVICE_TREE, mm::paddr_to_daddr}, + arch::{ + boot::DEVICE_TREE, + io::io_mem::{read_once, write_once}, + mm::paddr_to_daddr, + }, console::uart_ns16650a::{Ns16550aAccess, Ns16550aRegister, Ns16550aUart}, sync::{LocalIrqDisabled, SpinLock}, }; @@ -34,12 +38,12 @@ impl SerialAccess { impl Ns16550aAccess for SerialAccess { fn read(&self, reg: Ns16550aRegister) -> u8 { // SAFETY: `self.base + reg` is a valid register of the serial port. - unsafe { core::ptr::read_volatile(self.base.add(reg as u8 as usize)) } + unsafe { read_once(self.base.add(usize::from(reg as u8))) } } fn write(&mut self, reg: Ns16550aRegister, val: u8) { // SAFETY: `self.base + reg` is a valid register of the serial port. - unsafe { core::ptr::write_volatile(self.base.add(reg as u8 as usize), val) }; + unsafe { write_once(self.base.add(usize::from(reg as u8)), val) } } } diff --git a/ostd/src/arch/riscv/io/io_mem.rs b/ostd/src/arch/riscv/io/io_mem.rs new file mode 100644 index 000000000..9799c1b1d --- /dev/null +++ b/ostd/src/arch/riscv/io/io_mem.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::mm::PodOnce; + +pub(crate) unsafe fn read_once(ptr: *const T) -> T { + // TODO: Use arch-specific single-instruction load for RISC-V. + // For detail, see https://github.com/asterinas/asterinas/issues/2948. + unsafe { core::ptr::read_volatile(ptr) } +} + +pub(crate) unsafe fn write_once(ptr: *mut T, val: T) { + // TODO: Use arch-specific single-instruction store for RISC-V. + // For detail, see https://github.com/asterinas/asterinas/issues/2948. + unsafe { core::ptr::write_volatile(ptr, val) } +} diff --git a/ostd/src/arch/riscv/io.rs b/ostd/src/arch/riscv/io/mod.rs similarity index 98% rename from ostd/src/arch/riscv/io.rs rename to ostd/src/arch/riscv/io/mod.rs index 91105d2a8..0a19db428 100644 --- a/ostd/src/arch/riscv/io.rs +++ b/ostd/src/arch/riscv/io/mod.rs @@ -4,6 +4,8 @@ use alloc::vec::Vec; use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder}; +pub(crate) mod io_mem; + /// Initializes the allocatable MMIO area based on the RISC-V memory /// distribution map. /// diff --git a/ostd/src/arch/riscv/mod.rs b/ostd/src/arch/riscv/mod.rs index 32907dabd..123c44b81 100644 --- a/ostd/src/arch/riscv/mod.rs +++ b/ostd/src/arch/riscv/mod.rs @@ -7,7 +7,7 @@ pub mod boot; pub mod cpu; pub mod device; -mod io; +pub(crate) mod io; pub(crate) mod iommu; pub mod irq; pub(crate) mod mm; diff --git a/ostd/src/arch/x86/io/io_mem.rs b/ostd/src/arch/x86/io/io_mem.rs new file mode 100644 index 000000000..5b052a5dd --- /dev/null +++ b/ostd/src/arch/x86/io/io_mem.rs @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::arch::asm; + +use crate::mm::PodOnce; + +/// Reads from a pointer with a non-tearing memory load. +/// +/// This function is semantically equivalent to `core::ptr::read_volatile` +/// but is manually implemented with a single memory load instruction. +/// +/// # Safety +/// +/// Same as the safety requirement of `core::ptr::read_volatile`. +/// +/// # Guarantee +/// +/// The single-memory-load-instruction guarantee is particularly useful for +/// Confidential VMs (CVMs) such as Intel TDX and AMD SEV, +/// where the memory load may cause CPU exceptions (#VE and #VC, respectively) +/// and the kernel has to handle such exceptions +/// by decoding the faulting CPU instruction. +/// As such, the kernel must be compiled to emit simple load/store CPU instructions. +pub(crate) unsafe fn read_once(ptr: *const T) -> T { + let mut val: u64 = 0; + + // SAFETY: The caller guarantees `ptr` is valid for reads. + unsafe { + // EFFICIENCY: The match should be optimized out for release builds. + // + // This match is resolved at compile-time via monomorphization. + // Since `size_of::()` is a constant for each concrete instance of this + // function, the compiler eliminates the branch and only emits the + // `MOV` instruction for the matching size. + match size_of::() { + 1 => { + asm!("mov {0:l}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags)) + } + 2 => { + asm!("mov {0:x}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags)) + } + 4 => { + asm!("mov {0:e}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags)) + } + 8 => { + asm!("mov {0}, [{1}]", out(reg) val, in(reg) ptr, options(nostack, readonly, preserves_flags)) + } + _ => core::hint::unreachable_unchecked(), + } + // EFFICIENCY: This should compile to a no-op for release builds. + // This is because both the source and destination locations fit in a register. + // So it only _re-interprets_ bits and no copying is needed. + core::ptr::read((&val as *const u64).cast::()) + } +} + +/// Writes to a pointer with a non-tearing memory store. +/// +/// # Safety +/// +/// Same as the safety requirement of `core::ptr::write_volatile`. +/// +/// # Guarantee +/// +/// Refer to the "Guarantee" section of [`read_once`]. +pub(crate) unsafe fn write_once(ptr: *mut T, val: T) { + let mut tmp: u64 = 0; + + // SAFETY: The caller guarantees `ptr` is valid for writes. + unsafe { + // EFFICIENCY: This should be a no-op for release build. + // This is because both the source and destination locations fit in a register. + // So it only _re-interprets_ bits and no copying is needed. + core::ptr::write((&mut tmp as *mut u64).cast::(), val); + + // EFFICIENCY: The match here has no overhead for release build. + // + // This match is resolved at compile-time via monomorphization. + // Since `size_of::()` is a constant for each concrete instance of this + // function, the compiler eliminates the branch and only emits the + // `MOV` instruction for the matching size. + match size_of::() { + 1 => { + asm!("mov [{0}], {1:l}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags)) + } + 2 => { + asm!("mov [{0}], {1:x}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags)) + } + 4 => { + asm!("mov [{0}], {1:e}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags)) + } + 8 => { + asm!("mov [{0}], {1}", in(reg) ptr, in(reg) tmp, options(nostack, preserves_flags)) + } + _ => core::hint::unreachable_unchecked(), + } + } +} + +/// Copies from MMIO to regular memory using string move instructions. +/// +/// # Safety +/// +/// - `src` must be valid for MMIO reads of `count` bytes. +/// - `dst` must be valid for writes of `count` bytes. +#[cfg(not(feature = "cvm_guest"))] +pub(crate) unsafe fn copy_from_bulk(mut dst: *mut u8, mut src: *const u8, mut count: usize) { + if count == 0 { + return; + } + + // Align source IO to 2 bytes. + if src.addr() & 1 != 0 { + // SAFETY: The caller guarantees both pointers are valid. + unsafe { + asm!("movsb", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags)); + } + count -= 1; + } + + if count > 1 && src.addr() & 2 != 0 { + // SAFETY: The caller guarantees both pointers are valid for 2 bytes. + unsafe { + asm!("movsw", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags)); + } + count -= 2; + } + + if count > 0 { + // SAFETY: The caller guarantees both pointers are valid for `count` bytes. + unsafe { + asm!( + "rep movsb", + inout("rdi") dst, + inout("rsi") src, + inout("rcx") count, + options(nostack, preserves_flags) + ); + } + let _ = (dst, src, count); + } +} + +/// Copies from regular memory to MMIO using string move instructions. +/// +/// # Safety +/// +/// - `src` must be valid for reads of `count` bytes. +/// - `dst` must be valid for MMIO writes of `count` bytes. +#[cfg(not(feature = "cvm_guest"))] +pub(crate) unsafe fn copy_to_bulk(mut src: *const u8, mut dst: *mut u8, mut count: usize) { + if count == 0 { + return; + } + + // Align destination IO to 2 bytes. + if dst.addr() & 1 != 0 { + // SAFETY: The caller guarantees both pointers are valid. + unsafe { + asm!("movsb", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags)); + } + count -= 1; + } + + if count > 1 && dst.addr() & 2 != 0 { + // SAFETY: The caller guarantees both pointers are valid for 2 bytes. + unsafe { + asm!("movsw", inout("rdi") dst, inout("rsi") src, options(nostack, preserves_flags)); + } + count -= 2; + } + + if count > 0 { + // SAFETY: The caller guarantees both pointers are valid for `count` bytes. + unsafe { + asm!( + "rep movsb", + inout("rdi") dst, + inout("rsi") src, + inout("rcx") count, + options(nostack, preserves_flags) + ); + } + let _ = (dst, src, count); + } +} diff --git a/ostd/src/arch/x86/io.rs b/ostd/src/arch/x86/io/mod.rs similarity index 99% rename from ostd/src/arch/x86/io.rs rename to ostd/src/arch/x86/io/mod.rs index 4417b0cfa..9ca694125 100644 --- a/ostd/src/arch/x86/io.rs +++ b/ostd/src/arch/x86/io/mod.rs @@ -6,6 +6,8 @@ use align_ext::AlignExt; use crate::{boot::memory_region::MemoryRegionType, io::IoMemAllocatorBuilder}; +pub(crate) mod io_mem; + /// Initializes the allocatable MMIO area based on the x86-64 memory distribution map. /// /// In x86-64, the available physical memory area is divided into two regions below 32 bits (Low memory) diff --git a/ostd/src/io/io_mem/mod.rs b/ostd/src/io/io_mem/mod.rs index eb307cf26..4961435ab 100644 --- a/ostd/src/io/io_mem/mod.rs +++ b/ostd/src/io/io_mem/mod.rs @@ -3,6 +3,7 @@ //! I/O memory and its allocator that allocates memory I/O (MMIO) to device drivers. mod allocator; +mod util; use core::{ marker::PhantomData, @@ -10,6 +11,7 @@ use core::{ }; use align_ext::AlignExt; +use inherit_methods_macro::inherit_methods; pub(crate) use self::allocator::IoMemAllocatorBuilder; pub(super) use self::allocator::init; @@ -17,8 +19,7 @@ use crate::{ Error, cpu::{AtomicCpuSet, CpuSet}, mm::{ - HasPaddr, HasSize, Infallible, PAGE_SIZE, Paddr, PodOnce, VmReader, VmWriter, - io_util::{HasVmReaderWriter, VmReaderWriterIdentity}, + HasPaddr, HasSize, PAGE_SIZE, Paddr, PodOnce, VmIo, VmIoFill, VmIoOnce, VmReader, VmWriter, kspace::kvirt_area::KVirtArea, page_prop::{CachePolicy, PageFlags, PageProperty, PrivilegedPageFlags}, tlb::{TlbFlushOp, TlbFlusher}, @@ -146,6 +147,19 @@ impl IoMem { pub fn cache_policy(&self) -> CachePolicy { self.cache_policy } + + /// Returns the base virtual address of the MMIO range. + fn base(&self) -> usize { + self.kvirt_area.deref().start() + self.offset + } + + /// Validates that the offset range lies within the MMIO window. + fn check_range(&self, offset: usize, len: usize) -> Result<()> { + if offset.checked_add(len).is_none_or(|end| end > self.limit) { + return Err(Error::InvalidArgs); + } + Ok(()) + } } #[cfg_attr(target_arch = "loongarch64", expect(unused))] @@ -163,10 +177,10 @@ impl IoMem { /// not cause any out-of-bounds access, and does not cause unsound side /// effects (e.g., corrupting the kernel memory). pub(crate) unsafe fn read_once(&self, offset: usize) -> T { - debug_assert!(offset + size_of::() < self.limit); + debug_assert!(offset + size_of::() <= self.limit); let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *const T; // SAFETY: The safety of the read operation's semantics is upheld by the caller. - unsafe { core::ptr::read_volatile(ptr) } + unsafe { crate::arch::io::io_mem::read_once(ptr) } } /// Writes a value of the `PodOnce` type at the specified offset using one @@ -182,10 +196,10 @@ impl IoMem { /// not cause any out-of-bounds access, and does not cause unsound side /// effects (e.g., corrupting the kernel memory). pub(crate) unsafe fn write_once(&self, offset: usize, value: &T) { - debug_assert!(offset + size_of::() < self.limit); + debug_assert!(offset + size_of::() <= self.limit); let ptr = (self.kvirt_area.deref().start() + self.offset + offset) as *mut T; // SAFETY: The safety of the write operation's semantics is upheld by the caller. - unsafe { core::ptr::write_volatile(ptr, *value) }; + unsafe { crate::arch::io::io_mem::write_once(ptr, *value) }; } } @@ -210,39 +224,160 @@ impl IoMem { } } -// For now, we reuse `VmReader` and `VmWriter` to access I/O memory. -// -// Note that I/O memory is not normal typed or untyped memory. Strictly speaking, it is not -// "memory", but rather I/O ports that communicate directly with the hardware. However, this code -// is in OSTD, so we can rely on the implementation details of `VmReader` and `VmWriter`, which we -// know are also suitable for accessing I/O memory. +impl VmIoOnce for IoMem { + fn read_once(&self, offset: usize) -> Result { + self.check_range(offset, size_of::())?; + let ptr = (self.base() + offset) as *const T; + if !ptr.is_aligned() { + return Err(Error::InvalidArgs); + } -impl HasVmReaderWriter for IoMem { - type Types = VmReaderWriterIdentity; + // SAFETY: The pointer is properly aligned and within the validated range. + let val = unsafe { crate::arch::io::io_mem::read_once(ptr) }; + Ok(val) + } - fn reader(&self) -> VmReader<'_, Infallible> { - // SAFETY: The constructor of the `IoMem` structure has already ensured the - // safety of reading from the mapped physical address, and the mapping is valid. - unsafe { - VmReader::from_kernel_space( - (self.kvirt_area.deref().start() + self.offset) as *mut u8, - self.limit, - ) + fn write_once(&self, offset: usize, value: &T) -> Result<()> { + self.check_range(offset, size_of::())?; + let ptr = (self.base() + offset) as *mut T; + if !ptr.is_aligned() { + return Err(Error::InvalidArgs); + } + + // SAFETY: The pointer is properly aligned and within the validated range. + unsafe { crate::arch::io::io_mem::write_once(ptr, *value) }; + Ok(()) + } +} + +impl VmIo for IoMem { + fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()> { + let len = writer.avail(); + self.check_range(offset, len)?; + + let src = (self.base() + offset) as *const u8; + + // SAFETY: `check_range` guarantees a valid MMIO range for `len` bytes. + let result = unsafe { crate::io::io_mem::util::copy_from_fallible(writer, src, len) }; + match result { + Ok(()) => Ok(()), + Err((err, _copied_len)) => Err(err), } } - fn writer(&self) -> VmWriter<'_, Infallible> { - // SAFETY: The constructor of the `IoMem` structure has already ensured the - // safety of writing to the mapped physical address, and the mapping is valid. + fn read_bytes(&self, offset: usize, buf: &mut [u8]) -> Result<()> { + self.check_range(offset, buf.len())?; + let src = (self.base() + offset) as *const u8; + let dst = buf.as_mut_ptr(); + + // SAFETY: The `dst` and `src` buffers are valid to write and read, respectively. unsafe { - VmWriter::from_kernel_space( - (self.kvirt_area.deref().start() + self.offset) as *mut u8, - self.limit, - ) + crate::io::io_mem::util::copy_from(dst, src, buf.len()); + } + Ok(()) + } + + fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()> { + let len = reader.remain(); + self.check_range(offset, len)?; + + let dst = (self.base() + offset) as *mut u8; + + // SAFETY: `check_range` guarantees a valid MMIO range for `len` bytes. + let result = unsafe { crate::io::io_mem::util::copy_to_fallible(reader, dst, len) }; + match result { + Ok(()) => Ok(()), + Err((err, _copied_len)) => Err(err), + } + } + + fn write_bytes(&self, offset: usize, buf: &[u8]) -> Result<()> { + self.check_range(offset, buf.len())?; + let src = buf.as_ptr(); + let dst = (self.base() + offset) as *mut u8; + + // SAFETY: The `dst` and `src` buffers are valid to write and read, respectively. + unsafe { + crate::io::io_mem::util::copy_to(src, dst, buf.len()); + } + Ok(()) + } +} + +impl VmIoFill for IoMem { + fn fill_zeros(&self, offset: usize, len: usize) -> core::result::Result<(), (Error, usize)> { + if offset > self.limit { + return Err((Error::InvalidArgs, 0)); + } + + let available = self.limit - offset; + let write_len = core::cmp::min(len, available); + if write_len == 0 { + return Ok(()); + } + + let mut remaining = write_len; + let mut ptr = (self.base() + offset) as *mut u8; + let word_size = size_of::(); + + // Align destination to word size. + while remaining > 0 && !ptr.addr().is_multiple_of(word_size) { + // SAFETY: `check_range` guarantees a valid MMIO range for the range. + unsafe { crate::arch::io::io_mem::write_once(ptr, 0u8) }; + ptr = ptr.wrapping_add(1); + remaining -= 1; + } + + while remaining >= word_size { + // SAFETY: `check_range` guarantees a valid MMIO range for the range. + unsafe { crate::arch::io::io_mem::write_once(ptr.cast::(), 0usize) }; + ptr = ptr.wrapping_add(word_size); + remaining -= word_size; + } + + while remaining > 0 { + // SAFETY: The remaining range is within the validated MMIO window. + unsafe { crate::arch::io::io_mem::write_once(ptr, 0u8) }; + ptr = ptr.wrapping_add(1); + remaining -= 1; + } + + if write_len < len { + Err((Error::InvalidArgs, write_len)) + } else { + Ok(()) } } } +macro_rules! impl_vm_io_pointer { + ($ty:ty, $from:tt) => { + #[inherit_methods(from = $from)] + impl VmIo for $ty { + fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()>; + fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()>; + } + + #[inherit_methods(from = $from)] + impl VmIoOnce for $ty { + fn read_once(&self, offset: usize) -> Result; + fn write_once(&self, offset: usize, value: &T) -> Result<()>; + } + + #[inherit_methods(from = $from)] + impl VmIoFill for $ty { + fn fill_zeros( + &self, + offset: usize, + len: usize, + ) -> core::result::Result<(), (Error, usize)>; + } + }; +} + +impl_vm_io_pointer!(&IoMem, "(**self)"); +impl_vm_io_pointer!(&mut IoMem, "(**self)"); + impl HasPaddr for IoMem { fn paddr(&self) -> Paddr { self.pa diff --git a/ostd/src/io/io_mem/util.rs b/ostd/src/io/io_mem/util.rs new file mode 100644 index 000000000..65020bf6b --- /dev/null +++ b/ostd/src/io/io_mem/util.rs @@ -0,0 +1,479 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{ + Error, + arch::io::io_mem::{read_once, write_once}, + mm::{FallibleVmRead, FallibleVmWrite, VmReader, VmWriter}, +}; + +#[cfg(all(target_arch = "x86_64", not(feature = "cvm_guest")))] +const BULK_THRESHOLD: usize = 128; + +/// Attempts to copy from MMIO to regular memory using a bulk path. +/// +/// # Safety +/// +/// - `dst` must be valid for writes of `count` bytes. +/// - `src` must be valid for MMIO reads of `count` bytes. +#[cfg(all(target_arch = "x86_64", not(feature = "cvm_guest")))] +unsafe fn try_bulk_copy_from(dst: *mut u8, src: *const u8, count: usize) -> bool { + if count >= BULK_THRESHOLD { + // SAFETY: The caller guarantees both pointers are valid for `count` bytes. + unsafe { crate::arch::io::io_mem::copy_from_bulk(dst, src, count) }; + return true; + } + false +} + +/// Attempts to copy from MMIO to regular memory using a bulk path. +/// +/// # Safety +/// +/// - `dst` must be valid for writes of `count` bytes. +/// - `src` must be valid for MMIO reads of `count` bytes. +#[cfg(not(all(target_arch = "x86_64", not(feature = "cvm_guest"))))] +unsafe fn try_bulk_copy_from(_dst: *mut u8, _src: *const u8, _count: usize) -> bool { + false +} + +/// Attempts to copy from regular memory to MMIO using a bulk path. +/// +/// # Safety +/// +/// - `src` must be valid for reads of `count` bytes. +/// - `dst` must be valid for MMIO writes of `count` bytes. +#[cfg(all(target_arch = "x86_64", not(feature = "cvm_guest")))] +unsafe fn try_bulk_copy_to(src: *const u8, dst: *mut u8, count: usize) -> bool { + if count >= BULK_THRESHOLD { + // SAFETY: The caller guarantees both pointers are valid for `count` bytes. + unsafe { crate::arch::io::io_mem::copy_to_bulk(src, dst, count) }; + return true; + } + false +} + +/// Attempts to copy from regular memory to MMIO using a bulk path. +/// +/// # Safety +/// +/// - `src` must be valid for reads of `count` bytes. +/// - `dst` must be valid for MMIO writes of `count` bytes. +#[cfg(not(all(target_arch = "x86_64", not(feature = "cvm_guest"))))] +unsafe fn try_bulk_copy_to(_src: *const u8, _dst: *mut u8, _count: usize) -> bool { + false +} + +/// Copies one byte from MMIO to regular memory and advances both pointers. +/// +/// # Safety +/// +/// - `src_io_ptr` must be valid for one byte of MMIO read. +/// - `dst_ptr` must be valid for one byte of regular memory write. +unsafe fn copy_u8_from(dst_ptr: &mut *mut u8, src_io_ptr: &mut *const u8) { + // SAFETY: The caller guarantees the MMIO and regular memory pointers are valid. + unsafe { + let val: u8 = read_once(*src_io_ptr); + core::ptr::write(*dst_ptr, val); + *src_io_ptr = (*src_io_ptr).add(1); + *dst_ptr = (*dst_ptr).add(1); + } +} + +/// Copies a u16 from MMIO to regular memory and advances both pointers. +/// +/// # Safety +/// +/// - `src_io_ptr` must be aligned and valid for a u16 MMIO read. +/// - `dst_ptr` must be valid for a u16 write (aligned if `write_unaligned` is false). +unsafe fn copy_u16_from(dst_ptr: &mut *mut u8, src_io_ptr: &mut *const u8, write_unaligned: bool) { + // SAFETY: The caller guarantees the MMIO and regular memory pointers are valid. + unsafe { + let val: u16 = read_once((*src_io_ptr).cast::()); + if write_unaligned { + core::ptr::write_unaligned((*dst_ptr).cast::(), val); + } else { + core::ptr::write((*dst_ptr).cast::(), val); + } + *src_io_ptr = (*src_io_ptr).add(size_of::()); + *dst_ptr = (*dst_ptr).add(size_of::()); + } +} + +/// Copies one byte from regular memory to MMIO and advances both pointers. +/// +/// # Safety +/// +/// - `src_ptr` must be valid for one byte of regular memory read. +/// - `dst_io_ptr` must be valid for one byte of MMIO write. +unsafe fn copy_u8_to(src_ptr: &mut *const u8, dst_io_ptr: &mut *mut u8) { + // SAFETY: The caller guarantees the regular memory and MMIO pointers are valid. + unsafe { + let val: u8 = core::ptr::read(*src_ptr); + write_once(*dst_io_ptr, val); + *src_ptr = (*src_ptr).add(1); + *dst_io_ptr = (*dst_io_ptr).add(1); + } +} + +/// Copies a u16 from regular memory to MMIO and advances both pointers. +/// +/// # Safety +/// +/// - `src_ptr` must be valid for a u16 read (aligned if `read_unaligned` is false). +/// - `dst_io_ptr` must be aligned and valid for a u16 MMIO write. +unsafe fn copy_u16_to(src_ptr: &mut *const u8, dst_io_ptr: &mut *mut u8, read_unaligned: bool) { + // SAFETY: The caller guarantees the regular memory and MMIO pointers are valid. + unsafe { + let val: u16 = if read_unaligned { + core::ptr::read_unaligned((*src_ptr).cast::()) + } else { + core::ptr::read((*src_ptr).cast::()) + }; + write_once((*dst_io_ptr).cast::(), val); + *src_ptr = (*src_ptr).add(size_of::()); + *dst_io_ptr = (*dst_io_ptr).add(size_of::()); + } +} + +/// Copies from I/O memory to regular memory. +/// +/// This uses simple load/store instructions, which is required on some platforms +/// (e.g., TDX). +/// +/// # Safety +/// +/// - `src_io_ptr` must be valid for MMIO reads of `count` bytes. +/// - `dst_ptr` must be valid for writes of `count` bytes. +pub(crate) unsafe fn copy_from(mut dst_ptr: *mut u8, mut src_io_ptr: *const u8, mut count: usize) { + // SAFETY: The caller guarantees both pointers are valid for `count` bytes. + if unsafe { try_bulk_copy_from(dst_ptr, src_io_ptr, count) } { + return; + } + + if count == 0 { + return; + } + + // Align any unaligned source IO. + if src_io_ptr.addr() & 1 != 0 { + // SAFETY: The caller guarantees valid MMIO reads and regular memory writes. + unsafe { copy_u8_from(&mut dst_ptr, &mut src_io_ptr) }; + count -= 1; + } + + if count > 1 && src_io_ptr.addr() & 2 != 0 { + let write_unaligned = dst_ptr.addr() & 1 != 0; + // SAFETY: The caller guarantees valid MMIO reads and regular memory writes. + unsafe { copy_u16_from(&mut dst_ptr, &mut src_io_ptr, write_unaligned) }; + count -= 2; + } + + let write_unaligned = dst_ptr.addr() & 1 != 0; + let word_size = size_of::(); + while count >= word_size { + // SAFETY: The caller guarantees valid MMIO reads and regular memory writes. + // The source pointer is aligned by the steps above. + unsafe { copy_u16_from(&mut dst_ptr, &mut src_io_ptr, write_unaligned) }; + count -= word_size; + } + + while count > 0 { + // SAFETY: The caller guarantees valid MMIO reads and regular memory writes. + unsafe { copy_u8_from(&mut dst_ptr, &mut src_io_ptr) }; + count -= 1; + } +} + +/// Copies from regular memory to I/O memory. +/// +/// # Safety +/// +/// - `src_ptr` must be valid for reads of `count` bytes. +/// - `dst_io_ptr` must be valid for MMIO writes of `count` bytes. +pub(crate) unsafe fn copy_to(mut src_ptr: *const u8, mut dst_io_ptr: *mut u8, mut count: usize) { + // SAFETY: The caller guarantees both pointers are valid for `count` bytes. + if unsafe { try_bulk_copy_to(src_ptr, dst_io_ptr, count) } { + return; + } + + if count == 0 { + return; + } + + // Align any unaligned destination IO. + if dst_io_ptr.addr() & 1 != 0 { + // SAFETY: The caller guarantees valid reads from regular memory and MMIO writes. + unsafe { copy_u8_to(&mut src_ptr, &mut dst_io_ptr) }; + count -= 1; + } + + if count > 1 && dst_io_ptr.addr() & 2 != 0 { + let read_unaligned = src_ptr.addr() & 1 != 0; + // SAFETY: The caller guarantees valid reads from regular memory and MMIO writes. + unsafe { copy_u16_to(&mut src_ptr, &mut dst_io_ptr, read_unaligned) }; + count -= 2; + } + + let read_unaligned = src_ptr.addr() & 1 != 0; + let word_size = size_of::(); + while count >= word_size { + // SAFETY: The caller guarantees valid reads from regular memory and MMIO writes. + // The destination pointer is aligned by the steps above. + unsafe { copy_u16_to(&mut src_ptr, &mut dst_io_ptr, read_unaligned) }; + count -= word_size; + } + + while count > 0 { + // SAFETY: The caller guarantees valid reads from regular memory and MMIO writes. + unsafe { copy_u8_to(&mut src_ptr, &mut dst_io_ptr) }; + count -= 1; + } +} + +/// Copies from I/O memory to regular memory with fallible destination writes. +/// +/// # Safety +/// +/// - `src_io_ptr` must be valid for MMIO reads of `count` bytes. +pub(crate) unsafe fn copy_from_fallible( + writer: &mut VmWriter, + mut src_io_ptr: *const u8, + mut count: usize, +) -> core::result::Result<(), (Error, usize)> { + const BUF_SIZE: usize = 64; + let mut buf = [0u8; BUF_SIZE]; + let mut copied_total = 0; + + while count > 0 { + let chunk = core::cmp::min(count, buf.len()); + + // SAFETY: The caller guarantees the MMIO range is valid for this chunk. + unsafe { copy_from(buf.as_mut_ptr(), src_io_ptr, chunk) }; + + let mut reader = VmReader::from(&buf[..chunk]); + match writer.write_fallible(&mut reader) { + Ok(written) => { + copied_total += written; + if written < chunk { + return Err((Error::PageFault, copied_total)); + } + } + Err((err, written)) => { + copied_total += written; + return Err((err, copied_total)); + } + } + + // SAFETY: `src_io_ptr` is valid for `chunk` bytes and we advance by `chunk`. + src_io_ptr = unsafe { src_io_ptr.add(chunk) }; + count -= chunk; + } + + Ok(()) +} + +/// Copies from regular memory to I/O memory with fallible source reads. +/// +/// # Safety +/// +/// - `dst_io_ptr` must be valid for MMIO writes of `count` bytes. +pub(crate) unsafe fn copy_to_fallible( + reader: &mut VmReader, + mut dst_io_ptr: *mut u8, + mut count: usize, +) -> core::result::Result<(), (Error, usize)> { + const BUF_SIZE: usize = 64; + let mut buf = [0u8; BUF_SIZE]; + let mut copied_total = 0; + + while count > 0 { + let chunk = core::cmp::min(count, buf.len()); + + let mut writer = VmWriter::from(&mut buf[..chunk]); + match reader.read_fallible(&mut writer) { + Ok(read_len) => { + if read_len < chunk { + // SAFETY: The MMIO range is valid for `read_len` bytes. + unsafe { copy_to(buf.as_ptr(), dst_io_ptr, read_len) }; + copied_total += read_len; + return Err((Error::PageFault, copied_total)); + } + } + Err((err, read_len)) => { + if read_len > 0 { + // SAFETY: The MMIO range is valid for `read_len` bytes. + unsafe { copy_to(buf.as_ptr(), dst_io_ptr, read_len) }; + copied_total += read_len; + } + return Err((err, copied_total)); + } + } + + // SAFETY: The MMIO range is valid for `chunk` bytes. + unsafe { copy_to(buf.as_ptr(), dst_io_ptr, chunk) }; + + // SAFETY: `dst_io_ptr` is valid for `chunk` bytes and we advance by `chunk`. + dst_io_ptr = unsafe { dst_io_ptr.add(chunk) }; + count -= chunk; + copied_total += chunk; + } + + Ok(()) +} + +/// CPU architecture-agnostc tests for `arch::mm::io::{read_once, write_once}`. +#[cfg(ktest)] +mod test_read_once_and_write_once { + use super::{read_once, write_once}; + use crate::prelude::ktest; + + #[ktest] + fn read_write_u8() { + let mut data: u8 = 0; + // SAFETY: `data` is valid for a single MMIO read/write. + unsafe { + write_once(&mut data, 42u8); + assert_eq!(read_once(&data), 42u8); + } + } + + #[ktest] + fn read_write_u16() { + let mut data: u16 = 0; + let val: u16 = 0x1234; + // SAFETY: `data` is valid for a single MMIO read/write. + unsafe { + write_once(&mut data, val); + assert_eq!(read_once(&data), val); + } + } + + #[ktest] + fn read_write_u32() { + let mut data: u32 = 0; + let val: u32 = 0x12345678; + // SAFETY: `data` is valid for a single MMIO read/write. + unsafe { + write_once(&mut data, val); + assert_eq!(read_once(&data), val); + } + } + + #[ktest] + fn read_write_u64() { + let mut data: u64 = 0; + let val: u64 = 0xDEADBEEFCAFEBABE; + // SAFETY: `data` is valid for a single MMIO read/write. + unsafe { + write_once(&mut data, val); + assert_eq!(read_once(&data), val); + } + } + + #[ktest] + fn boundary_overlap() { + // Ensure that writing a u8 doesn't corrupt neighboring bytes + // in a larger structure, verifying our instruction sizing. + let mut data: [u8; 2] = [0xAA, 0xBB]; + // SAFETY: `data` is valid for a single MMIO read/write. + unsafe { + write_once(&mut data[0], 0x11u8); + assert_eq!(data[0], 0x11); + assert_eq!(data[1], 0xBB); // Should remain untouched + } + } +} + +#[cfg(ktest)] +mod test_copy_helpers { + use super::{copy_from, copy_to}; + use crate::prelude::ktest; + + fn fill_pattern(buf: &mut [u8]) { + for (idx, byte) in buf.iter_mut().enumerate() { + *byte = (idx as u8).wrapping_mul(3).wrapping_add(1); + } + } + + fn run_copy_from_case(src_offset: usize, dst_offset: usize, len: usize) { + let mut src = [0u8; 64]; + let mut dst = [0u8; 64]; + fill_pattern(&mut src); + + let src_ptr = unsafe { src.as_ptr().add(src_offset) }; + let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) }; + + // SAFETY: The test buffers are valid for the requested range. + unsafe { copy_from(dst_ptr, src_ptr, len) }; + + assert_eq!( + &dst[dst_offset..dst_offset + len], + &src[src_offset..src_offset + len] + ); + } + + fn run_copy_to_case(src_offset: usize, dst_offset: usize, len: usize) { + let mut src = [0u8; 64]; + let mut dst = [0u8; 64]; + fill_pattern(&mut src); + + let src_ptr = unsafe { src.as_ptr().add(src_offset) }; + let dst_ptr = unsafe { dst.as_mut_ptr().add(dst_offset) }; + + // SAFETY: The test buffers are valid for the requested range. + unsafe { copy_to(src_ptr, dst_ptr, len) }; + + assert_eq!( + &dst[dst_offset..dst_offset + len], + &src[src_offset..src_offset + len] + ); + } + + #[ktest] + fn copy_from_alignment_and_sizes() { + let word_size = size_of::(); + let sizes = [ + 0, + 1, + word_size.saturating_sub(1), + word_size, + word_size + 1, + word_size * 2 + 3, + ]; + let offsets = [0, 1, 2]; + + for &len in &sizes { + for &src_offset in &offsets { + for &dst_offset in &offsets { + if src_offset + len <= 64 && dst_offset + len <= 64 { + run_copy_from_case(src_offset, dst_offset, len); + } + } + } + } + } + + #[ktest] + fn copy_to_alignment_and_sizes() { + let word_size = size_of::(); + let sizes = [ + 0, + 1, + word_size.saturating_sub(1), + word_size, + word_size + 1, + word_size * 2 + 3, + ]; + let offsets = [0, 1, 2]; + + for &len in &sizes { + for &src_offset in &offsets { + for &dst_offset in &offsets { + if src_offset + len <= 64 && dst_offset + len <= 64 { + run_copy_to_case(src_offset, dst_offset, len); + } + } + } + } + } +} diff --git a/ostd/src/mm/io.rs b/ostd/src/mm/io.rs index a405c00c4..fabaf8d9f 100644 --- a/ostd/src/mm/io.rs +++ b/ostd/src/mm/io.rs @@ -956,6 +956,20 @@ impl VmWriter<'_, Fallibility> { self.cursor } + /// Creates a temporary writer view with the same cursor and end. + /// + /// The returned writer borrows `self` mutably, so the original writer + /// cannot be used until the returned writer is dropped. This is useful + /// for creating a temporary, adjusted view without mutating the original + /// writer state. + pub fn fork(&mut self) -> VmWriter<'_, Fallibility> { + VmWriter { + cursor: self.cursor, + end: self.end, + phantom: PhantomData, + } + } + /// Returns if it has available space to write. pub fn has_avail(&self) -> bool { self.avail() > 0