From ad796eb4be9877712c0034d291a082cee1fd2dec Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 2 Sep 2023 14:41:48 +0200 Subject: [PATCH] More quantized llama in python. (#716) * More quantized llama in python. * Expose a couple more functions. * Apply the last layer. * Use the vocab from the ggml files. --- candle-pyo3/quant-llama.py | 19 +++++++++---- candle-pyo3/src/lib.rs | 56 ++++++++++++++++++++++++++++++++++---- 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index a3638855b5..092c1faa4d 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -117,7 +117,6 @@ def precompute_freqs_cis(hparams, freq_base): idx_theta = [float(i) for i in range(MAX_SEQ_LEN)] idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1)) m = idx_theta.matmul(theta.unsqueeze(0)) - print(m.shape) return (m.cos(), m.sin()) class QuantizedLlama: @@ -143,28 +142,36 @@ def __call__(self, token, index_pos): for layer in self.layers: x = layer(x, mask, index_pos) + x = self.norm(x) + x = x.narrow(1, -1, 1).squeeze(1) + x = self.output.matmul_t(x) return x def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") filename = sys.argv[1] + print(f"reading model file {filename}") if filename.endswith("gguf"): all_tensors = candle.load_gguf(sys.argv[1]) hparams = None + vocab = None else: - all_tensors, hparams = candle.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) + print("model built, starting inference") tokens = [1] - for token_idx in range(1): - print(tokens) + for token_idx in range(500): last_token = tokens[-1] lt = candle.tensor([last_token]).unsqueeze(0) logits = model(lt, len(tokens)) - print(logits) - next_token = "TODO: sample" + # Greedy sampling for now + # pr = candle.nn.softmax(logits, -1) + m = logits.get(0).argmax_keepdim(-1) + next_token = m.values()[0] + print(vocab[next_token], end='', flush=True) tokens.append(next_token) if __name__ == '__main__': diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 43d99c25f9..5e6f48ea84 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -145,6 +145,22 @@ pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result { + let dim = t.dim(dim)?; + if 0 <= index { + let index = index as usize; + if dim <= index { + ::candle::bail!("index {index} is too large for tensor dimension {dim}") + } + Ok(index) + } else { + if (dim as i64) < -index { + ::candle::bail!("index {index} is too low for tensor dimension {dim}") + } + Ok((dim as i64 + index) as usize) + } +} + fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result { let rank = t.rank(); if 0 <= dim { @@ -409,7 +425,8 @@ impl PyTensor { Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) } - fn squeeze(&self, dim: usize) -> PyResult { + fn squeeze(&self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?)) } @@ -417,7 +434,8 @@ impl PyTensor { Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?)) } - fn get(&self, index: usize) -> PyResult { + fn get(&self, index: i64) -> PyResult { + let index = actual_index(self, 0, index).map_err(wrap_err)?; Ok(PyTensor(self.0.get(index).map_err(wrap_err)?)) } @@ -425,11 +443,32 @@ impl PyTensor { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } - fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult { + fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult { let dim = actual_dim(self, dim).map_err(wrap_err)?; + let start = actual_index(self, dim, start).map_err(wrap_err)?; Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } + fn argmax_keepdim(&self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?)) + } + + fn argmin_keepdim(&self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?)) + } + + fn max_keepdim(&self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?)) + } + + fn min_keepdim(&self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?)) + } + fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult { let dims = if let Ok(dim) = dims.extract::(py) { vec![dim] @@ -661,7 +700,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult { } #[pyfunction] -fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { +fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; let tensors = ggml @@ -681,7 +720,14 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { ("ftype", ggml.hparams.ftype), ]; let hparams = hparams.into_py_dict(py).to_object(py); - Ok((tensors, hparams)) + let vocab = ggml + .vocab + .token_score_pairs + .iter() + .map(|(bytes, _)| String::from_utf8_lossy(bytes.as_slice()).to_string()) + .collect::>() + .to_object(py); + Ok((tensors, hparams, vocab)) } #[pyfunction]