diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index bf0902196e..7de4604456 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { name: &str, hints: B::Hints, ) -> Result { - let path = self.path(name); - self.data - .backend - .get(s.into(), &path, hints, self.data.dtype, &self.data.device) + self.get_with_hints_dtype(s, name, hints, self.data.dtype) } /// Retrieve the tensor associated with the given name at the current path. pub fn get>(&self, s: S, name: &str) -> Result { self.get_with_hints(s, name, Default::default()) } + + /// Retrieve the tensor associated with the given name & dtype at the current path. + pub fn get_with_hints_dtype>( + &self, + s: S, + name: &str, + hints: B::Hints, + dtype: DType, + ) -> Result { + let path = self.path(name); + self.data + .backend + .get(s.into(), &path, hints, dtype, &self.data.device) + } } struct Zeros;