diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 144f1e98a5..b6d668dca8 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -58,7 +58,7 @@ pub enum UnaryOp { } #[derive(Clone)] -pub(crate) enum Op { +pub enum Op { Binary(Tensor, Tensor, BinaryOp), Unary(Tensor, UnaryOp), Cmp(Tensor, CmpOp), @@ -512,3 +512,63 @@ impl UnaryOpT for Relu { v } } + +/// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are +/// properly checked when creating a new value +#[derive(Clone)] +pub struct BackpropOp(Option); + +impl BackpropOp { + pub(crate) fn none() -> Self { + BackpropOp(None) + } + + pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self { + let op = if arg.track_op() { + Some(f(arg.clone())) + } else { + None + }; + Self(op) + } + + pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self { + let op = if arg1.track_op() || arg2.track_op() { + Some(f(arg1.clone(), arg2.clone())) + } else { + None + }; + Self(op) + } + + pub(crate) fn new3( + arg1: &Tensor, + arg2: &Tensor, + arg3: &Tensor, + f: impl Fn(Tensor, Tensor, Tensor) -> Op, + ) -> Self { + let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() { + Some(f(arg1.clone(), arg2.clone(), arg3.clone())) + } else { + None + }; + Self(op) + } + + pub(crate) fn new>(args: &[A], f: impl Fn(Vec) -> Op) -> Self { + let op = if args.iter().any(|arg| arg.as_ref().track_op()) { + let args: Vec = args.iter().map(|arg| arg.as_ref().clone()).collect(); + Some(f(args)) + } else { + None + }; + Self(op) + } +} + +impl std::ops::Deref for BackpropOp { + type Target = Option; + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a0b21299ba..28ecc35774 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,5 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp}; +use crate::op::{ + BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp, +}; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -33,7 +35,7 @@ pub struct Tensor_ { // that's tricky to encode in the current setup. storage: Arc>, layout: Layout, - op: Option, + op: BackpropOp, is_variable: bool, dtype: DType, device: Device, @@ -79,11 +81,7 @@ macro_rules! unary_op { let storage = self .storage() .unary_impl::(self.layout())?; - let op = if self.track_op() { - Some(Op::Unary(self.clone(), UnaryOp::$op_name)) - } else { - None - }; + let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name)); Ok(from_storage(storage, shape.clone(), op, false)) } }; @@ -98,11 +96,7 @@ macro_rules! binary_op { self.layout(), rhs.layout(), )?; - let op = if self.track_op() || rhs.track_op() { - Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name)) - } else { - None - }; + let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name)); Ok(from_storage(storage, shape.clone(), op, false)) } }; @@ -131,7 +125,7 @@ macro_rules! broadcast_binary_op { fn from_storage>( storage: Storage, shape: S, - op: Option, + op: BackpropOp, is_variable: bool, ) -> Tensor { let dtype = storage.dtype(); @@ -155,13 +149,14 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result { + let none = BackpropOp::none(); if is_variable { let shape = shape.into(); let storage = device.ones(&shape, dtype)?; - Ok(from_storage(storage, shape, None, is_variable)) + Ok(from_storage(storage, shape, none, is_variable)) } else { let storage = device.ones(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) + from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) } } @@ -199,13 +194,14 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result { + let none = BackpropOp::none(); if is_variable { let shape = shape.into(); let storage = device.zeros(&shape, dtype)?; - Ok(from_storage(storage, shape, None, is_variable)) + Ok(from_storage(storage, shape, none, is_variable)) } else { let storage = device.zeros(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) + from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape) } } @@ -246,7 +242,8 @@ impl Tensor { ) -> Result { let s = s.into(); let storage = device.rand_uniform(&s, dtype, lo, up)?; - Ok(from_storage(storage, s, None, is_variable)) + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. @@ -270,7 +267,8 @@ impl Tensor { ) -> Result { let s = s.into(); let storage = device.rand_normal(&s, dtype, mean, std)?; - Ok(from_storage(storage, s, None, is_variable)) + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled from a normal distribution with the @@ -297,7 +295,8 @@ impl Tensor { return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); } let storage = device.storage(array)?; - Ok(from_storage(storage, shape, None, is_variable)) + let none = BackpropOp::none(); + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor on the specified device using the content and shape of the input. @@ -352,7 +351,8 @@ impl Tensor { return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); } let storage = device.storage_owned(data)?; - Ok(from_storage(storage, shape, None, is_variable)) + let none = BackpropOp::none(); + Ok(from_storage(storage, shape, none, is_variable)) } /// Creates a new tensor initialized with values from the input vector. The number of elements @@ -500,26 +500,14 @@ impl Tensor { /// ``` pub fn affine(&self, mul: f64, add: f64) -> Result { let storage = self.storage().affine(self.layout(), mul, add)?; - let op = if self.track_op() { - Some(Op::Affine { - arg: self.clone(), - mul, - add, - }) - } else { - None - }; + let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add }); Ok(from_storage(storage, self.shape(), op, false)) } /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor. pub fn elu(&self, alpha: f64) -> Result { let storage = self.storage().elu(self.layout(), alpha)?; - let op = if self.track_op() { - Some(Op::Elu(self.clone(), alpha)) - } else { - None - }; + let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha)); Ok(from_storage(storage, self.shape(), op, false)) } @@ -554,11 +542,7 @@ impl Tensor { if start == 0 && dims[dim] == len { Ok(self.clone()) } else { - let op = if self.track_op() { - Some(Op::Narrow(self.clone(), dim, start, len)) - } else { - None - }; + let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len)); let layout = self.layout().narrow(dim, start, len)?; let tensor_ = Tensor_ { id: TensorId::new(), @@ -602,11 +586,7 @@ impl Tensor { let mut storage = self.storage().unary_impl::(self.layout())?; // The resulting storage is contiguous. storage.divide_by_sum_over_dim(shape, dim)?; - let op = if self.track_op() { - Some(Op::Softmax(self.clone(), dim)) - } else { - None - }; + let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim)); Ok(from_storage(storage, shape.clone(), op, false)) } } @@ -638,11 +618,7 @@ impl Tensor { let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let mut dims = self.dims().to_vec(); dims[dim] = 1; - let op = if self.track_op() { - Some(Op::Reduce(self.clone(), op, dims.to_vec())) - } else { - None - }; + let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())); let res = from_storage(storage, dims, op, false); if keepdim { Ok(res) @@ -660,11 +636,7 @@ impl Tensor { for &sum_dim in sum_dims.iter() { dims[sum_dim] = 1 } - let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec())) - } else { - None - }; + let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec())); let sum = from_storage(storage, dims, op, false); if keepdim { Ok(sum) @@ -738,11 +710,7 @@ impl Tensor { let storage = self .storage() .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; - let op = if self.track_op() { - Some(Op::Cmp(self.clone(), op)) - } else { - None - }; + let op = BackpropOp::new1(self, |a| Op::Cmp(a, op)); Ok(from_storage(storage, shape.dims(), op, false)) } @@ -807,16 +775,12 @@ impl Tensor { let storage = self.storage() .conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; - let op = if self.track_op() || kernel.track_op() { - Some(Op::Conv1D { - arg: self.clone(), - kernel: kernel.clone(), - padding, - stride, - }) - } else { - None - }; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { + arg, + kernel, + padding, + stride, + }); let out_dims = params.out_dims(); Ok(from_storage(storage, out_dims, op, false)) } @@ -867,11 +831,7 @@ impl Tensor { self.layout(), rhs.layout(), )?; - let op = if self.track_op() || rhs.track_op() { - Some(Op::Matmul(self.clone(), rhs.clone())) - } else { - None - }; + let op = BackpropOp::new2(self, rhs, Op::Matmul); Ok(from_storage(storage, c_shape, op, false)) } @@ -888,15 +848,7 @@ impl Tensor { &on_false.storage(), on_false.layout(), )?; - let op = if self.track_op() || on_true.track_op() || on_false.track_op() { - Some(Op::WhereCond( - self.clone(), - on_true.clone(), - on_false.clone(), - )) - } else { - None - }; + let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond); Ok(from_storage(storage, shape, op, false)) } @@ -937,11 +889,7 @@ impl Tensor { .storage() .embedding(ids.layout(), &rhs.storage(), rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); - let op = if ids.track_op() || rhs.track_op() { - Some(Op::Embedding(ids.clone(), rhs.clone())) - } else { - None - }; + let op = BackpropOp::new2(ids, rhs, Op::Embedding); Ok(from_storage(storage, shape, op, false)) } @@ -983,16 +931,9 @@ impl Tensor { source.layout(), dim, )?; - let op = if indexes.track_op() || self.track_op() { - Some(Op::ScatterAdd( - self.clone(), - indexes.clone(), - source.clone(), - dim, - )) - } else { - None - }; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::ScatterAdd(t1, t2, t3, dim) + }); Ok(from_storage(storage, self.shape(), op, false)) } @@ -1038,16 +979,9 @@ impl Tensor { source.layout(), dim, )?; - let op = if indexes.track_op() || self.track_op() { - Some(Op::IndexAdd( - self.clone(), - indexes.clone(), - source.clone(), - dim, - )) - } else { - None - }; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::IndexAdd(t1, t2, t3, dim) + }); Ok(from_storage(storage, self.shape(), op, false)) } @@ -1077,11 +1011,7 @@ impl Tensor { let storage = self.storage() .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; - let op = if indexes.track_op() || self.track_op() { - Some(Op::Gather(self.clone(), indexes.clone(), dim)) - } else { - None - }; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim)); Ok(from_storage(storage, indexes.shape(), op, false)) } @@ -1104,11 +1034,7 @@ impl Tensor { )?; let mut dims = self.dims().to_vec(); dims[dim] = indexes_len; - let op = if indexes.track_op() || self.track_op() { - Some(Op::IndexSelect(self.clone(), indexes.clone(), dim)) - } else { - None - }; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); Ok(from_storage(storage, dims, op, false)) } @@ -1404,11 +1330,7 @@ impl Tensor { pub fn transpose(&self, dim1: D1, dim2: D2) -> Result { let dim1 = dim1.to_index(self.shape(), "transpose")?; let dim2 = dim2.to_index(self.shape(), "transpose")?; - let op = if self.track_op() { - Some(Op::Transpose(self.clone(), dim1, dim2)) - } else { - None - }; + let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2)); let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), @@ -1434,11 +1356,7 @@ impl Tensor { /// Compared to clone, this copies the actual storage but may fail because of running out of /// memory. pub fn copy(&self) -> Result { - let op = if self.track_op() { - Some(Op::Copy(self.clone())) - } else { - None - }; + let op = BackpropOp::new1(self, Op::Copy); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)), @@ -1458,7 +1376,7 @@ impl Tensor { id: TensorId::new(), storage: self.storage.clone(), layout: self.layout.clone(), - op: None, + op: BackpropOp::none(), is_variable: false, dtype: self.dtype, device: self.device.clone(), @@ -1484,11 +1402,7 @@ impl Tensor { } (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), }; - let op = if self.track_op() { - Some(Op::ToDevice(self.clone())) - } else { - None - }; + let op = BackpropOp::new1(self, Op::ToDevice); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(RwLock::new(storage)), @@ -1519,16 +1433,11 @@ impl Tensor { /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If /// `i_a` is equal to 1, any value can be used. pub fn broadcast_as>(&self, shape: S) -> Result { - let op = if self.track_op() { - Some(Op::Broadcast(self.clone())) - } else { - None - }; let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), layout: self.layout.broadcast_as(shape)?, - op, + op: BackpropOp::new1(self, Op::Broadcast), is_variable: false, dtype: self.dtype, device: self.device.clone(), @@ -1557,11 +1466,7 @@ impl Tensor { } else { let shape = self.shape(); let storage = self.storage().to_dtype(self.layout(), dtype)?; - let op = if self.track_op() { - Some(Op::ToDType(self.clone())) - } else { - None - }; + let op = BackpropOp::new1(self, Op::ToDType); Ok(from_storage(storage, shape.clone(), op, false)) } } @@ -1576,11 +1481,7 @@ impl Tensor { let mut storage = self.device().zeros(shape, self.dtype())?; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; - let op = if self.track_op() { - Some(Op::Copy(self.clone())) - } else { - None - }; + let op = BackpropOp::new1(self, Op::Copy); Ok(from_storage(storage, shape.clone(), op, false)) } } @@ -1612,11 +1513,7 @@ impl Tensor { } .bt()); } - let op = if self.track_op() { - Some(Op::Reshape(self.clone())) - } else { - None - }; + let op = BackpropOp::new1(self, Op::Reshape); if self.is_contiguous() { let tensor_ = Tensor_ { id: TensorId::new(), @@ -1820,12 +1717,7 @@ impl Tensor { offsets.push(next_offset); } let shape = Shape::from(cat_dims); - let op = if args.iter().any(|arg| arg.as_ref().track_op()) { - let args: Vec = args.iter().map(|arg| arg.as_ref().clone()).collect(); - Some(Op::Cat(args, 0)) - } else { - None - }; + let op = BackpropOp::new(args, |args| Op::Cat(args, 0)); let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); @@ -1865,11 +1757,7 @@ impl Tensor { let (storage, shape) = self .storage() .custom_op1(self.layout(), c.as_ref().as_ref())?; - let op = if self.track_op() { - Some(Op::CustomOp1(self.clone(), c)) - } else { - None - }; + let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone())); Ok(from_storage(storage, shape, op, false)) } @@ -1885,11 +1773,7 @@ impl Tensor { rhs.layout(), c.as_ref().as_ref(), )?; - let op = if self.track_op() { - Some(Op::CustomOp2(self.clone(), rhs.clone(), c)) - } else { - None - }; + let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone())); Ok(from_storage(storage, shape, op, false)) } @@ -1907,11 +1791,9 @@ impl Tensor { t3.layout(), c.as_ref().as_ref(), )?; - let op = if self.track_op() { - Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c)) - } else { - None - }; + let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| { + Op::CustomOp3(t1, t2, t3, c.clone()) + }); Ok(from_storage(storage, shape, op, false)) }