From c2562d01382ee69094f9706ebd6d3f0cc1cd9484 Mon Sep 17 00:00:00 2001
From: Ralf Jung <post@ralfj.de>
Date: Thu, 18 Apr 2024 10:42:46 +0200
Subject: [PATCH] move allocator shim logic into its own file

---
 src/shims/alloc.rs                 | 152 +++++++++++++++++++++++++++++
 src/shims/foreign_items.rs         | 151 +---------------------------
 src/shims/mod.rs                   |   1 +
 src/shims/unix/foreign_items.rs    |   1 +
 src/shims/unix/fs.rs               |  29 +++---
 src/shims/windows/foreign_items.rs |   1 +
 6 files changed, 178 insertions(+), 157 deletions(-)
 create mode 100644 src/shims/alloc.rs

diff --git a/src/shims/alloc.rs b/src/shims/alloc.rs
new file mode 100644
index 0000000000..b5ae06c2a4
--- /dev/null
+++ b/src/shims/alloc.rs
@@ -0,0 +1,152 @@
+use std::iter;
+
+use rustc_ast::expand::allocator::AllocatorKind;
+use rustc_target::abi::{Align, Size};
+
+use crate::*;
+use shims::foreign_items::EmulateForeignItemResult;
+
+/// Check some basic requirements for this allocation request:
+/// non-zero size, power-of-two alignment.
+pub(super) fn check_alloc_request<'tcx>(size: u64, align: u64) -> InterpResult<'tcx> {
+    if size == 0 {
+        throw_ub_format!("creating allocation with size 0");
+    }
+    if !align.is_power_of_two() {
+        throw_ub_format!("creating allocation with non-power-of-two alignment {}", align);
+    }
+    Ok(())
+}
+
+impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
+pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
+    /// Returns the minimum alignment for the target architecture for allocations of the given size.
+    fn min_align(&self, size: u64, kind: MiriMemoryKind) -> Align {
+        let this = self.eval_context_ref();
+        // List taken from `library/std/src/sys/pal/common/alloc.rs`.
+        // This list should be kept in sync with the one from libstd.
+        let min_align = match this.tcx.sess.target.arch.as_ref() {
+            "x86" | "arm" | "mips" | "mips32r6" | "powerpc" | "powerpc64" | "wasm32" => 8,
+            "x86_64" | "aarch64" | "mips64" | "mips64r6" | "s390x" | "sparc64" | "loongarch64" =>
+                16,
+            arch => bug!("unsupported target architecture for malloc: `{}`", arch),
+        };
+        // Windows always aligns, even small allocations.
+        // Source: <https://support.microsoft.com/en-us/help/286470/how-to-use-pageheap-exe-in-windows-xp-windows-2000-and-windows-server>
+        // But jemalloc does not, so for the C heap we only align if the allocation is sufficiently big.
+        if kind == MiriMemoryKind::WinHeap || size >= min_align {
+            return Align::from_bytes(min_align).unwrap();
+        }
+        // We have `size < min_align`. Round `size` *down* to the next power of two and use that.
+        fn prev_power_of_two(x: u64) -> u64 {
+            let next_pow2 = x.next_power_of_two();
+            if next_pow2 == x {
+                // x *is* a power of two, just use that.
+                x
+            } else {
+                // x is between two powers, so next = 2*prev.
+                next_pow2 / 2
+            }
+        }
+        Align::from_bytes(prev_power_of_two(size)).unwrap()
+    }
+
+    /// Emulates calling the internal __rust_* allocator functions
+    fn emulate_allocator(
+        &mut self,
+        default: impl FnOnce(&mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx>,
+    ) -> InterpResult<'tcx, EmulateForeignItemResult> {
+        let this = self.eval_context_mut();
+
+        let Some(allocator_kind) = this.tcx.allocator_kind(()) else {
+            // in real code, this symbol does not exist without an allocator
+            return Ok(EmulateForeignItemResult::NotSupported);
+        };
+
+        match allocator_kind {
+            AllocatorKind::Global => {
+                // When `#[global_allocator]` is used, `__rust_*` is defined by the macro expansion
+                // of this attribute. As such we have to call an exported Rust function,
+                // and not execute any Miri shim. Somewhat unintuitively doing so is done
+                // by returning `NotSupported`, which triggers the `lookup_exported_symbol`
+                // fallback case in `emulate_foreign_item`.
+                return Ok(EmulateForeignItemResult::NotSupported);
+            }
+            AllocatorKind::Default => {
+                default(this)?;
+                Ok(EmulateForeignItemResult::NeedsJumping)
+            }
+        }
+    }
+
+    fn malloc(
+        &mut self,
+        size: u64,
+        zero_init: bool,
+        kind: MiriMemoryKind,
+    ) -> InterpResult<'tcx, Pointer<Option<Provenance>>> {
+        let this = self.eval_context_mut();
+        if size == 0 {
+            Ok(Pointer::null())
+        } else {
+            let align = this.min_align(size, kind);
+            let ptr = this.allocate_ptr(Size::from_bytes(size), align, kind.into())?;
+            if zero_init {
+                // We just allocated this, the access is definitely in-bounds and fits into our address space.
+                this.write_bytes_ptr(
+                    ptr.into(),
+                    iter::repeat(0u8).take(usize::try_from(size).unwrap()),
+                )
+                .unwrap();
+            }
+            Ok(ptr.into())
+        }
+    }
+
+    fn free(
+        &mut self,
+        ptr: Pointer<Option<Provenance>>,
+        kind: MiriMemoryKind,
+    ) -> InterpResult<'tcx> {
+        let this = self.eval_context_mut();
+        if !this.ptr_is_null(ptr)? {
+            this.deallocate_ptr(ptr, None, kind.into())?;
+        }
+        Ok(())
+    }
+
+    fn realloc(
+        &mut self,
+        old_ptr: Pointer<Option<Provenance>>,
+        new_size: u64,
+        kind: MiriMemoryKind,
+    ) -> InterpResult<'tcx, Pointer<Option<Provenance>>> {
+        let this = self.eval_context_mut();
+        let new_align = this.min_align(new_size, kind);
+        if this.ptr_is_null(old_ptr)? {
+            // Here we must behave like `malloc`.
+            if new_size == 0 {
+                Ok(Pointer::null())
+            } else {
+                let new_ptr =
+                    this.allocate_ptr(Size::from_bytes(new_size), new_align, kind.into())?;
+                Ok(new_ptr.into())
+            }
+        } else {
+            if new_size == 0 {
+                // C, in their infinite wisdom, made this UB.
+                // <https://www.open-std.org/jtc1/sc22/wg14/www/docs/n2464.pdf>
+                throw_ub_format!("`realloc` with a size of zero");
+            } else {
+                let new_ptr = this.reallocate_ptr(
+                    old_ptr,
+                    None,
+                    Size::from_bytes(new_size),
+                    new_align,
+                    kind.into(),
+                )?;
+                Ok(new_ptr.into())
+            }
+        }
+    }
+}
diff --git a/src/shims/foreign_items.rs b/src/shims/foreign_items.rs
index 6a6ad33e52..636361148a 100644
--- a/src/shims/foreign_items.rs
+++ b/src/shims/foreign_items.rs
@@ -1,7 +1,7 @@
 use std::{collections::hash_map::Entry, io::Write, iter, path::Path};
 
 use rustc_apfloat::Float;
