From 39a541fdeba53dd90cc85d13dcdd3fee916bb53a Mon Sep 17 00:00:00 2001 From: Zejun Zhao Date: Fri, 8 Aug 2025 04:53:33 +0800 Subject: [PATCH] Add RISC-V FPU support --- kernel/src/process/signal/c_types.rs | 5 + kernel/src/process/signal/mod.rs | 10 +- ostd/src/arch/riscv/cpu/context.rs | 156 ++++++++++++++++-- ostd/src/arch/riscv/cpu/fpu.S | 228 +++++++++++++++++++++++++++ ostd/src/arch/riscv/mod.rs | 5 - ostd/src/arch/riscv/trap/mod.rs | 2 +- ostd/src/arch/riscv/trap/trap.S | 21 ++- ostd/src/arch/riscv/trap/trap.rs | 35 +++- 8 files changed, 424 insertions(+), 38 deletions(-) create mode 100644 ostd/src/arch/riscv/cpu/fpu.S diff --git a/kernel/src/process/signal/c_types.rs b/kernel/src/process/signal/c_types.rs index 630cebb6d..d5d909f66 100644 --- a/kernel/src/process/signal/c_types.rs +++ b/kernel/src/process/signal/c_types.rs @@ -196,6 +196,11 @@ pub struct ucontext_t { pub uc_mcontext: mcontext_t, } +// FIXME: Currently Rust generates array impls for every size up to 32 manually +// and there is ongoing work on refactoring with const generics. We can just +// derive the `Default` implementation once that is done. +// +// See . #[cfg(any(target_arch = "riscv64", target_arch = "loongarch64"))] impl Default for ucontext_t { fn default() -> Self { diff --git a/kernel/src/process/signal/mod.rs b/kernel/src/process/signal/mod.rs index 21ce40318..bd94b44cc 100644 --- a/kernel/src/process/signal/mod.rs +++ b/kernel/src/process/signal/mod.rs @@ -258,9 +258,15 @@ pub fn handle_user_signal( const UC_FP_XSTATE: u64 = 1 << 0; ucontext.uc_flags = UC_FP_XSTATE; } else if #[cfg(target_arch = "riscv64")] { + // Reference: + // , + // . + const FP_STATE_SIZE: usize = + size_of::() + 3 * size_of::(); + let ucontext_addr = alloc_aligned_in_user_stack( stack_pointer, - size_of::() + fpu_context_bytes.len(), + size_of::() + FP_STATE_SIZE, align_of::(), )?; let fpu_context_addr = (ucontext_addr as usize) + size_of::(); @@ -280,7 +286,7 @@ pub fn handle_user_signal( } } - let mut fpu_context_reader = VmReader::from(fpu_context.as_bytes()); + let mut fpu_context_reader = VmReader::from(fpu_context_bytes); user_space.write_bytes(fpu_context_addr as _, &mut fpu_context_reader)?; user_space.write_val(ucontext_addr as _, &ucontext)?; diff --git a/ostd/src/arch/riscv/cpu/context.rs b/ostd/src/arch/riscv/cpu/context.rs index 20efa7534..c535a3f56 100644 --- a/ostd/src/arch/riscv/cpu/context.rs +++ b/ostd/src/arch/riscv/cpu/context.rs @@ -2,12 +2,17 @@ //! CPU execution context control. -use core::fmt::Debug; +use alloc::boxed::Box; +use core::{arch::global_asm, fmt::Debug}; +use ostd_pod::Pod; use riscv::register::scause::{Exception, Trap}; use crate::{ - arch::trap::{call_irq_callback_functions_by_scause, RawUserContext, TrapFrame}, + arch::{ + cpu::extension::{has_extensions, IsaExtensions}, + trap::{call_irq_callback_functions_by_scause, RawUserContext, TrapFrame, SSTATUS_FS_MASK}, + }, cpu::PrivilegeLevel, user::{ReturnReason, UserContextApi, UserContextApiInternal}, }; @@ -302,31 +307,154 @@ cpu_context_impl_getter_setter!( ); /// The FPU context of user task. -/// -/// This could be used for saving both legacy and modern state format. -// FIXME: Implement FPU context on RISC-V platforms. -#[derive(Clone, Debug, Default)] -pub struct FpuContext; +#[derive(Clone, Debug)] +pub enum FpuContext { + /// FPU context for F extension (32-bit floating point). + F(Box), + /// FPU context for D extension (64-bit floating point). + D(Box), + /// FPU context for Q extension (128-bit floating point). + Q(Box), + /// No FPU context (no FPU extensions enabled). + None, +} impl FpuContext { /// Creates a new FPU context. pub fn new() -> Self { - Self + if has_extensions(IsaExtensions::Q) { + Self::Q(Box::default()) + } else if has_extensions(IsaExtensions::D) { + Self::D(Box::default()) + } else if has_extensions(IsaExtensions::F) { + Self::F(Box::default()) + } else { + Self::None + } } - /// Saves CPU's current FPU context to this instance, if needed. - pub fn save(&mut self) {} + /// Saves CPU's current FPU context to this instance. + pub fn save(&mut self) { + match self { + Self::F(ctx) => ctx.save(), + Self::D(ctx) => ctx.save(), + Self::Q(ctx) => ctx.save(), + Self::None => {} + } + } - /// Loads CPU's FPU context from this instance, if needed. - pub fn load(&mut self) {} + /// Loads CPU's FPU context from this instance. + pub fn load(&self) { + match self { + Self::F(ctx) => ctx.load(), + Self::D(ctx) => ctx.load(), + Self::Q(ctx) => ctx.load(), + Self::None => {} + } + } /// Returns the FPU context as a byte slice. pub fn as_bytes(&self) -> &[u8] { - &[] + match self { + Self::F(ctx) => ctx.as_bytes(), + Self::D(ctx) => ctx.as_bytes(), + Self::Q(ctx) => ctx.as_bytes(), + Self::None => &[], + } } /// Returns the FPU context as a mutable byte slice. pub fn as_bytes_mut(&mut self) -> &mut [u8] { - &mut [] + match self { + Self::F(ctx) => ctx.as_bytes_mut(), + Self::D(ctx) => ctx.as_bytes_mut(), + Self::Q(ctx) => ctx.as_bytes_mut(), + Self::None => &mut [], + } } } + +impl Default for FpuContext { + fn default() -> Self { + Self::new() + } +} + +/// FPU context for F extension (32-bit floating point). +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, Pod)] +pub struct FFpuContext { + f: [u32; 32], + fcsr: u32, +} + +/// FPU context for D extension (64-bit floating point). +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, Pod)] +pub struct DFpuContext { + f: [u64; 32], + fcsr: u32, +} + +/// FPU context for Q extension (128-bit floating point). +#[repr(C)] +#[derive(Clone, Copy, Debug, Pod)] +pub struct QFpuContext { + f: [u64; 64], + fcsr: u32, +} + +// FIXME: Currently Rust generates array impls for every size up to 32 manually +// and there is ongoing work on refactoring with const generics. We can just +// derive the `Default` implementation once that is done. +// +// See . +impl Default for QFpuContext { + fn default() -> Self { + Self { + f: [0; 64], + fcsr: 0, + } + } +} + +impl FFpuContext { + fn save(&mut self) { + unsafe { save_fpu_context_f(self as *mut _) }; + } + + fn load(&self) { + unsafe { load_fpu_context_f(self as *const _) }; + } +} + +impl DFpuContext { + fn save(&mut self) { + unsafe { save_fpu_context_d(self as *mut _) }; + } + + fn load(&self) { + unsafe { load_fpu_context_d(self as *const _) }; + } +} + +impl QFpuContext { + fn save(&mut self) { + unsafe { save_fpu_context_q(self as *mut _) }; + } + + fn load(&self) { + unsafe { load_fpu_context_q(self as *const _) }; + } +} + +global_asm!(include_str!("fpu.S"), SSTATUS_FS_MASK = const SSTATUS_FS_MASK); + +unsafe extern "C" { + unsafe fn save_fpu_context_f(ctx: *mut FFpuContext); + unsafe fn load_fpu_context_f(ctx: *const FFpuContext); + unsafe fn save_fpu_context_d(ctx: *mut DFpuContext); + unsafe fn load_fpu_context_d(ctx: *const DFpuContext); + unsafe fn save_fpu_context_q(ctx: *mut QFpuContext); + unsafe fn load_fpu_context_q(ctx: *const QFpuContext); +} diff --git a/ostd/src/arch/riscv/cpu/fpu.S b/ostd/src/arch/riscv/cpu/fpu.S new file mode 100644 index 000000000..9fb0c93ae --- /dev/null +++ b/ostd/src/arch/riscv/cpu/fpu.S @@ -0,0 +1,228 @@ +/* SPDX-License-Identifier: MPL-2.0 */ + +.macro SAVE_FPU_CONTEXT base, store_insn, reg_size + # Enable FPU + li t1, {SSTATUS_FS_MASK} + csrs sstatus, t1 + + .if \reg_size == 16 + # Q extension - use manual encoding + FSQ 0, 0, 10 # a0 is register x10 + FSQ 1, 16, 10 + FSQ 2, 32, 10 + FSQ 3, 48, 10 + FSQ 4, 64, 10 + FSQ 5, 80, 10 + FSQ 6, 96, 10 + FSQ 7, 112, 10 + FSQ 8, 128, 10 + FSQ 9, 144, 10 + FSQ 10, 160, 10 + FSQ 11, 176, 10 + FSQ 12, 192, 10 + FSQ 13, 208, 10 + FSQ 14, 224, 10 + FSQ 15, 240, 10 + FSQ 16, 256, 10 + FSQ 17, 272, 10 + FSQ 18, 288, 10 + FSQ 19, 304, 10 + FSQ 20, 320, 10 + FSQ 21, 336, 10 + FSQ 22, 352, 10 + FSQ 23, 368, 10 + FSQ 24, 384, 10 + FSQ 25, 400, 10 + FSQ 26, 416, 10 + FSQ 27, 432, 10 + FSQ 28, 448, 10 + FSQ 29, 464, 10 + FSQ 30, 480, 10 + FSQ 31, 496, 10 + .else + # F/D extensions - use regular instructions + \store_insn f0, 0*\reg_size(\base) + \store_insn f1, 1*\reg_size(\base) + \store_insn f2, 2*\reg_size(\base) + \store_insn f3, 3*\reg_size(\base) + \store_insn f4, 4*\reg_size(\base) + \store_insn f5, 5*\reg_size(\base) + \store_insn f6, 6*\reg_size(\base) + \store_insn f7, 7*\reg_size(\base) + \store_insn f8, 8*\reg_size(\base) + \store_insn f9, 9*\reg_size(\base) + \store_insn f10, 10*\reg_size(\base) + \store_insn f11, 11*\reg_size(\base) + \store_insn f12, 12*\reg_size(\base) + \store_insn f13, 13*\reg_size(\base) + \store_insn f14, 14*\reg_size(\base) + \store_insn f15, 15*\reg_size(\base) + \store_insn f16, 16*\reg_size(\base) + \store_insn f17, 17*\reg_size(\base) + \store_insn f18, 18*\reg_size(\base) + \store_insn f19, 19*\reg_size(\base) + \store_insn f20, 20*\reg_size(\base) + \store_insn f21, 21*\reg_size(\base) + \store_insn f22, 22*\reg_size(\base) + \store_insn f23, 23*\reg_size(\base) + \store_insn f24, 24*\reg_size(\base) + \store_insn f25, 25*\reg_size(\base) + \store_insn f26, 26*\reg_size(\base) + \store_insn f27, 27*\reg_size(\base) + \store_insn f28, 28*\reg_size(\base) + \store_insn f29, 29*\reg_size(\base) + \store_insn f30, 30*\reg_size(\base) + \store_insn f31, 31*\reg_size(\base) + .endif + + # Save fcsr + frcsr t0 + sw t0, 32*\reg_size(\base) + + # Disable FPU + csrc sstatus, t1 + + ret +.endm + +.macro LOAD_FPU_CONTEXT base, load_insn, reg_size + # Enable FPU + li t1, {SSTATUS_FS_MASK} + csrs sstatus, t1 + + .if \reg_size == 16 + # Q extension - use manual encoding + FLQ 0, 0, 10 # a0 = register 10 + FLQ 1, 16, 10 + FLQ 2, 32, 10 + FLQ 3, 48, 10 + FLQ 4, 64, 10 + FLQ 5, 80, 10 + FLQ 6, 96, 10 + FLQ 7, 112, 10 + FLQ 8, 128, 10 + FLQ 9, 144, 10 + FLQ 10, 160, 10 + FLQ 11, 176, 10 + FLQ 12, 192, 10 + FLQ 13, 208, 10 + FLQ 14, 224, 10 + FLQ 15, 240, 10 + FLQ 16, 256, 10 + FLQ 17, 272, 10 + FLQ 18, 288, 10 + FLQ 19, 304, 10 + FLQ 20, 320, 10 + FLQ 21, 336, 10 + FLQ 22, 352, 10 + FLQ 23, 368, 10 + FLQ 24, 384, 10 + FLQ 25, 400, 10 + FLQ 26, 416, 10 + FLQ 27, 432, 10 + FLQ 28, 448, 10 + FLQ 29, 464, 10 + FLQ 30, 480, 10 + FLQ 31, 496, 10 + .else + # F/D extensions - use regular instructions + \load_insn f0, 0*\reg_size(\base) + \load_insn f1, 1*\reg_size(\base) + \load_insn f2, 2*\reg_size(\base) + \load_insn f3, 3*\reg_size(\base) + \load_insn f4, 4*\reg_size(\base) + \load_insn f5, 5*\reg_size(\base) + \load_insn f6, 6*\reg_size(\base) + \load_insn f7, 7*\reg_size(\base) + \load_insn f8, 8*\reg_size(\base) + \load_insn f9, 9*\reg_size(\base) + \load_insn f10, 10*\reg_size(\base) + \load_insn f11, 11*\reg_size(\base) + \load_insn f12, 12*\reg_size(\base) + \load_insn f13, 13*\reg_size(\base) + \load_insn f14, 14*\reg_size(\base) + \load_insn f15, 15*\reg_size(\base) + \load_insn f16, 16*\reg_size(\base) + \load_insn f17, 17*\reg_size(\base) + \load_insn f18, 18*\reg_size(\base) + \load_insn f19, 19*\reg_size(\base) + \load_insn f20, 20*\reg_size(\base) + \load_insn f21, 21*\reg_size(\base) + \load_insn f22, 22*\reg_size(\base) + \load_insn f23, 23*\reg_size(\base) + \load_insn f24, 24*\reg_size(\base) + \load_insn f25, 25*\reg_size(\base) + \load_insn f26, 26*\reg_size(\base) + \load_insn f27, 27*\reg_size(\base) + \load_insn f28, 28*\reg_size(\base) + \load_insn f29, 29*\reg_size(\base) + \load_insn f30, 30*\reg_size(\base) + \load_insn f31, 31*\reg_size(\base) + .endif + + # Load fcsr + lw t0, 32*\reg_size(\base) + fscsr t0 + + # Disable FPU + csrc sstatus, t1 + + ret +.endm + +# Currently LLVM assembler doesn't support Q extension, we manually +# encode the FSQ and FLQ instructions here. + +# FSQ: store freg to offset(basereg) +.macro FSQ freg, offset, basereg + .4byte (((\offset & 0xFE0) << 20) | (\freg << 20) | (\basereg << 15) | (0x4 << 12) | ((\offset & 0x1F) << 7) | 0x27) +.endm + +# FLQ: load freg from offset(basereg) +.macro FLQ freg, offset, basereg + .4byte (((\offset & 0xFFF) << 20) | (\basereg << 15) | (0x4 << 12) | (\freg << 7) | 0x07) +.endm + +.text + +.option push +.option arch, +f +.option arch, +d + +.global save_fpu_context_f +.type save_fpu_context_f, @function +save_fpu_context_f: + SAVE_FPU_CONTEXT a0, fsw, 4 +.size save_fpu_context_f, .-save_fpu_context_f + +.global load_fpu_context_f +.type load_fpu_context_f, @function +load_fpu_context_f: + LOAD_FPU_CONTEXT a0, flw, 4 +.size load_fpu_context_f, .-load_fpu_context_f + +.global save_fpu_context_d +.type save_fpu_context_d, @function +save_fpu_context_d: + SAVE_FPU_CONTEXT a0, fsd, 8 +.size save_fpu_context_d, .-save_fpu_context_d + +.global load_fpu_context_d +.type load_fpu_context_d, @function +load_fpu_context_d: + LOAD_FPU_CONTEXT a0, fld, 8 +.size load_fpu_context_d, .-load_fpu_context_d + +.global save_fpu_context_q +.type save_fpu_context_q, @function +save_fpu_context_q: + SAVE_FPU_CONTEXT a0, fsq, 16 +.size save_fpu_context_q, .-save_fpu_context_q + +.global load_fpu_context_q +.type load_fpu_context_q, @function +load_fpu_context_q: + LOAD_FPU_CONTEXT a0, flq, 16 +.size load_fpu_context_q, .-load_fpu_context_q + +.option pop diff --git a/ostd/src/arch/riscv/mod.rs b/ostd/src/arch/riscv/mod.rs index 09e9afe78..6e6cf0b76 100644 --- a/ostd/src/arch/riscv/mod.rs +++ b/ostd/src/arch/riscv/mod.rs @@ -83,9 +83,4 @@ pub fn read_random() -> Option { pub(crate) fn enable_cpu_features() { cpu::extension::init(); - unsafe { - // We adopt a lazy approach to enable the floating-point unit; it's not - // enabled before the first FPU trap. - riscv::register::sstatus::set_fs(riscv::register::sstatus::FS::Off); - } } diff --git a/ostd/src/arch/riscv/trap/mod.rs b/ostd/src/arch/riscv/trap/mod.rs index be7d9f5ee..94734c99e 100644 --- a/ostd/src/arch/riscv/trap/mod.rs +++ b/ostd/src/arch/riscv/trap/mod.rs @@ -9,8 +9,8 @@ use core::sync::atomic::Ordering; use riscv::register::scause::{Interrupt, Trap}; use spin::Once; -pub(super) use trap::RawUserContext; pub use trap::TrapFrame; +pub(super) use trap::{RawUserContext, SSTATUS_FS_MASK}; use crate::{ arch::{ diff --git a/ostd/src/arch/riscv/trap/trap.S b/ostd/src/arch/riscv/trap/trap.S index f544eec1f..110d3c426 100644 --- a/ostd/src/arch/riscv/trap/trap.S +++ b/ostd/src/arch/riscv/trap/trap.S @@ -68,22 +68,19 @@ trap_from_user: STORE_SP x31, 31 # save sp, sstatus, sepc - csrrw t0, sscratch, x0 # sscratch = 0 (kernel) - csrr t1, sstatus + csrrw t0, sscratch, x0 # sscratch = 0 (kernel) + li t3, {SSTATUS_FS_MASK} # disable FPU to prevent unexpected usage of floating point in kernel space + csrrc t1, sstatus, t3 csrr t2, sepc - STORE_SP t0, 2 # save sp - STORE_SP t1, 32 # save sstatus - STORE_SP t2, 33 # save sepc + STORE_SP t0, 2 # save sp + STORE_SP t1, 32 # save sstatus + STORE_SP t2, 33 # save sepc - li t0, 3 << 13 - or t1, t1, t0 # sstatus.FS = Dirty (3) - csrw sstatus, t1 - - andi t1, t1, 1 << 8 # sstatus.SPP == 1 + andi t1, t1, 1 << 8 # sstatus.SPP == 1 beqz t1, end_trap_from_user end_trap_from_kernel: - mv a0, sp # first arg is TrapFrame - lla ra, trap_return # set return address + mv a0, sp # first arg is TrapFrame + lla ra, trap_return # set return address .extern trap_handler j trap_handler diff --git a/ostd/src/arch/riscv/trap/trap.rs b/ostd/src/arch/riscv/trap/trap.rs index 05dfbc695..cb327d147 100644 --- a/ostd/src/arch/riscv/trap/trap.rs +++ b/ostd/src/arch/riscv/trap/trap.rs @@ -16,7 +16,10 @@ use core::arch::{asm, global_asm}; -use crate::arch::cpu::context::GeneralRegs; +use crate::arch::cpu::{ + context::GeneralRegs, + extension::{has_extensions, IsaExtensions}, +}; #[cfg(target_arch = "riscv32")] global_asm!( @@ -43,7 +46,11 @@ global_asm!( " ); -global_asm!(include_str!("trap.S")); +/// FPU status bits. +/// Reference: . +pub(in crate::arch) const SSTATUS_FS_MASK: usize = 0b11 << 13; + +global_asm!(include_str!("trap.S"), SSTATUS_FS_MASK = const SSTATUS_FS_MASK); /// Initialize interrupt handling for the current HART. /// @@ -76,7 +83,7 @@ pub unsafe fn init() { /// println!("TRAP! tf: {:#x?}", tf); /// } /// ``` -#[derive(Debug, Default, Clone, Copy)] +#[derive(Debug, Clone, Copy)] #[repr(C)] pub struct TrapFrame { /// General registers @@ -88,7 +95,7 @@ pub struct TrapFrame { } /// Saved registers on a trap. -#[derive(Debug, Default, Clone, Copy)] +#[derive(Debug, Clone, Copy)] #[repr(C)] pub(in crate::arch) struct RawUserContext { /// General registers @@ -99,6 +106,26 @@ pub(in crate::arch) struct RawUserContext { pub(in crate::arch) sepc: usize, } +impl Default for RawUserContext { + fn default() -> Self { + let sstatus = if has_extensions(IsaExtensions::F) + || has_extensions(IsaExtensions::D) + || has_extensions(IsaExtensions::Q) + { + const SSTATUS_FS_INITIAL: usize = 0b01 << 13; + SSTATUS_FS_INITIAL + } else { + 0 + }; + + Self { + general: GeneralRegs::default(), + sstatus, + sepc: 0, + } + } +} + impl RawUserContext { /// Goes to user space with the context, and comes back when a trap occurs. ///