Skip to content

Commit

Permalink
Add an abstract backprop op type (huggingface#240)
Browse files Browse the repository at this point in the history
* Start adding the backprop op type.

* More backprop ops.

* Finish the backprop op.
  • Loading branch information
LaurentMazare authored Jul 25, 2023
1 parent be9c261 commit c97d512
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 178 deletions.
62 changes: 61 additions & 1 deletion candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub enum UnaryOp {
}

#[derive(Clone)]
pub(crate) enum Op {
pub enum Op {
Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
Expand Down Expand Up @@ -512,3 +512,63 @@ impl UnaryOpT for Relu {
v
}
}

/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
/// properly checked when creating a new value
#[derive(Clone)]
pub struct BackpropOp(Option<Op>);

impl BackpropOp {
pub(crate) fn none() -> Self {
BackpropOp(None)
}

pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
let op = if arg.track_op() {
Some(f(arg.clone()))
} else {
None
};
Self(op)
}

pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
let op = if arg1.track_op() || arg2.track_op() {
Some(f(arg1.clone(), arg2.clone()))
} else {
None
};
Self(op)
}

pub(crate) fn new3(
arg1: &Tensor,
arg2: &Tensor,
arg3: &Tensor,
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
) -> Self {
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
} else {
None
};
Self(op)
}

pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(f(args))
} else {
None
};
Self(op)
}
}

impl std::ops::Deref for BackpropOp {
type Target = Option<Op>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
Loading

0 comments on commit c97d512

Please sign in to comment.