-use rustc_ast::expand::allocator::{alloc_error_handler_name, AllocatorKind};
+use rustc_ast::expand::allocator::alloc_error_handler_name;
 use rustc_hir::{def::DefKind, def_id::CrateNum};
 use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags;
 use rustc_middle::mir;
@@ -12,6 +12,7 @@ use rustc_target::{
     spec::abi::Abi,
 };
 
+use super::alloc::{check_alloc_request, EvalContextExt as _};
 use super::backtrace::EvalContextExt as _;
 use crate::*;
 use helpers::{ToHost, ToSoft};
@@ -232,140 +233,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
             Some(instance) => Ok(Some((this.load_mir(instance.def, None)?, instance))),
         }
     }
-
-    fn malloc(
-        &mut self,
-        size: u64,
-        zero_init: bool,
-        kind: MiriMemoryKind,
-    ) -> InterpResult<'tcx, Pointer<Option<Provenance>>> {
-        let this = self.eval_context_mut();
-        if size == 0 {
-            Ok(Pointer::null())
-        } else {
-            let align = this.min_align(size, kind);
-            let ptr = this.allocate_ptr(Size::from_bytes(size), align, kind.into())?;
-            if zero_init {
-                // We just allocated this, the access is definitely in-bounds and fits into our address space.
-                this.write_bytes_ptr(
-                    ptr.into(),
-                    iter::repeat(0u8).take(usize::try_from(size).unwrap()),
-                )
-                .unwrap();
-            }
-            Ok(ptr.into())
-        }
-    }
-
-    fn free(
-        &mut self,
-        ptr: Pointer<Option<Provenance>>,
-        kind: MiriMemoryKind,
-    ) -> InterpResult<'tcx> {
-        let this = self.eval_context_mut();
-        if !this.ptr_is_null(ptr)? {
-            this.deallocate_ptr(ptr, None, kind.into())?;
-        }
-        Ok(())
-    }
-
-    fn realloc(
-        &mut self,
-        old_ptr: Pointer<Option<Provenance>>,
-        new_size: u64,
-        kind: MiriMemoryKind,
-    ) -> InterpResult<'tcx, Pointer<Option<Provenance>>> {
-        let this = self.eval_context_mut();
-        let new_align = this.min_align(new_size, kind);
-        if this.ptr_is_null(old_ptr)? {
-            // Here we must behave like `malloc`.
-            if new_size == 0 {
-                Ok(Pointer::null())
-            } else {
-                let new_ptr =
-                    this.allocate_ptr(Size::from_bytes(new_size), new_align, kind.into())?;
-                Ok(new_ptr.into())
-            }
-        } else {
-            if new_size == 0 {
-                // C, in their infinite wisdom, made this UB.
-                // <https://www.open-std.org/jtc1/sc22/wg14/www/docs/n2464.pdf>
-                throw_ub_format!("`realloc` with a size of zero");
-            } else {
-                let new_ptr = this.reallocate_ptr(
-                    old_ptr,
-                    None,
-                    Size::from_bytes(new_size),
-                    new_align,
-                    kind.into(),
-                )?;
-                Ok(new_ptr.into())
-            }
-        }
-    }
 }
 
 impl<'mir, 'tcx: 'mir> EvalContextExtPriv<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
 trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
