diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 60272c9b36..97631b0a19 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -16,6 +16,7 @@ doc = false [dependencies] candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.2.1" } half = { workspace = true } pyo3 = { version = "0.19.0", features = ["extension-module"] } diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py new file mode 100644 index 0000000000..a3638855b5 --- /dev/null +++ b/candle-pyo3/quant-llama.py @@ -0,0 +1,171 @@ +# This example shows how the candle Python api can be used to replicate llama.cpp. +import os +import sys + +# The "import candle" statement below works if there is a "candle.so" file in sys.path. +# Here we check for shared libraries that can be used in the build directory. +BUILD_DIR = "./target/release-with-debug" +so_file = BUILD_DIR + "/candle.so" +if os.path.islink(so_file): os.remove(so_file) +for lib_file in ["libcandle.dylib", "libcandle.so"]: + lib_file_ = BUILD_DIR + "/" + lib_file + if os.path.isfile(lib_file_): + os.symlink(lib_file, so_file) + sys.path.insert(0, BUILD_DIR) + break + +import candle + +MAX_SEQ_LEN = 4096 + +def masked_fill(on_false, mask, on_true): + shape = mask.shape + on_true = candle.tensor(on_true).broadcast_as(shape) + return mask.where_cond(on_true, on_false) + +class RmsNorm: + def __init__(self, qtensor): + self.weight = qtensor.dequantize() + + def __call__(self, x): + b_size, seq_len, hidden_size = x.shape + norm_x = x.sqr().sum_keepdim(2) / hidden_size + x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) + return x_normed.broadcast_mul(self.weight) + +class QuantizedLayer: + def __init__(self, layer_idx, hparams, all_tensors, cos_sin): + p = f"layers.{layer_idx}" + self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] + self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] + self.attention_wv = all_tensors[f"{p}.attention.wv.weight"] + self.attention_wo = all_tensors[f"{p}.attention.wo.weight"] + self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"] + self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"] + self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"] + self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"]) + self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"]) + + self.n_head = hparams["n_head"] + self.n_kv_head = self.n_head + self.head_dim = hparams["n_embd"] // self.n_head + + self.kv_cache = None + self.cos = cos_sin[0] + self.sin = cos_sin[1] + + def __call__(self, x, mask, index_pos): + residual = x + x = self.attn_norm(x) + attn = self.forward_attn(x, mask, index_pos) + x = attn + residual + + residual = x + x = self.ffn_norm(x) + w1 = self.ffw1.matmul_t(x) + w3 = self.ffw3.matmul_t(x) + mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3) + + return mlp + residual + + def forward_attn(self, x, mask, index_pos): + b_size, seq_len, n_embd = x.shape + q = self.attention_wq.matmul_t(x) + k = self.attention_wk.matmul_t(x) + v = self.attention_wv.matmul_t(x) + + q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2) + k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) + v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) + + q = self.apply_rotary_emb(q, index_pos) + k = self.apply_rotary_emb(k, index_pos) + + if self.kv_cache is not None and index_pos > 0: + prev_k, prev_v = self.kv_cache + k = candle.cat([prev_k, k], 2).contiguous() + v = candle.cat([prev_v, v], 2).contiguous() + + self.kv_cache = (k, v) + + # TODO: maybe repeat k/v here if we start supporting MQA. + + att = q.matmul(k.t()) / self.head_dim**0.5 + mask = mask.broadcast_as(att.shape) + att = masked_fill(att, mask, float("-inf")) + att = candle.nn.softmax(att, -1) + y = att.matmul(v.contiguous()) + y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) + return self.attention_wo.matmul_t(y) + + def apply_rotary_emb(self, x, index_pos): + (b_size, n_head, seq_len, n_embd) = x.shape + cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) + sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) + x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2)) + x0 = x.narrow(-1, 0, 1) + x1 = x.narrow(-1, 1, 1) + y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin) + y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos) + rope = candle.cat([y0, y1], -1) + return rope.flatten_from(-2) + +def precompute_freqs_cis(hparams, freq_base): + head_dim = hparams["n_embd"] // hparams["n_head"] + theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)] + theta = candle.tensor(theta) + idx_theta = [float(i) for i in range(MAX_SEQ_LEN)] + idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1)) + m = idx_theta.matmul(theta.unsqueeze(0)) + print(m.shape) + return (m.cos(), m.sin()) + +class QuantizedLlama: + def __init__(self, hparams, all_tensors): + self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() + self.norm = RmsNorm(all_tensors["norm.weight"]) + self.output = all_tensors["output.weight"] + self.layers = [] + cos_sin = precompute_freqs_cis(hparams, 10000.) + for layer_idx in range(hparams["n_layer"]): + layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) + self.layers.append(layer) + + def __call__(self, token, index_pos): + b_size, seq_len = token.shape + vocab_size, hidden_size = self.tok_embeddings.shape + token = token.reshape((b_size * seq_len,)) + x = self.tok_embeddings.index_select(token, 0) + x = x.reshape((b_size, seq_len, hidden_size)) + + mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)] + mask = candle.tensor(mask).reshape((seq_len, seq_len)) + + for layer in self.layers: + x = layer(x, mask, index_pos) + return x + +def main(): + if len(sys.argv) < 2: + raise ValueError("missing weight file argument") + filename = sys.argv[1] + if filename.endswith("gguf"): + all_tensors = candle.load_gguf(sys.argv[1]) + hparams = None + else: + all_tensors, hparams = candle.load_ggml(sys.argv[1]) + print(hparams) + model = QuantizedLlama(hparams, all_tensors) + + tokens = [1] + for token_idx in range(1): + print(tokens) + last_token = tokens[-1] + lt = candle.tensor([last_token]).unsqueeze(0) + logits = model(lt, len(tokens)) + print(logits) + next_token = "TODO: sample" + tokens.append(next_token) + +if __name__ == '__main__': + main() diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 2673d8437e..43d99c25f9 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -145,6 +145,22 @@ pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result { + let rank = t.rank(); + if 0 <= dim { + let dim = dim as usize; + if rank <= dim { + ::candle::bail!("dimension index {dim} is too large for tensor rank {rank}") + } + Ok(dim) + } else { + if (rank as i64) < -dim { + ::candle::bail!("dimension index {dim} is too low for tensor rank {rank}") + } + Ok((rank as i64 + dim) as usize) + } +} + // TODO: Something similar to this should probably be a part of candle core. trait MapDType { type Output; @@ -182,7 +198,10 @@ impl PyTensor { } else if let Ok(vs) = vs.extract::>(py) { Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? } else { - Err(PyTypeError::new_err("incorrect type for tensor"))? + let ty = vs.as_ref(py).get_type(); + Err(PyTypeError::new_err(format!( + "incorrect type {ty} for tensor" + )))? }; Ok(Self(tensor)) } @@ -295,10 +314,31 @@ impl PyTensor { Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?)) } + fn index_select(&self, rhs: &Self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) + } + fn matmul(&self, rhs: &Self) -> PyResult { Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) } + fn broadcast_add(&self, rhs: &Self) -> PyResult { + Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?)) + } + + fn broadcast_sub(&self, rhs: &Self) -> PyResult { + Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?)) + } + + fn broadcast_mul(&self, rhs: &Self) -> PyResult { + Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?)) + } + + fn broadcast_div(&self, rhs: &Self) -> PyResult { + Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?)) + } + fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult { Ok(PyTensor( self.0.where_cond(on_true, on_false).map_err(wrap_err)?, @@ -346,6 +386,17 @@ impl PyTensor { Ok(Self(tensor)) } + fn __truediv__(&self, rhs: &PyAny) -> PyResult { + let tensor = if let Ok(rhs) = rhs.extract::() { + (&self.0 / &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::() { + (&self.0 / rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported rhs for div"))? + }; + Ok(Self(tensor)) + } + fn reshape(&self, shape: PyShape) -> PyResult { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } @@ -374,7 +425,8 @@ impl PyTensor { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } - fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult { + fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } @@ -400,6 +452,16 @@ impl PyTensor { Ok(PyTensor(mean)) } + fn flatten_from(&self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?)) + } + + fn flatten_to(&self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?)) + } + fn flatten_all(&self) -> PyResult { Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?)) } @@ -467,7 +529,11 @@ impl PyTensor { /// Concatenate the tensors across one axis. #[pyfunction] -fn cat(tensors: Vec, dim: usize) -> PyResult { +fn cat(tensors: Vec, dim: i64) -> PyResult { + if tensors.is_empty() { + return Err(PyErr::new::("empty input to cat")); + } + let dim = actual_dim(&tensors[0], dim).map_err(wrap_err)?; let tensors = tensors.into_iter().map(|t| t.0).collect::>(); let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?; Ok(PyTensor(tensor)) @@ -595,16 +661,27 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult { } #[pyfunction] -fn load_ggml(path: &str, py: Python<'_>) -> PyResult { +fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; - let res = ggml + let tensors = ggml .tensors .into_iter() .map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))) .collect::<::candle::Result>>() .map_err(wrap_err)?; - Ok(res.into_py_dict(py).to_object(py)) + let tensors = tensors.into_py_dict(py).to_object(py); + let hparams = [ + ("n_vocab", ggml.hparams.n_vocab), + ("n_embd", ggml.hparams.n_embd), + ("n_mult", ggml.hparams.n_mult), + ("n_head", ggml.hparams.n_head), + ("n_layer", ggml.hparams.n_layer), + ("n_rot", ggml.hparams.n_rot), + ("ftype", ggml.hparams.ftype), + ]; + let hparams = hparams.into_py_dict(py).to_object(py); + Ok((tensors, hparams)) } #[pyfunction] @@ -651,11 +728,33 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } +#[pyfunction] +fn softmax(t: PyTensor, dim: i64) -> PyResult { + let dim = actual_dim(&t, dim).map_err(wrap_err)?; + let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?; + Ok(PyTensor(sm)) +} + +#[pyfunction] +fn silu(t: PyTensor) -> PyResult { + let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(silu, m)?)?; + m.add_function(wrap_pyfunction!(softmax, m)?)?; + Ok(()) +} + #[pymodule] fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let utils = PyModule::new(py, "utils")?; candle_utils(py, utils)?; m.add_submodule(utils)?; + let nn = PyModule::new(py, "nn")?; + candle_nn_m(py, nn)?; + m.add_submodule(nn)?; m.add_class::()?; m.add_class::()?; m.add_class::()?;