Skip to content

Commit

Permalink
Avoid keeping track of the copy ops when not necessary. (huggingface#239
Browse files Browse the repository at this point in the history
)
  • Loading branch information
LaurentMazare authored Jul 25, 2023
1 parent 944d70b commit be9c261
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1434,11 +1434,16 @@ impl Tensor {
/// Compared to clone, this copies the actual storage but may fail because of running out of
/// memory.
pub fn copy(&self) -> Result<Tensor> {
let op = if self.track_op() {
Some(Op::Copy(self.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
layout: self.layout.clone(),
op: Some(Op::Copy(self.clone())),
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
Expand Down Expand Up @@ -1571,12 +1576,12 @@ impl Tensor {
let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(
storage,
shape.clone(),
Some(Op::Copy(self.clone())),
false,
))
let op = if self.track_op() {
Some(Op::Copy(self.clone()))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false))
}
}

Expand Down

0 comments on commit be9c261

Please sign in to comment.