-    /// Returns the minimum alignment for the target architecture for allocations of the given size.
-    fn min_align(&self, size: u64, kind: MiriMemoryKind) -> Align {
-        let this = self.eval_context_ref();
-        // List taken from `library/std/src/sys/pal/common/alloc.rs`.
-        // This list should be kept in sync with the one from libstd.
-        let min_align = match this.tcx.sess.target.arch.as_ref() {
-            "x86" | "arm" | "mips" | "mips32r6" | "powerpc" | "powerpc64" | "wasm32" => 8,
-            "x86_64" | "aarch64" | "mips64" | "mips64r6" | "s390x" | "sparc64" | "loongarch64" =>
-                16,
-            arch => bug!("unsupported target architecture for malloc: `{}`", arch),
-        };
-        // Windows always aligns, even small allocations.
-        // Source: <https://support.microsoft.com/en-us/help/286470/how-to-use-pageheap-exe-in-windows-xp-windows-2000-and-windows-server>
-        // But jemalloc does not, so for the C heap we only align if the allocation is sufficiently big.
-        if kind == MiriMemoryKind::WinHeap || size >= min_align {
-            return Align::from_bytes(min_align).unwrap();
-        }
-        // We have `size < min_align`. Round `size` *down* to the next power of two and use that.
-        fn prev_power_of_two(x: u64) -> u64 {
-            let next_pow2 = x.next_power_of_two();
-            if next_pow2 == x {
-                // x *is* a power of two, just use that.
-                x
-            } else {
-                // x is between two powers, so next = 2*prev.
-                next_pow2 / 2
-            }
-        }
-        Align::from_bytes(prev_power_of_two(size)).unwrap()
-    }
-
-    /// Emulates calling the internal __rust_* allocator functions
-    fn emulate_allocator(
-        &mut self,
-        default: impl FnOnce(&mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx>,
-    ) -> InterpResult<'tcx, EmulateForeignItemResult> {
-        let this = self.eval_context_mut();
-
-        let Some(allocator_kind) = this.tcx.allocator_kind(()) else {
-            // in real code, this symbol does not exist without an allocator
-            return Ok(EmulateForeignItemResult::NotSupported);
-        };
-
-        match allocator_kind {
-            AllocatorKind::Global => {
-                // When `#[global_allocator]` is used, `__rust_*` is defined by the macro expansion
-                // of this attribute. As such we have to call an exported Rust function,
-                // and not execute any Miri shim. Somewhat unintuitively doing so is done
-                // by returning `NotSupported`, which triggers the `lookup_exported_symbol`
-                // fallback case in `emulate_foreign_item`.
-                return Ok(EmulateForeignItemResult::NotSupported);
-            }
-            AllocatorKind::Default => {
-                default(this)?;
-                Ok(EmulateForeignItemResult::NeedsJumping)
-            }
-        }
-    }
-
     fn emulate_foreign_item_inner(
         &mut self,
         link_name: Symbol,
@@ -612,7 +483,7 @@ trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     let size = this.read_target_usize(size)?;
                     let align = this.read_target_usize(align)?;
 
-                    Self::check_alloc_request(size, align)?;
+                    check_alloc_request(size, align)?;
 
                     let memory_kind = match link_name.as_str() {
                         "__rust_alloc" => MiriMemoryKind::Rust,
@@ -646,7 +517,7 @@ trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     let size = this.read_target_usize(size)?;
                     let align = this.read_target_usize(align)?;
 
-                    Self::check_alloc_request(size, align)?;
+                    check_alloc_request(size, align)?;
 
                     let ptr = this.allocate_ptr(
                         Size::from_bytes(size),
@@ -710,7 +581,7 @@ trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     let new_size = this.read_target_usize(new_size)?;
                     // No need to check old_size; we anyway check that they match the allocation.
 
-                    Self::check_alloc_request(new_size, align)?;
+                    check_alloc_request(new_size, align)?;
 
                     let align = Align::from_bytes(align).unwrap();
                     let new_ptr = this.reallocate_ptr(
@@ -1102,16 +973,4 @@ trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         // i.e., if we actually emulated the function with one of the shims.
         Ok(EmulateForeignItemResult::NeedsJumping)
     }
-
-    /// Check some basic requirements for this allocation request:
-    /// non-zero size, power-of-two alignment.
-    fn check_alloc_request(size: u64, align: u64) -> InterpResult<'tcx> {
-        if size == 0 {
-            throw_ub_format!("creating allocation with size 0");
-        }
-        if !align.is_power_of_two() {
-            throw_ub_format!("creating allocation with non-power-of-two alignment {}", align);
-        }
-        Ok(())
-    }
 }
diff --git a/src/shims/mod.rs b/src/shims/mod.rs
index ea6120f757..85c9a202f7 100644
--- a/src/shims/mod.rs
+++ b/src/shims/mod.rs
@@ -1,5 +1,6 @@
 #![warn(clippy::arithmetic_side_effects)]
 
+mod alloc;
 mod backtrace;
 #[cfg(target_os = "linux")]
 pub mod ffi_support;
diff --git a/src/shims/unix/foreign_items.rs b/src/shims/unix/foreign_items.rs
index 3a56aa9138..c72d3bb3df 100644
--- a/src/shims/unix/foreign_items.rs
+++ b/src/shims/unix/foreign_items.rs
@@ -6,6 +6,7 @@ use rustc_span::Symbol;
 use rustc_target::abi::{Align, Size};
 use rustc_target::spec::abi::Abi;
 
+use crate::shims::alloc::EvalContextExt as _;
 use crate::shims::unix::*;
 use crate::*;
 use shims::foreign_items::EmulateForeignItemResult;
diff --git a/src/shims/unix/fs.rs b/src/shims/unix/fs.rs
index 31076fdfaf..ebf9f43c19 100644
--- a/src/shims/unix/fs.rs
+++ b/src/shims/unix/fs.rs
@@ -196,13 +196,12 @@ struct OpenDir {
     read_dir: ReadDir,
     /// The most recent entry returned by readdir().
     /// Will be freed by the next call.
-    entry: Pointer<Option<Provenance>>,
+    entry: Option<Pointer<Option<Provenance>>>,
 }
 
 impl OpenDir {
     fn new(read_dir: ReadDir) -> Self {
-        // We rely on `free` being a NOP on null pointers.
-        Self { read_dir, entry: Pointer::null() }
+        Self { read_dir, entry: None }
     }
 }
 
@@ -924,8 +923,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 let d_name_offset = dirent64_layout.fields.offset(4 /* d_name */).bytes();
                 let size = d_name_offset.checked_add(name_len).unwrap();
 
-                let entry =
-                    this.malloc(size, /*zero_init:*/ false, MiriMemoryKind::Runtime)?;
+                let entry = this.allocate_ptr(
+                    Size::from_bytes(size),
+                    dirent64_layout.align.abi,
+                    MiriMemoryKind::Runtime.into(),
+                )?;
+                let entry: Pointer<Option<Provenance>> = entry.into();
 
                 // If the host is a Unix system, fill in the inode number with its real value.
                 // If not, use 0 as a fallback value.
@@ -949,23 +952,25 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 let name_ptr = entry.offset(Size::from_bytes(d_name_offset), this)?;
                 this.write_bytes_ptr(name_ptr, name_bytes.iter().copied())?;
 
-                entry
+                Some(entry)
             }
             None => {
                 // end of stream: return NULL
-                Pointer::null()
+                None
             }
             Some(Err(e)) => {
                 this.set_last_error_from_io_error(e.kind())?;
-                Pointer::null()
+                None
             }
         };
 
         let open_dir = this.machine.dirs.streams.get_mut(&dirp).unwrap();
         let old_entry = std::mem::replace(&mut open_dir.entry, entry);
