From af2104078f25865a02e18712986ee4b988d7affb Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 21 Sep 2024 13:18:42 +0200 Subject: [PATCH] Metal commands refactoring (#2489) * Split out the commands part of the metal device. * Make most fields private. * Move the allocator back. * Rework the encoder provider type. --- candle-core/src/metal_backend/device.rs | 195 +++++++++++++----------- candle-core/src/metal_backend/mod.rs | 17 +-- candle-metal-kernels/src/utils.rs | 33 +++- 3 files changed, 141 insertions(+), 104 deletions(-) diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 3deb465b85..29b8995bc9 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger} use std::collections::HashMap; use std::ffi::c_void; use std::path::Path; -use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard}; +use std::sync::{Arc, Mutex, RwLock}; use super::MetalError; @@ -22,19 +22,9 @@ impl DeviceId { } type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; -type AllocatedBuffers = Arc>; - -#[derive(Clone)] -pub struct MetalDevice { - /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than - /// the device itself. - pub(crate) id: DeviceId, - - /// Raw metal device: - pub(crate) device: metal::Device, - +pub(crate) struct Commands { /// Single command queue for the entire device. - pub(crate) command_queue: CommandQueue, + command_queue: CommandQueue, /// One command buffer at a time. /// The scheduler works by allowing multiple /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) @@ -44,16 +34,73 @@ pub struct MetalDevice { /// Despite what the documentation says, command buffers are NOT ordered. They are ordered /// for their START time, but there's no guarantee that command buffer1 will finish before /// command buffer2 starts (or there are metal bugs there) - pub(crate) command_buffer: Arc>, + command_buffer: CommandBuffer, /// Keeps track of the current amount of compute command encoders on the current /// command buffer /// Arc, RwLock because of the interior mutability. - pub(crate) command_buffer_index: Arc>, + command_buffer_index: usize, /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) - pub(crate) compute_per_buffer: usize, - /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. - /// Heavily used by [`candle_metal_kernels`] - pub(crate) kernels: Arc, + compute_per_buffer: usize, +} + +impl Commands { + pub(crate) fn new(command_queue: CommandQueue) -> Result { + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { + Ok(val) => val.parse()?, + _ => 50, + }; + Ok(Self { + command_queue, + command_buffer, + command_buffer_index: 0, + compute_per_buffer, + }) + } + + pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> { + let mut command_buffer = self.command_buffer.to_owned(); + let mut flushed = false; + if self.command_buffer_index > self.compute_per_buffer { + self.command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + self.command_buffer = command_buffer.clone(); + self.command_buffer_index = 0; + flushed = true; + } + self.command_buffer_index += 1; + Ok((flushed, command_buffer)) + } + + pub fn wait_until_completed(&mut self) -> Result<()> { + match self.command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Already committed"); + } + _ => {} + } + self.command_buffer.commit(); + self.command_buffer.wait_until_completed(); + self.command_buffer = self.command_queue.new_command_buffer().to_owned(); + + Ok(()) + } +} + +#[derive(Clone)] +pub struct MetalDevice { + /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than + /// the device itself. + pub(crate) id: DeviceId, + + /// Raw metal device: + pub(crate) device: metal::Device, + + pub(crate) commands: Arc>, + /// Simple allocator struct. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting @@ -67,7 +114,11 @@ pub struct MetalDevice { /// /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers /// (strong_count = 1). - pub(crate) buffers: AllocatedBuffers, + pub(crate) buffers: Arc>, + + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`] + pub(crate) kernels: Arc, /// Seed for random number generation. pub(crate) seed: Arc>, /// Whether to use the MLX matmul kernels instead of the MFA ones. @@ -101,44 +152,31 @@ impl MetalDevice { &self.device } - pub fn command_queue(&self) -> &CommandQueue { - &self.command_queue + fn drop_unused_buffers(&self) -> Result<()> { + let mut buffers = self.buffers.write().map_err(MetalError::from)?; + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(*s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(()) } pub fn command_buffer(&self) -> Result { - let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?; - let mut command_buffer = command_buffer_lock.to_owned(); - let mut index = self - .command_buffer_index - .write() - .map_err(MetalError::from)?; - if *index > self.compute_per_buffer { - command_buffer.commit(); - command_buffer = self.command_queue.new_command_buffer().to_owned(); - *command_buffer_lock = command_buffer.clone(); - *index = 0; - - self.drop_unused_buffers()?; + let mut commands = self.commands.write().map_err(MetalError::from)?; + let (flushed, command_buffer) = commands.command_buffer()?; + if flushed { + self.drop_unused_buffers()? } - *index += 1; Ok(command_buffer) } pub fn wait_until_completed(&self) -> Result<()> { - let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled - | metal::MTLCommandBufferStatus::Completed => { - panic!("Already committed"); - } - _ => {} - } - command_buffer.commit(); - command_buffer.wait_until_completed(); - *command_buffer = self.command_queue.new_command_buffer().to_owned(); - - Ok(()) + let mut commands = self.commands.write().map_err(MetalError::from)?; + commands.wait_until_completed() } pub fn kernels(&self) -> &Kernels { @@ -186,6 +224,7 @@ impl MetalDevice { MTLResourceOptions::StorageModeManaged, ); let mut buffers = self.buffers.write().map_err(MetalError::from)?; + let subbuffers = buffers .entry((size, MTLResourceOptions::StorageModeManaged)) .or_insert(vec![]); @@ -216,40 +255,6 @@ impl MetalDevice { Ok(buffer) } - fn find_available_buffer( - &self, - size: NSUInteger, - option: MTLResourceOptions, - buffers: &RwLockWriteGuard, - ) -> Option> { - let mut best_buffer: Option<&Arc> = None; - let mut best_buffer_size: NSUInteger = NSUInteger::MAX; - for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { - if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { - for sub in subbuffers { - if Arc::strong_count(sub) == 1 { - best_buffer = Some(sub); - best_buffer_size = *buffer_size; - } - } - } - } - best_buffer.cloned() - } - - fn drop_unused_buffers(&self) -> Result<()> { - let mut buffers = self.buffers.write().map_err(MetalError::from)?; - for subbuffers in buffers.values_mut() { - let newbuffers = subbuffers - .iter() - .filter(|s| Arc::strong_count(*s) > 1) - .map(Arc::clone) - .collect(); - *subbuffers = newbuffers; - } - Ok(()) - } - /// The critical allocator algorithm fn allocate_buffer( &self, @@ -258,7 +263,7 @@ impl MetalDevice { _name: &str, ) -> Result> { let mut buffers = self.buffers.write().map_err(MetalError::from)?; - if let Some(b) = self.find_available_buffer(size, option, &buffers) { + if let Some(b) = find_available_buffer(size, option, &buffers) { // Cloning also ensures we increment the strong count return Ok(b.clone()); } @@ -297,3 +302,23 @@ impl MetalDevice { fn buf_size(size: NSUInteger) -> NSUInteger { size.saturating_sub(1).next_power_of_two() as NSUInteger } + +fn find_available_buffer( + size: NSUInteger, + option: MTLResourceOptions, + buffers: &BufferMap, +) -> Option> { + let mut best_buffer: Option<&Arc> = None; + let mut best_buffer_size: NSUInteger = NSUInteger::MAX; + for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { + if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { + for sub in subbuffers { + if Arc::strong_count(sub) == 1 { + best_buffer = Some(sub); + best_buffer_size = *buffer_size; + } + } + } + } + best_buffer.cloned() +} diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 9c980db804..69edd2d1d2 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1864,33 +1864,22 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer().to_owned(); - command_buffer.enqueue(); - let command_buffer = Arc::new(RwLock::new(command_buffer)); - let command_buffer_index = Arc::new(RwLock::new(0)); let kernels = Arc::new(Kernels::new()); - let buffers = Arc::new(RwLock::new(HashMap::new())); let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() { Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false, Ok(_) => true, }; - let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { - Ok(val) => val.parse()?, - _ => 50, - }; let seed = Arc::new(Mutex::new(device.new_buffer_with_data( [299792458].as_ptr() as *const c_void, 4, MTLResourceOptions::StorageModeManaged, ))); + let commands = device::Commands::new(command_queue)?; Ok(Self { id: DeviceId::new(), device, - command_queue, - command_buffer, - command_buffer_index, - compute_per_buffer, - buffers, + commands: Arc::new(RwLock::new(commands)), + buffers: Arc::new(RwLock::new(HashMap::new())), kernels, seed, use_mlx_mm, diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 2ddd610b02..d2cc09f495 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -168,17 +168,22 @@ pub trait EncoderProvider { fn encoder(&self) -> Self::Encoder<'_>; } -pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef); +pub struct WrappedEncoder<'a> { + inner: &'a ComputeCommandEncoderRef, + end_encoding_on_drop: bool, +} impl<'a> Drop for WrappedEncoder<'a> { fn drop(&mut self) { - self.0.end_encoding() + if self.end_encoding_on_drop { + self.inner.end_encoding() + } } } impl<'a> AsRef for WrappedEncoder<'a> { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { - self.0 + self.inner } } @@ -187,7 +192,10 @@ impl EncoderProvider for &metal::CommandBuffer { where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { - WrappedEncoder(self.new_compute_command_encoder()) + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } } } @@ -196,6 +204,21 @@ impl EncoderProvider for &metal::CommandBufferRef { where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { - WrappedEncoder(self.new_compute_command_encoder()) + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } + } +} + +impl EncoderProvider for &ComputeCommandEncoderRef { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder(&self) -> Self::Encoder<'_> { + WrappedEncoder { + inner: self, + end_encoding_on_drop: false, + } } }