-        this.free(old_entry, MiriMemoryKind::Runtime)?;
+        if let Some(old_entry) = old_entry {
+            this.deallocate_ptr(old_entry, None, MiriMemoryKind::Runtime.into())?;
+        }
 
-        Ok(Scalar::from_maybe_pointer(entry, this))
+        Ok(Scalar::from_maybe_pointer(entry.unwrap_or_else(Pointer::null), this))
     }
 
     fn macos_fbsd_readdir_r(
@@ -1106,7 +1111,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         }
 
         if let Some(open_dir) = this.machine.dirs.streams.remove(&dirp) {
-            this.free(open_dir.entry, MiriMemoryKind::Runtime)?;
+            if let Some(entry) = open_dir.entry {
+                this.deallocate_ptr(entry, None, MiriMemoryKind::Runtime.into())?;
+            }
             drop(open_dir);
             Ok(0)
         } else {
diff --git a/src/shims/windows/foreign_items.rs b/src/shims/windows/foreign_items.rs
index de80df3c80..ec4c610148 100644
--- a/src/shims/windows/foreign_items.rs
+++ b/src/shims/windows/foreign_items.rs
@@ -8,6 +8,7 @@ use rustc_span::Symbol;
 use rustc_target::abi::Size;
 use rustc_target::spec::abi::Abi;
 
+use crate::shims::alloc::EvalContextExt as _;
 use crate::shims::os_str::bytes_to_os_str;
 use crate::*;
 use shims::foreign_items::EmulateForeignItemResult;