From 904bbdae65d69aac0c54c29eef744ca5e69c6733 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Fri, 6 Oct 2023 20:01:07 +0200 Subject: [PATCH] Make the Python Wrapper more Hackable and simplify Quantization (#1010) * Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation --- .vscode/settings.json | 11 + candle-pyo3/.gitignore | 1 + candle-pyo3/e5.py | 104 +++ candle-pyo3/py_src/candle/__init__.py | 29 +- .../py_src/candle/functional/__init__.py | 8 + .../candle/{nn => functional}/__init__.pyi | 21 + candle-pyo3/py_src/candle/models/bert.py | 194 +++++ candle-pyo3/py_src/candle/models/llama.py | 150 ++++ candle-pyo3/py_src/candle/nn/__init__.py | 10 +- candle-pyo3/py_src/candle/nn/container.py | 483 ++++++++++++ candle-pyo3/py_src/candle/nn/linear.py | 119 +++ candle-pyo3/py_src/candle/nn/module.py | 702 ++++++++++++++++++ candle-pyo3/py_src/candle/nn/normalization.py | 54 ++ candle-pyo3/py_src/candle/nn/sparse.py | 39 + candle-pyo3/py_src/candle/typing/__init__.py | 8 +- candle-pyo3/pyproject.toml | 4 + candle-pyo3/quant-llama.py | 197 +---- candle-pyo3/src/lib.rs | 134 +++- candle-pyo3/stub.py | 14 +- candle-pyo3/test.py | 2 +- candle-pyo3/tests/__init__.py | 0 candle-pyo3/tests/bindings/test_linear.py | 38 + candle-pyo3/tests/bindings/test_module.py | 161 ++++ candle-pyo3/tests/native/test_tensor.py | 74 ++ candle-pyo3/tests/native/test_utils.py | 51 ++ 25 files changed, 2426 insertions(+), 182 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 candle-pyo3/e5.py create mode 100644 candle-pyo3/py_src/candle/functional/__init__.py rename candle-pyo3/py_src/candle/{nn => functional}/__init__.pyi (54%) create mode 100644 candle-pyo3/py_src/candle/models/bert.py create mode 100644 candle-pyo3/py_src/candle/models/llama.py create mode 100644 candle-pyo3/py_src/candle/nn/container.py create mode 100644 candle-pyo3/py_src/candle/nn/linear.py create mode 100644 candle-pyo3/py_src/candle/nn/module.py create mode 100644 candle-pyo3/py_src/candle/nn/normalization.py create mode 100644 candle-pyo3/py_src/candle/nn/sparse.py create mode 100644 candle-pyo3/tests/__init__.py create mode 100644 candle-pyo3/tests/bindings/test_linear.py create mode 100644 candle-pyo3/tests/bindings/test_module.py create mode 100644 candle-pyo3/tests/native/test_tensor.py create mode 100644 candle-pyo3/tests/native/test_utils.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..b2dbd68012 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.formatting.provider": "none", + "python.testing.pytestArgs": [ + "candle-pyo3" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/candle-pyo3/.gitignore b/candle-pyo3/.gitignore index 68bc17f9ff..3d1f96fbbb 100644 --- a/candle-pyo3/.gitignore +++ b/candle-pyo3/.gitignore @@ -1,3 +1,4 @@ +tests/_workdir # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/candle-pyo3/e5.py b/candle-pyo3/e5.py new file mode 100644 index 0000000000..a0af0c5608 --- /dev/null +++ b/candle-pyo3/e5.py @@ -0,0 +1,104 @@ +from candle.utils import load_safetensors, save_gguf, load_gguf +from candle.models.bert import BertModel, Config +import json +from candle import Tensor +from tqdm import tqdm +from dataclasses import fields +import os +import time + +from huggingface_hub import hf_hub_download +from transformers import BertTokenizer, AutoModel +import torch + +if __name__ == "__main__": + model_name = "intfloat/e5-small-v2" + model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors") + config_file = hf_hub_download(repo_id=model_name, filename="config.json") + + tensors = load_safetensors(model_file) + config = Config() + with open(config_file, "r") as f: + raw_config = json.load(f) + for field in fields(config): + if field.name in raw_config: + setattr(config, field.name, raw_config[field.name]) + + # Load the model + model = BertModel(config) + model.load_state_dict(tensors) + + hf_model = AutoModel.from_pretrained(model_name) + tokenizer = BertTokenizer.from_pretrained(model_name) + + sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ] + + def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor): + """Average the hidden states according to the attention mask""" + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + tokenized = tokenizer(sentences, padding=True) + tokens = Tensor(tokenized["input_ids"]) + token_type_ids = Tensor(tokenized["token_type_ids"]) + encoder_out, _ = model.forward(tokens, token_type_ids) + + hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt") + hf_result = hf_model(**hf_tokenized)["last_hidden_state"] + + hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"]) + candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"]) + + loss = torch.nn.L1Loss() + error = loss(hf_pooled, candle_pooled).mean().item() + print(f"Mean error between torch-referenze and candle: {error}") + + # Quantize all attention 'weights' + quantized_tensors = {} + for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"): + if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name): + # check if the tensor is k-quantizable + if tensor.shape[-1] % 256 == 0: + new_tensor = tensor.quantize("q4k") + else: + new_tensor = tensor.quantize("q5_0") + quantized_tensors[name] = new_tensor + else: + quantized_tensors[name] = tensor.quantize("q8_0") + + print(f"Saving quantized tensors") + # Remove all None values from the config + config_to_save = {k: v for k, v in config.__dict__.items() if v is not None} + # Save the model + quantized_model_file = "e5_small.gguf" + save_gguf(quantized_model_file, quantized_tensors, config_to_save) + + file_size_mb = os.path.getsize(model_file) / 1024 / 1024 + file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024 + print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB") + # Load the model from the gguf + tensors, raw_config = load_gguf(quantized_model_file) + config = Config() + for field in fields(config): + if field.name in raw_config: + setattr(config, field.name, raw_config[field.name]) + model = BertModel(config) + # "embeddings.position_ids" is missing in the gguf as it is i64 + model.load_state_dict(tensors, strict=False) + + # Run the model again + encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids) + encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu") + + candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"]) + error = loss(hf_pooled, candle_pooled_2).mean().item() + print(f"Mean error between torch-referenze and quantized-candle: {error}") diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py index 951609cce6..dc97b775d5 100644 --- a/candle-pyo3/py_src/candle/__init__.py +++ b/candle-pyo3/py_src/candle/__init__.py @@ -1,5 +1,30 @@ -from .candle import * +import logging + +try: + from .candle import * +except ImportError as e: + # If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here + logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...") + import os + import platform + + # Try to locate CUDA_PATH environment variable + cuda_path = os.environ.get("CUDA_PATH", None) + if cuda_path: + logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}") + if platform.system() == "Windows": + cuda_path = os.path.join(cuda_path, "bin") + else: + cuda_path = os.path.join(cuda_path, "lib64") + + logging.warning(f"Adding {cuda_path} to DLL search path...") + os.add_dll_directory(cuda_path) + + try: + from .candle import * + except ImportError as inner_e: + raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.") __doc__ = candle.__doc__ if hasattr(candle, "__all__"): - __all__ = candle.__all__ \ No newline at end of file + __all__ = candle.__all__ diff --git a/candle-pyo3/py_src/candle/functional/__init__.py b/candle-pyo3/py_src/candle/functional/__init__.py new file mode 100644 index 0000000000..efb246f066 --- /dev/null +++ b/candle-pyo3/py_src/candle/functional/__init__.py @@ -0,0 +1,8 @@ +# Generated content DO NOT EDIT +from .. import functional + +gelu = functional.gelu +relu = functional.relu +silu = functional.silu +softmax = functional.softmax +tanh = functional.tanh diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi similarity index 54% rename from candle-pyo3/py_src/candle/nn/__init__.pyi rename to candle-pyo3/py_src/candle/functional/__init__.pyi index 01b30fcedf..a46b6137a9 100644 --- a/candle-pyo3/py_src/candle/nn/__init__.pyi +++ b/candle-pyo3/py_src/candle/functional/__init__.pyi @@ -4,6 +4,20 @@ from os import PathLike from candle.typing import _ArrayLike, Device from candle import Tensor, DType, QTensor +@staticmethod +def gelu(tensor: Tensor) -> Tensor: + """ + Applies the Gaussian Error Linear Unit (GELU) function to a given tensor. + """ + pass + +@staticmethod +def relu(tensor: Tensor) -> Tensor: + """ + Applies the Rectified Linear Unit (ReLU) function to a given tensor. + """ + pass + @staticmethod def silu(tensor: Tensor) -> Tensor: """ @@ -17,3 +31,10 @@ def softmax(tensor: Tensor, dim: int) -> Tensor: Applies the Softmax function to a given tensor.# """ pass + +@staticmethod +def tanh(tensor: Tensor) -> Tensor: + """ + Applies the tanh function to a given tensor. + """ + pass diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py new file mode 100644 index 0000000000..0a773f939d --- /dev/null +++ b/candle-pyo3/py_src/candle/models/bert.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Optional +from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList +from candle import Tensor +import candle +import candle.functional as F +from typing import Tuple, Optional + + +@dataclass +class Config: + vocab_size: int = 30522 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + max_position_embeddings: int = 512 + type_vocab_size: int = 2 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + pad_token_id: int = 0 + position_embedding_type: str = "absolute" + use_cache: bool = True + classifier_dropout: Optional[float] = None + model_type: Optional[str] = "bert" + + +class BertSelfAttention(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + all_head_size = int(config.num_attention_heads * self.attention_head_size) + hidden_size = config.hidden_size + self.query = Linear(hidden_size, all_head_size) + self.key = Linear(hidden_size, all_head_size) + self.value = Linear(hidden_size, all_head_size) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + new_x_shape = x.shape[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.reshape(new_x_shape).transpose(1, 2) + return x.contiguous() + + def forward(self, hidden_states: Tensor) -> Tensor: + query = self.query.forward(hidden_states) + key = self.key.forward(hidden_states) + value = self.value.forward(hidden_states) + + query = self.transpose_for_scores(query) + key = self.transpose_for_scores(key) + value = self.transpose_for_scores(value) + + attention_scores = query.matmul(key.t()) + attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5) + attention_probs = F.softmax(attention_scores, dim=-1) + + context_layer = attention_probs.matmul(value) + context_layer = context_layer.transpose(1, 2).contiguous() + context_layer = context_layer.flatten_from(-2) + return context_layer + + +class BertSelfOutput(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense.forward(hidden_states) + return self.LayerNorm.forward(hidden_states + input_tensor) + + +class BertAttention(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, hidden_states: Tensor) -> Tensor: + self_outputs = self.self.forward(hidden_states) + attention_output = self.output.forward(self_outputs, hidden_states) + return attention_output + + +class BertIntermediate(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.hidden_size, config.intermediate_size) + self.act = F.gelu if config.hidden_act == "gelu" else F.relu + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense.forward(hidden_states) + return self.act(hidden_states) + + +class BertOutput(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense.forward(hidden_states) + return self.LayerNorm.forward(hidden_states + input_tensor) + + +class BertLayer(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states: Tensor) -> Tensor: + attention_output = self.attention.forward(hidden_states) + # TODO: Support cross-attention? + # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + # TODO: Support something similar to `apply_chunking_to_forward`? + intermediate_output = self.intermediate.forward(attention_output) + layer_output = self.output.forward(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.layer = ModuleList() + for _ in range(config.num_hidden_layers): + self.layer.append(BertLayer(config)) + + def forward(self, hidden_states: Tensor) -> Tensor: + for l in self.layer: + hidden_states = l.forward(hidden_states) + return hidden_states + + +class BertEmbeddings(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape( + (1, config.max_position_embeddings) + ) + + def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor: + (_batch_size, seq_len) = input_ids.shape + input_embeddings = self.word_embeddings.forward(input_ids) + token_type_embeddings = self.token_type_embeddings.forward(token_type_ids) + embeddings: Tensor = input_embeddings + token_type_embeddings + + position_ids = list(range(seq_len)) + position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device) + + embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids)) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertPooler(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.hidden_size, config.hidden_size) + self.activation = F.tanh + + def forward(self, hidden_states: Tensor) -> Tensor: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense.forward(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 +class BertModel(Module): + def __init__(self, config: Config, add_pooling_layer=True) -> None: + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]: + embeddings = self.embeddings.forward(input_ids, token_type_ids) + encoder_out = self.encoder.forward(embeddings) + pooled_output = self.pooler(encoder_out) if self.pooler is not None else None + return encoder_out, pooled_output diff --git a/candle-pyo3/py_src/candle/models/llama.py b/candle-pyo3/py_src/candle/models/llama.py new file mode 100644 index 0000000000..fd9b30af3d --- /dev/null +++ b/candle-pyo3/py_src/candle/models/llama.py @@ -0,0 +1,150 @@ +import candle +from typing import Dict, Tuple, Any +from candle import Tensor, QTensor, utils, nn +from candle.nn import Module, ModuleList + + +def masked_fill(on_false: Tensor, mask: Tensor, on_true: Tensor): + shape = mask.shape + on_true = candle.tensor(on_true).broadcast_as(shape) + return mask.where_cond(on_true, on_false) + + +def precompute_freqs_cis(hparams: Dict[str, Any], freq_base: float, max_seq_len: int): + head_dim = hparams["n_embd"] // hparams["n_head"] + theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)] + theta = candle.tensor(theta) + idx_theta = [float(i) for i in range(max_seq_len)] + idx_theta = candle.tensor(idx_theta).reshape((max_seq_len, 1)) + m = idx_theta.matmul(theta.unsqueeze(0)) + return (m.cos(), m.sin()) + + +class RmsNorm(Module): + def __init__(self, qtensor: QTensor): + super().__init__() + self.weight = qtensor.dequantize() + + def forward(self, x: Tensor) -> Tensor: + b_size, seq_len, hidden_size = x.shape + norm_x = x.sqr().sum_keepdim(2) / hidden_size + x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) + return x_normed.broadcast_mul(self.weight) + + +class QuantizedLayer(Module): + def __init__( + self, + layer_idx: int, + hparams: Dict[str, Any], + all_tensors: Dict[str, QTensor], + cos_sin: Tuple[Tensor, Tensor], + ): + super().__init__() + p = f"layers.{layer_idx}" + self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] + self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] + self.attention_wv = all_tensors[f"{p}.attention.wv.weight"] + self.attention_wo = all_tensors[f"{p}.attention.wo.weight"] + self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"] + self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"] + self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"] + self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"]) + self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"]) + + self.n_head = hparams["n_head"] + self.n_kv_head = self.n_head + self.head_dim = hparams["n_embd"] // self.n_head + + self.kv_cache = None + self.cos = cos_sin[0] + self.sin = cos_sin[1] + self._non_persistent_buffers_set.add("cos") + self._non_persistent_buffers_set.add("sin") + + def forward(self, x: Tensor, mask: Tensor, index_pos: int) -> Tensor: + residual = x + x = self.attn_norm(x) + attn = self.forward_attn(x, mask, index_pos) + x = attn + residual + + residual = x + x = self.ffn_norm(x) + w1 = self.ffw1.matmul_t(x) + w3 = self.ffw3.matmul_t(x) + mlp = self.ffw2.matmul_t(nn.silu(w1) * w3) + + return mlp + residual + + def forward_attn(self, x: Tensor, mask: Tensor, index_pos: int): + b_size, seq_len, n_embd = x.shape + q = self.attention_wq.matmul_t(x) + k = self.attention_wk.matmul_t(x) + v = self.attention_wv.matmul_t(x) + + q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2) + k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) + v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) + + q = self.apply_rotary_emb(q, index_pos) + k = self.apply_rotary_emb(k, index_pos) + + if self.kv_cache is not None and index_pos > 0: + prev_k, prev_v = self.kv_cache + k = candle.cat([prev_k, k], 2).contiguous() + v = candle.cat([prev_v, v], 2).contiguous() + + self.kv_cache = (k, v) + + # TODO: maybe repeat k/v here if we start supporting MQA. + + att = q.matmul(k.t()) / self.head_dim**0.5 + mask = mask.broadcast_as(att.shape) + att = masked_fill(att, mask, float("-inf")) + att = nn.softmax(att, -1) + y = att.matmul(v.contiguous()) + y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) + return self.attention_wo.matmul_t(y) + + def apply_rotary_emb(self, x: Tensor, index_pos: int): + b_size, n_head, seq_len, n_embd = x.shape + cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1)) + sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1)) + x = x.reshape((b_size, n_head, seq_len, n_embd // 2, 2)) + x0 = x.narrow(-1, 0, 1) + x1 = x.narrow(-1, 1, 1) + y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin) + y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos) + rope = candle.cat([y0, y1], -1) + return rope.flatten_from(-2) + + +class QuantizedLlama(Module): + def __init__(self, hparams: Dict[str, Any], all_tensors: Dict[str, QTensor]): + super().__init__() + self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() + self.norm = RmsNorm(all_tensors["norm.weight"]) + self.output = all_tensors["output.weight"] + self.layers = ModuleList() + rope_freq = hparams.get("rope_freq", 10000.0) + cos_sin = precompute_freqs_cis(hparams, rope_freq, hparams["context_length"]) + for layer_idx in range(hparams["n_layer"]): + layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) + self.layers.append(layer) + + def forward(self, token: Tensor, index_pos: int) -> Tensor: + b_size, seq_len = token.shape + vocab_size, hidden_size = self.tok_embeddings.shape + token = token.reshape((b_size * seq_len,)) + x = self.tok_embeddings.index_select(token, 0) + x = x.reshape((b_size, seq_len, hidden_size)) + + mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)] + mask = candle.tensor(mask).reshape((seq_len, seq_len)) + + for layer in self.layers: + x = layer(x, mask, index_pos) + x = self.norm(x) + x = x.narrow(1, -1, 1).squeeze(1) + x = self.output.matmul_t(x) + return x diff --git a/candle-pyo3/py_src/candle/nn/__init__.py b/candle-pyo3/py_src/candle/nn/__init__.py index b8c5cfb773..8da0e8aa81 100644 --- a/candle-pyo3/py_src/candle/nn/__init__.py +++ b/candle-pyo3/py_src/candle/nn/__init__.py @@ -1,5 +1,5 @@ -# Generated content DO NOT EDIT -from .. import nn - -silu = nn.silu -softmax = nn.softmax +from .module import Module +from .container import Sequential, ModuleList, ModuleDict +from .sparse import Embedding +from .normalization import LayerNorm +from .linear import Linear diff --git a/candle-pyo3/py_src/candle/nn/container.py b/candle-pyo3/py_src/candle/nn/container.py new file mode 100644 index 0000000000..15ed8dd236 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/container.py @@ -0,0 +1,483 @@ +# see https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py +from .module import Module +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + overload, + Tuple, + TypeVar, + Union, +) +from collections import OrderedDict, abc as container_abcs +import operator +from itertools import chain, islice + +__all__ = ["Sequential", "ModuleList", "ModuleDict"] + +T = TypeVar("T", bound=Module) + + +def _addindent(s_: str, numSpaces: int): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +class Sequential(Module): + r"""A sequential container. + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``OrderedDict`` of modules can be + passed in. The ``forward()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :class:`candle.nn.ModuleList`? A ``ModuleList`` is exactly what it + sounds like--a list for storing ``Module`` s! On the other hand, + the layers in a ``Sequential`` are connected in a cascading way. + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + @overload + def __init__(self, *args: Module) -> None: + ... + + @overload + def __init__(self, arg: "OrderedDict[str, Module]") -> None: + ... + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator, idx) -> T: + """Get the idx-th item of the iterator""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError("index {} is out of range".format(idx)) + idx %= size + return next(islice(iterator, idx, None)) + + def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> "Sequential": + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError( + "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other))) + ) + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> "Sequential": + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError( + "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other))) + ) + + def __mul__(self, other: int) -> "Sequential": + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif other <= 0: + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> "Sequential": + return self.__mul__(other) + + def __imul__(self, other: int) -> "Sequential": + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif other <= 0: + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + # NB: We can't really type check this function as the type of input + # may change dynamically (as is tested in + # TestScript.test_sequential_intermediary_types). Cannot annotate + # with Any as TorchScript expects a more precise type + def forward(self, input): + for module in self: + input = module(input) + return input + + def append(self, module: Module) -> "Sequential": + r"""Appends a given module to the end. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def insert(self, index: int, module: Module) -> "Sequential": + if not isinstance(module, Module): + raise AssertionError("module should be of type: {}".format(Module)) + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError("Index out of range: {}".format(index)) + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential) -> "Sequential": + for layer in sequential: + self.append(layer) + return self + + +class ModuleList(Module): + r"""Holds submodules in a list. + + :class:`~candle.nn.ModuleList` can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + :class:`~candle.nn.Module` methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError("index {} is out of range".format(idx)) + if idx < 0: + idx += len(self) + return str(idx) + + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> "ModuleList": + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> "ModuleList": + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __repr__(self): + """A custom repr for ModuleList that compresses repeated module representations""" + list_of_reprs = [repr(item) for item in self] + if len(list_of_reprs) == 0: + return self._get_name() + "()" + + start_end_indices = [[0, 0]] + repeated_blocks = [list_of_reprs[0]] + for i, r in enumerate(list_of_reprs[1:], 1): + if r == repeated_blocks[-1]: + start_end_indices[-1][1] += 1 + continue + + start_end_indices.append([i, i]) + repeated_blocks.append(r) + + lines = [] + main_str = self._get_name() + "(" + for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): + local_repr = f"({start_id}): {b}" # default repr + + if start_id != end_id: + n = end_id - start_id + 1 + local_repr = f"({start_id}-{end_id}): {n} x {b}" + + local_repr = _addindent(local_repr, 2) + lines.append(local_repr) + + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + r"""Insert a given module before a given index in the list. + + Args: + index (int): index to insert. + module (nn.Module): module to insert + """ + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def append(self, module: Module) -> "ModuleList": + r"""Appends a given module to the end of the list. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> "ModuleList": + r"""Appends modules from a Python iterable to the end of the list. + + Args: + modules (iterable): iterable of modules to append + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleList.extend should be called with an " "iterable, but got " + type(modules).__name__ + ) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ModuleDict(Module): + r"""Holds submodules in a dictionary. + + :class:`~candle.nn.ModuleDict` can be indexed like a regular Python dictionary, + but modules it contains are properly registered, and will be visible by all + :class:`~candle.nn.Module` methods. + + :class:`~candle.nn.ModuleDict` is an **ordered** dictionary that respects + + * the order of insertion, and + + * in :meth:`~candle.nn.ModuleDict.update`, the order of the merged + ``OrderedDict``, ``dict`` (started from Python 3.6) or another + :class:`~candle.nn.ModuleDict` (the argument to + :meth:`~candle.nn.ModuleDict.update`). + + Note that :meth:`~candle.nn.ModuleDict.update` with other unordered mapping + types (e.g., Python's plain ``dict`` before Python version 3.6) does not + preserve the order of the merged mapping. + + Args: + modules (iterable, optional): a mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module) + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + def items(self) -> Iterable[Tuple[str, Module]]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + def values(self) -> Iterable[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~candle.nn.ModuleDict` with the key-value pairs from a + mapping or an iterable, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~candle.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~candle.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~candle.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + type(modules).__name__ + ) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError( + "ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(m).__name__ + ) + if not len(m) == 2: + raise ValueError( + "ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" + ) + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward alltogether to fallback on Module's _forward_unimplemented diff --git a/candle-pyo3/py_src/candle/nn/linear.py b/candle-pyo3/py_src/candle/nn/linear.py new file mode 100644 index 0000000000..d275eb1ec8 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/linear.py @@ -0,0 +1,119 @@ +import math +from typing import Any + +import candle +from candle import Tensor +from .module import Module + +# See https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py + + +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = candle.randn(128, 20) + >>> output = m(input) + >>> print(output.shape) + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +class Linear(Module): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + Args: + in_features: size of each input sample + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input: :math:`(*, H_{in})` where :math:`*` means any number of + dimensions including none and :math:`H_{in} = \text{in\_features}`. + - Output: :math:`(*, H_{out})` where all but the last dimension + are the same shape as the input and :math:`H_{out} = \text{out\_features}`. + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + # Allow 'weight' to be quantized + self._quantizable_buffers.add("weight") + + self.in_features = in_features + self.out_features = out_features + # TODO: Do actual initialization here: e.g. kaiming_uniform or xavier_uniform + self.weight = candle.ones((out_features, in_features), **factory_kwargs) + if bias: + self.bias = candle.zeros((out_features,), **factory_kwargs) + else: + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + dims = x.shape + last_dim = dims[-1] + + if isinstance(self.weight, candle.QTensor): + if len(dims) < 3: + matmul_result = self.weight.matmul_t(x).broadcast_add(self.bias) + elif len(dims) == 3: + b, n, m = dims + output_shape = (b, n, self.out_features) + re = x.reshape((b * n, m)) + matmul_result = self.weight.matmul_t(re).reshape((output_shape)) + else: + raise NotImplementedError("'QTensor.matmul_t' is not implemented for more than 3 dimensions") + + if self.bias: + return matmul_result.broadcast_add(self.bias) + else: + if self.weight.shape[-1] == last_dim and len(dims) < 3: + w = self.weight.t() + else: + batch_size = dims[0] + w = self.weight.broadcast_left((batch_size,)).t() + + x = x.matmul(w) + if self.bias is not None: + x = x.broadcast_add(self.bias) + return x + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" diff --git a/candle-pyo3/py_src/candle/nn/module.py b/candle-pyo3/py_src/candle/nn/module.py new file mode 100644 index 0000000000..514d92b86e --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/module.py @@ -0,0 +1,702 @@ +from candle import Tensor, QTensor, DType +from typing import ( + Dict, + Tuple, + Any, + Optional, + Union, + Iterator, + Set, + overload, + Mapping, + TypeVar, + List, +) +from collections import OrderedDict, namedtuple + +TensorLike = Union[Tensor, QTensor] +T = TypeVar("T", bound="Module") + + +class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])): + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return "" + return super().__repr__() + + __str__ = __repr__ + + +# see: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py +class Module: + """ + Pytorch like Module. + + Base class for all neural network modules. + + Your models should also subclass this class. + """ + + _modules: Dict[str, Optional["Module"]] + _buffers: Dict[str, Optional[TensorLike]] + _non_persistent_buffers_set: Set[str] + _quantizable_buffers: Set[str] + _version: int = 1 + + def __init__(self, *args, **kwargs) -> None: + """ + Initializes internal Module state + """ + super().__setattr__("_modules", OrderedDict()) + super().__setattr__("_buffers", OrderedDict()) + super().__setattr__("_non_persistent_buffers_set", set()) + super().__setattr__("_quantizable_buffers", set()) + + def __call__(self, *input): + """ + Call self as a function. + """ + return self.forward(*input) + + def forward(self, *input): + """ + Defines the computation performed at every call. + Should be overridden by all subclasses. + """ + pass + + def children(self) -> Iterator["Module"]: + r"""Returns an iterator over immediate children modules. + + Yields: + Module: a child module + """ + for name, module in self.named_children(): + yield module + + def named_children(self) -> Iterator[Tuple[str, "Module"]]: + r"""Returns an iterator over immediate children modules, yielding both + the name of the module as well as the module itself. + + Yields: + (str, Module): Tuple containing a name and child module + + Example:: + + >>> for name, module in model.named_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(module) + + """ + memo = set() + for name, module in self._modules.items(): + if module is not None and module not in memo: + memo.add(module) + yield name, module + + def add_module(self, name: str, module: Optional["Module"]) -> None: + r"""Adds a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (str): name of the child module. The child module can be + accessed from this module using the given name + module (Module): child module to be added to the module. + """ + if not isinstance(module, Module) and module is not None: + raise TypeError(f"{str(module)} is not a Module subclass") + elif not isinstance(name, str): + raise TypeError(f"module name should be a string. Got {name}") + elif hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + elif "." in name: + raise KeyError(f'module name can\'t contain ".", got: {name}') + elif name == "": + raise KeyError('module name can\'t be empty string ""') + self._modules[name] = module + + def register_module(self, name: str, module: Optional["Module"]) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + + def modules(self) -> Iterator["Module"]: + r"""Returns an iterator over all modules in the network.""" + for _, module in self.named_modules(): + yield module + + def named_modules( + self, + memo: Optional[Set["Module"]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + r"""Returns an iterator over all modules in the network, yielding + both the name of the module as well as the module itself. + + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + + Yields: + (str, Module): Tuple of name and module + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + """ + + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + for name, module in self._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + for m in module.named_modules(memo, submodule_prefix, remove_duplicate): + yield m + + def buffers(self, recurse: bool = True) -> Iterator[TensorLike]: + """ + Returns an iterator over module buffers. + """ + for name, buf in self.named_buffers(recurse=recurse): + yield buf + + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, TensorLike]]: + r"""Returns an iterator over module buffers, yielding both the + name of the buffer as well as the buffer itself. + + Args: + prefix (str): prefix to prepend to all buffer names. + recurse (bool, optional): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. Defaults to True. + remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. + + Yields: + (str, Tensor): Tuple containing the name and buffer + + Example:: + + >>> for name, buf in self.named_buffers(): + >>> if name in ['running_var']: + >>> print(buf.size()) + + """ + gen = self._named_members( + lambda module: module._buffers.items(), + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ) + yield from gen + + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. + T_destination = TypeVar("T_destination", bound=Dict[str, Any]) + + @overload + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: + ... + + @overload + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: + ... + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + r"""Returns a dictionary containing references to the whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. note:: + The returned object is a shallow copy. It contains references + to the module's parameters and buffers. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~candle.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == "": + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict( + destination=destination, + prefix=prefix + name + ".", + keep_vars=keep_vars, + ) + return destination + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~candle.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + if isinstance(buf, Tensor): + destination[prefix + name] = buf if keep_vars else buf.detach() + else: + destination[prefix + name] = buf + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~candle.nn.Module.state_dict` function. + + .. warning:: + If :attr:`assign` is ``True`` the optimizer must be created after + the call to :attr:`load_state_dict`. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~candle.nn.Module.state_dict` function. Default: ``True`` + assign (bool, optional): whether to assign items in the state + dictionary to their corresponding keys in the module instead + of copying them inplace into the module's current parameters and buffers. + When ``False``, the properties of the tensors in the current + module are preserved while when ``True``, the properties of the + Tensors in the state dict are preserved. + Default: ``False`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata["assign_to_params_buffers"] = assign + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + "." + child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} + load(child, child_state_dict, child_prefix) + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, + "Unexpected key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in unexpected_keys)), + ) + if len(missing_keys) > 0: + error_msgs.insert( + 0, + "Missing key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in missing_keys)), + ) + + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs)) + ) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~candle.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + Additionally, :attr:`local_metadata` can also contain the key + `assign_to_params_buffers` that indicates whether keys should be + assigned their corresponding tensor in the state_dict. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~candle.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~candle.nn.Module.load_state_dict` + """ + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = persistent_buffers.items() + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not isinstance(input_param, (Tensor, QTensor)): + error_msgs.append( + f'While copying the parameter named "{key}", ' + "expected Tensor-like object from checkpoint but " + f"received {type(input_param)}" + ) + continue + + if input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) + continue + + try: + # Shape checks are already done above -> Just assign tensor + setattr(self, name, input_param) + except Exception as ex: + error_msgs.append( + f'While copying the parameter named "{key}", ' + f"whose dimensions in the model are {param.shape} and " + f"whose dimensions in the checkpoint are {input_param.shape}, " + f"an exception occurred : {ex.args}." + ) + elif strict: + missing_keys.append(key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix): + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True): + r"""Helper method for yielding various names + members of modules.""" + memo = set() + modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)] + for module_prefix, module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + if remove_duplicate: + memo.add(v) + name = module_prefix + ("." if module_prefix else "") + k + yield name, v + + def _get_name(self): + return self.__class__.__name__ + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + return self + + def __move_tensor_to_device(self, tensor: TensorLike, device: str): + if isinstance(tensor, Tensor): + return tensor.to_device(device) + else: + raise NotImplementedError("Cannot offload QTensor to cuda, yet!") + + def device(self) -> str: + """ + Gets the device of the module, by inspecting its tensors. + """ + tensor = next(self.buffers()) + if isinstance(tensor, Tensor): + return tensor.device + else: + # QTensors can only be on the CPU + return "cpu" + + def cuda(self: T) -> T: + r"""Moves all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + + def to_cuda(t: TensorLike): + return self.__move_tensor_to_device(t, "cuda") + + return self._apply(to_cuda) + + def cpu(self: T) -> T: + r"""Moves all model parameters and buffers to the CPU. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + + def to_cpu(t: TensorLike): + return self.__move_tensor_to_device(t, "cpu") + + return self._apply(to_cpu) + + def __cast_tensor(self, tensor: TensorLike, dtype: Union[DType, str]): + if isinstance(tensor, Tensor): + return tensor.to_dtype(dtype) + else: + raise TypeError("candle.Module.to only accepts Tensor dtypes, but got desired dtype={}".format(dtype)) + + def type(self: T, dst_type: Union[DType, str]) -> T: + r"""Casts all parameters and buffers to :attr:`dst_type`. + + .. note:: + This method modifies the module in-place. + + Args: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + + def cast(t: TensorLike): + return self.__cast_tensor(t, dst_type) + + return self._apply(cast) + + @overload + def to( + self: T, + device: str = ..., + dtype: Optional[Union[DType, str]] = ..., + ) -> T: + ... + + @overload + def to(self: T, dtype: Union[DType, str]) -> T: + ... + + def to(self, *args, **kwargs): + r"""Moves and/or casts the parameters and buffers. + + This can be called as + + .. function:: to(device=None, dtype=None) + :noindex: + + .. function:: to(dtype) + :noindex: + + See below for examples. + + .. note:: + This method modifies the module in-place. + + Args: + device (:class:`candle.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`candle.dtype`): the desired floating point dtype of + the parameters and buffers in this module + + Returns: + Module: self + """ + + device = None + dtype = None + + if args: + for arg in args: + # Assuming arg can be a string representing a device or a dtype + + if isinstance(arg, str): + lower_arg = str(arg).lower() + if lower_arg.startswith("cuda") or lower_arg == "cpu": + device = lower_arg + else: + dtype = arg + elif isinstance(arg, DType): + dtype = str(arg) + else: + raise TypeError("Module.to() received an invalid combination of arguments. Got: {}".format(args)) + + if kwargs: + device = kwargs.get("device", device) + dtype = str(kwargs.get("dtype", dtype)) + + if device: + device = device.lower() + + if dtype: + dtype = dtype.lower() + if dtype not in ["f32", "f16", "f64"]: + raise TypeError( + "candle.Module.to only accepts floating point" "dtypes, but got desired dtype={}".format(dtype) + ) + + def convert(t): + if dtype: + t = self.__cast_tensor(t, dtype) + if device: + t = self.__move_tensor_to_device(t, device) + return t + + return self._apply(convert) + + def __setattr__(self, __name: str, __value: Any) -> None: + if isinstance(__value, Module): + self._modules[__name] = __value + elif isinstance(__value, QTensor): + if __name in self._quantizable_buffers: + type = __value.ggml_dtype.lower() + if type in ["f32", "f16"]: + # It is faster to just dequantize the tensor here and use the normal tensor operations + dequant = __value.dequantize() + if type == "f16": + dequant = dequant.to_dtype("f16") + self._buffers[__name] = dequant + else: + self._buffers[__name] = __value + else: + # We expect a normal tensor here => dequantize it + self._buffers[__name] = __value.dequantize() + elif isinstance(__value, Tensor): + self._buffers[__name] = __value + else: + super().__setattr__(__name, __value) + + def __getattr__(self, __name: str) -> Any: + if "_modules" in self.__dict__: + modules = self.__dict__["_modules"] + if __name in modules: + return modules[__name] + if "_buffers" in self.__dict__: + tensors = self.__dict__["_buffers"] + if __name in tensors: + return tensors[__name] + return super().__getattribute__(__name) + + def __delattr__(self, name): + if name in self._buffers: + del self._buffers[name] + elif name in self._modules: + del self._modules[name] + else: + super().__delattr__(name) diff --git a/candle-pyo3/py_src/candle/nn/normalization.py b/candle-pyo3/py_src/candle/nn/normalization.py new file mode 100644 index 0000000000..67510a24bb --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/normalization.py @@ -0,0 +1,54 @@ +import candle +from candle import Tensor +from .module import Module +from typing import Union, List, Tuple, Optional, Any + +_shape_t = Union[int, List[int]] +import numbers + + +class LayerNorm(Module): + r"""Applies Layer Normalization over a mini-batch of inputs as described in + the paper `Layer Normalization ` + + math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + """ + __constants__ = ["normalized_shape", "eps"] + normalized_shape: Tuple[int, ...] + eps: float + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + + self.weight = candle.ones(normalized_shape, **factory_kwargs) + if bias: + self.bias = candle.zeros(normalized_shape, **factory_kwargs) + else: + self.bias = None + + def forward(self, input: Tensor) -> Tensor: + mean_x = input.sum_keepdim(2) / float(self.normalized_shape[-1]) + x = input.broadcast_sub(mean_x) + norm_x = x.sqr().sum_keepdim(2) / float(self.normalized_shape[-1]) + x_normed = x.broadcast_div((norm_x + self.eps).sqrt()) + x = x_normed.broadcast_mul(self.weight) + + if self.bias: + x = x.broadcast_add(self.bias) + return x + + def extra_repr(self) -> str: + return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) diff --git a/candle-pyo3/py_src/candle/nn/sparse.py b/candle-pyo3/py_src/candle/nn/sparse.py new file mode 100644 index 0000000000..386f80817d --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/sparse.py @@ -0,0 +1,39 @@ +from .module import Module +from typing import Optional, Tuple, Any +from candle import Tensor +import candle + + +class Embedding(Module): + """A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, device=None) -> None: + factory_kwargs = {"device": device} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = candle.randn((num_embeddings, embedding_dim), **factory_kwargs) + + def forward(self, indexes: Tensor) -> Tensor: + final_dims = list(indexes.shape) + final_dims.append(self.embedding_dim) + indexes = indexes.flatten_all() + values = self.weight.index_select(indexes, 0) + return values.reshape(final_dims) diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py index ea85d2a36a..ccdb623850 100644 --- a/candle-pyo3/py_src/candle/typing/__init__.py +++ b/candle-pyo3/py_src/candle/typing/__init__.py @@ -2,7 +2,7 @@ _T = TypeVar("_T") -_ArrayLike = Union[ +_ArrayLike = Union[ _T, Sequence[_T], Sequence[Sequence[_T]], @@ -10,7 +10,7 @@ Sequence[Sequence[Sequence[Sequence[_T]]]], ] -CPU:str = "cpu" -CUDA:str = "cuda" +CPU: str = "cpu" +CUDA: str = "cuda" -Device = TypeVar("Device", CPU, CUDA) \ No newline at end of file +Device = TypeVar("Device", CPU, CUDA) diff --git a/candle-pyo3/pyproject.toml b/candle-pyo3/pyproject.toml index 88793493b1..e375796c63 100644 --- a/candle-pyo3/pyproject.toml +++ b/candle-pyo3/pyproject.toml @@ -28,3 +28,7 @@ features = ["pyo3/extension-module"] [tool.black] line-length = 119 target-version = ['py35'] + +[project.optional-dependencies] +testing = ["pytest", "black==22.3"] +huggingface = ["transformers>=4.33.3", "huggingface-hub>=0.17.3"] \ No newline at end of file diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index 46d9ff62dd..1cb39e4ff2 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -2,181 +2,59 @@ import sys from typing import Dict, Tuple, Any import candle -from candle import Tensor, QTensor, utils, nn +from candle.models.llama import QuantizedLlama +from candle import utils MAX_SEQ_LEN = 4096 -def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor): - shape = mask.shape - on_true = candle.tensor(on_true).broadcast_as(shape) - return mask.where_cond(on_true, on_false) -class RmsNorm: - def __init__(self, qtensor:QTensor): - self.weight = qtensor.dequantize() - - def __call__(self, x:Tensor): - b_size, seq_len, hidden_size = x.shape - norm_x = x.sqr().sum_keepdim(2) / hidden_size - x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) - return x_normed.broadcast_mul(self.weight) - -class QuantizedLayer: - def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]): - p = f"layers.{layer_idx}" - self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] - self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] - self.attention_wv = all_tensors[f"{p}.attention.wv.weight"] - self.attention_wo = all_tensors[f"{p}.attention.wo.weight"] - self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"] - self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"] - self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"] - self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"]) - self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"]) - - self.n_head = hparams["n_head"] - self.n_kv_head = self.n_head - self.head_dim = hparams["n_embd"] // self.n_head - - self.kv_cache = None - self.cos = cos_sin[0] - self.sin = cos_sin[1] - - def __call__(self, x:Tensor, mask:Tensor, index_pos:int): - residual = x - x = self.attn_norm(x) - attn = self.forward_attn(x, mask, index_pos) - x = attn + residual - - residual = x - x = self.ffn_norm(x) - w1 = self.ffw1.matmul_t(x) - w3 = self.ffw3.matmul_t(x) - mlp = self.ffw2.matmul_t(nn.silu(w1) * w3) - - return mlp + residual - - def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int): - b_size, seq_len, n_embd = x.shape - q = self.attention_wq.matmul_t(x) - k = self.attention_wk.matmul_t(x) - v = self.attention_wv.matmul_t(x) - - q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2) - k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) - v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) - - q = self.apply_rotary_emb(q, index_pos) - k = self.apply_rotary_emb(k, index_pos) - - if self.kv_cache is not None and index_pos > 0: - prev_k, prev_v = self.kv_cache - k = candle.cat([prev_k, k], 2).contiguous() - v = candle.cat([prev_v, v], 2).contiguous() - - self.kv_cache = (k, v) - - # TODO: maybe repeat k/v here if we start supporting MQA. - - att = q.matmul(k.t()) / self.head_dim**0.5 - mask = mask.broadcast_as(att.shape) - att = masked_fill(att, mask, float("-inf")) - att = nn.softmax(att, -1) - y = att.matmul(v.contiguous()) - y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) - return self.attention_wo.matmul_t(y) - - def apply_rotary_emb(self, x:Tensor, index_pos:int): - (b_size, n_head, seq_len, n_embd) = x.shape - cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) - sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) - x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2)) - x0 = x.narrow(-1, 0, 1) - x1 = x.narrow(-1, 1, 1) - y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin) - y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos) - rope = candle.cat([y0, y1], -1) - return rope.flatten_from(-2) - -def precompute_freqs_cis(hparams, freq_base): - head_dim = hparams["n_embd"] // hparams["n_head"] - theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)] - theta = candle.tensor(theta) - idx_theta = [float(i) for i in range(MAX_SEQ_LEN)] - idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1)) - m = idx_theta.matmul(theta.unsqueeze(0)) - return (m.cos(), m.sin()) - -class QuantizedLlama: - def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]): - self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() - self.norm = RmsNorm(all_tensors["norm.weight"]) - self.output = all_tensors["output.weight"] - self.layers = [] - rope_freq = hparams.get("rope_freq", 10000.) - cos_sin = precompute_freqs_cis(hparams, rope_freq) - for layer_idx in range(hparams["n_layer"]): - layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) - self.layers.append(layer) - - def __call__(self, token:Tensor, index_pos:int): - b_size, seq_len = token.shape - vocab_size, hidden_size = self.tok_embeddings.shape - token = token.reshape((b_size * seq_len,)) - x = self.tok_embeddings.index_select(token, 0) - x = x.reshape((b_size, seq_len, hidden_size)) - - mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)] - mask = candle.tensor(mask).reshape((seq_len, seq_len)) - - for layer in self.layers: - x = layer(x, mask, index_pos) - x = self.norm(x) - x = x.narrow(1, -1, 1).squeeze(1) - x = self.output.matmul_t(x) - return x - -def gguf_rename(tensor_name:str): - if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight' - if tensor_name == 'output_norm.weight': return 'norm.weight' - tensor_name = tensor_name.replace('blk.', 'layers.') - tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.') - tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.') - tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.') - tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.') - tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.') - tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.') - tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.') - tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.') +def gguf_rename(tensor_name: str): + if tensor_name == "token_embd.weight": + return "tok_embeddings.weight" + if tensor_name == "output_norm.weight": + return "norm.weight" + tensor_name = tensor_name.replace("blk.", "layers.") + tensor_name = tensor_name.replace(".attn_q.", ".attention.wq.") + tensor_name = tensor_name.replace(".attn_k.", ".attention.wk.") + tensor_name = tensor_name.replace(".attn_v.", ".attention.wv.") + tensor_name = tensor_name.replace(".attn_output.", ".attention.wo.") + tensor_name = tensor_name.replace(".ffn_gate.", ".feed_forward.w1.") + tensor_name = tensor_name.replace(".ffn_down.", ".feed_forward.w2.") + tensor_name = tensor_name.replace(".ffn_up.", ".feed_forward.w3.") + tensor_name = tensor_name.replace(".attn_norm.", ".attention_norm.") return tensor_name + def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") + filename = sys.argv[1] print(f"reading model file {filename}") if filename.endswith("gguf"): - all_tensors, metadata = utils.load_gguf(sys.argv[1]) + all_tensors, metadata = utils.load_gguf(filename) vocab = metadata["tokenizer.ggml.tokens"] for i, v in enumerate(vocab): - vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ') + vocab[i] = "\n" if v == "<0x0A>" else v.replace("▁", " ") hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")} print(hparams) hparams = { - 'n_vocab': len(vocab), - 'n_embd': metadata['llama.embedding_length'], - 'n_mult': 256, - 'n_head': metadata['llama.attention.head_count'], - 'n_head_kv': metadata['llama.attention.head_count_kv'], - 'n_layer': metadata['llama.block_count'], - 'n_rot': metadata['llama.rope.dimension_count'], - 'rope_freq': metadata.get('llama.rope.freq_base', 10000.), - 'ftype': metadata['general.file_type'], + "n_vocab": len(vocab), + "n_embd": metadata["llama.embedding_length"], + "n_mult": 256, + "n_head": metadata["llama.attention.head_count"], + "n_head_kv": metadata["llama.attention.head_count_kv"], + "n_layer": metadata["llama.block_count"], + "n_rot": metadata["llama.rope.dimension_count"], + "rope_freq": metadata.get("llama.rope.freq_base", 10000.0), + "ftype": metadata["general.file_type"], + "context_length": metadata["llama.context_length"], } - all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() } - + all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()} else: - all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = utils.load_ggml(filename) + hparams["context_length"] = 2048 + print(hparams) model = QuantizedLlama(hparams, all_tensors) print("model built, starting inference") @@ -185,13 +63,14 @@ def main(): for token_idx in range(500): last_token = tokens[-1] lt = candle.tensor([last_token]).unsqueeze(0) - logits = model(lt, len(tokens)) + logits = model.forward(lt, len(tokens)) # Greedy sampling for now # pr = candle.nn.softmax(logits, -1) m = logits.get(0).argmax_keepdim(-1) next_token = m.values()[0] - print(vocab[next_token], end='', flush=True) + print(vocab[next_token], end="", flush=True) tokens.append(next_token) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 64b6dd2c78..4d4b520006 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -3,6 +3,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; +use std::os::raw::c_long; use std::sync::Arc; use half::{bf16, f16}; @@ -196,6 +197,12 @@ trait MapDType { } } +enum Indexer { + Index(usize), + Slice(usize, usize), + Elipsis, +} + #[pymethods] impl PyTensor { #[new] @@ -436,6 +443,95 @@ impl PyTensor { )) } + #[getter] + /// Index a tensor. + /// &RETURNS&: Tensor + fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult { + let mut indexers: Vec = vec![]; + let dims = self.0.shape().dims(); + + let to_absolute_index = |index: isize, current_dim: usize| { + // Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0] + let actual_index = if index < 0 { + dims[current_dim] as isize + index + } else { + index + }; + + // Check that the index is in range + if actual_index < 0 || actual_index >= dims[current_dim] as isize { + return Err(PyTypeError::new_err(format!( + "index out of range for dimension '{i}' with indexer '{value}'", + i = current_dim, + value = index + ))); + } + Ok(actual_index as usize) + }; + if let Ok(index) = idx.extract(py) { + // Handle a single index e.g. tensor[0] or tensor[-1] + indexers.push(Indexer::Index(to_absolute_index(index, 0)?)); + } else if let Ok(slice) = idx.downcast::(py) { + // Handle a single slice e.g. tensor[0:1] or tensor[0:-1] + let index = slice.indices(dims[0] as c_long)?; + indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); + } else if let Ok(tuple) = idx.downcast::(py) { + // Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1] + + if tuple.len() > dims.len() { + return Err(PyTypeError::new_err("provided too many indices")); + } + + for (i, item) in tuple.iter().enumerate() { + if item.is_ellipsis() { + // Handle '...' e.g. tensor[..., 0] + + if i > 0 { + return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation")); + } + indexers.push(Indexer::Elipsis); + } else if let Ok(slice) = item.downcast::() { + // Handle slice + let index = slice.indices(dims[i] as c_long)?; + indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); + } else if let Ok(index) = item.extract::() { + indexers.push(Indexer::Index(to_absolute_index(index, i)?)); + } else { + return Err(PyTypeError::new_err("unsupported index")); + } + } + } else { + return Err(PyTypeError::new_err("unsupported index")); + } + + let mut x = self.0.clone(); + let mut current_dim = 0; + // Apply the indexers + for indexer in indexers.iter() { + x = match indexer { + Indexer::Index(n) => x + .narrow(current_dim, *n, 1) + .map_err(wrap_err)? + .squeeze(current_dim) + .map_err(wrap_err)?, + Indexer::Slice(start, stop) => { + let out = x + .narrow(current_dim, *start, stop.saturating_sub(*start)) + .map_err(wrap_err)?; + current_dim += 1; + out + } + Indexer::Elipsis => { + // Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for + current_dim += dims.len() - (indexers.len() - 1); + x + } + } + } + + Ok(Self(x)) + } + /// Add two tensors. /// &RETURNS&: Tensor fn __add__(&self, rhs: &PyAny) -> PyResult { @@ -697,7 +793,7 @@ impl PyTensor { /// &RETURNS&: QTensor fn quantize(&self, quantized_dtype: &str) -> PyResult { use ::candle::quantized; - let res = match quantized_dtype { + let res = match quantized_dtype.to_lowercase().as_str() { "q2k" => quantized::QTensor::quantize::(self), "q3k" => quantized::QTensor::quantize::(self), "q4_0" => quantized::QTensor::quantize::(self), @@ -1137,9 +1233,39 @@ fn silu(tensor: PyTensor) -> PyResult { Ok(PyTensor(s)) } -fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +#[pyfunction] +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor. +/// &RETURNS&: Tensor +fn gelu(tensor: PyTensor) -> PyResult { + let s = tensor.0.gelu_erf().map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +#[pyfunction] +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the Rectified Linear Unit (ReLU) function to a given tensor. +/// &RETURNS&: Tensor +fn relu(tensor: PyTensor) -> PyResult { + let s = tensor.0.relu().map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +#[pyfunction] +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the tanh function to a given tensor. +/// &RETURNS&: Tensor +fn tanh(tensor: PyTensor) -> PyResult { + let s = tensor.0.tanh().map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(silu, m)?)?; m.add_function(wrap_pyfunction!(softmax, m)?)?; + m.add_function(wrap_pyfunction!(gelu, m)?)?; + m.add_function(wrap_pyfunction!(relu, m)?)?; + m.add_function(wrap_pyfunction!(tanh, m)?)?; Ok(()) } @@ -1148,8 +1274,8 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let utils = PyModule::new(py, "utils")?; candle_utils(py, utils)?; m.add_submodule(utils)?; - let nn = PyModule::new(py, "nn")?; - candle_nn_m(py, nn)?; + let nn = PyModule::new(py, "functional")?; + candle_functional_m(py, nn)?; m.add_submodule(nn)?; m.add_class::()?; m.add_class::()?; diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index 149715c275..3100a10c10 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -1,4 +1,4 @@ -#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py +# See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py import argparse import inspect import os @@ -23,7 +23,7 @@ def do_indent(text: Optional[str], indent: str): return text.replace("\n", f"\n{indent}") -def function(obj, indent:str, text_signature:str=None): +def function(obj, indent: str, text_signature: str = None): if text_signature is None: text_signature = obj.__text_signature__ @@ -32,12 +32,12 @@ def function(obj, indent:str, text_signature:str=None): if doc_string is None: doc_string = "" - # Check if we have a return type annotation in the docstring + # Check if we have a return type annotation in the docstring return_type = None doc_lines = doc_string.split("\n") if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER): # Extract the return type and remove it from the docstring - return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip() + return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER) :].strip() doc_string = "\n".join(doc_lines[:-1]) string = "" @@ -115,7 +115,7 @@ def pyi_file(obj, indent=""): body += f"{indent+INDENT}pass\n" body += "\n" - for (name, fn) in fns: + for name, fn in fns: body += pyi_file(fn, indent=indent) if not body: @@ -221,12 +221,12 @@ def write(module, directory, origin, check=False): args = parser.parse_args() - #Enable execution from the candle and candle-pyo3 directories + # Enable execution from the candle and candle-pyo3 directories cwd = Path.cwd() directory = "py_src/candle/" if cwd.name != "candle-pyo3": directory = f"candle-pyo3/{directory}" - + import candle write(candle.candle, directory, "candle", check=args.check) diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 7f24b49d7e..a56ed22c3c 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -7,7 +7,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) -print(t+t) +print(t + t) t = t.reshape([2, 4]) print(t.matmul(t.t())) diff --git a/candle-pyo3/tests/__init__.py b/candle-pyo3/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/candle-pyo3/tests/bindings/test_linear.py b/candle-pyo3/tests/bindings/test_linear.py new file mode 100644 index 0000000000..5936aac469 --- /dev/null +++ b/candle-pyo3/tests/bindings/test_linear.py @@ -0,0 +1,38 @@ +import candle +from candle import Tensor +from candle.nn import Linear + + +def test_linear_layer_can_be_constructed(): + linear = Linear(10, 10) + assert linear is not None + + +def test_linear_layer_can_forward_a_singular_input(): + linear = Linear(384, 1536) + input_tensor = candle.randn((8, 384)) + output = linear.forward(input_tensor) + assert output.shape == (8, 1536) + + +def test_linear_layer_can_forward_a_batched_input(): + linear = Linear(384, 1536) + input_tensor = candle.randn((16, 8, 384)) + output = linear.forward(input_tensor) + assert output.shape == (16, 8, 1536) + + +def test_quantized_linear_layer_can_forward_a_singular_input(): + linear = Linear(384, 1536) + linear.weight = linear.weight.quantize("q4_0") + input_tensor = candle.randn((8, 384)) + output = linear.forward(input_tensor) + assert output.shape == (8, 1536) + + +def test_quantized_linear_layer_can_forward_a_batched_input(): + linear = Linear(384, 1536) + linear.weight = linear.weight.quantize("q4_0") + input_tensor = candle.randn((16, 8, 384)) + output = linear.forward(input_tensor) + assert output.shape == (16, 8, 1536) diff --git a/candle-pyo3/tests/bindings/test_module.py b/candle-pyo3/tests/bindings/test_module.py new file mode 100644 index 0000000000..819dae5be1 --- /dev/null +++ b/candle-pyo3/tests/bindings/test_module.py @@ -0,0 +1,161 @@ +import candle +from candle import Tensor, QTensor +from candle.nn import Module, Linear +from candle.utils import cuda_is_available + +import pytest + + +def test_module_can_be_constructed(): + class A(Module): + pass + + a = A() + assert a is not None + assert len(list(a.buffers())) == 0 + + +def test_module_registers_tensors(): + class A(Module): + def __init__(self): + super().__init__() + self.t = Tensor(42.0) + + a = A() + named_buffers = dict(a.named_buffers()) + assert len(named_buffers) == 1 + assert "t" in named_buffers + + +def test_module_registers_submodules(): + class A(Module): + def __init__(self): + super().__init__() + self.linear = Linear(10, 20) + + a = A() + named_modules = dict(a.named_modules()) + named_buffers = dict(a.named_buffers()) + assert len(named_buffers) == 2 + assert "linear" in named_modules + assert "linear.weight" in named_buffers + assert "linear.bias" in named_buffers + + +def test_module_can_dump_statedict(): + class A(Module): + def __init__(self): + super().__init__() + self.linear = Linear(10, 20) + self.t = Tensor(42.0) + + a = A() + state_dict = a.state_dict() + assert hasattr(state_dict, "_metadata") + assert "t" in state_dict + assert "linear.weight" in state_dict + assert "linear.bias" in state_dict + assert len(state_dict) == 3 + + +def test_module_can_load_statedict(): + class A(Module): + def __init__(self): + super().__init__() + self.linear = Linear(10, 20) + self.t = Tensor(42.0) + + statedict = { + "linear.weight": candle.ones((20, 10)), + "linear.bias": candle.zeros((20,)), + "t": Tensor(42.0), + } + a = A() + a.load_state_dict(statedict) + + +def test_module_throws_on_shape_missmatch(): + class A(Module): + def __init__(self): + super().__init__() + self.t = Tensor(42.0) + + statedict = { + "t": candle.ones((20,)), + } + a = A() + with pytest.raises(RuntimeError) as excinfo: + a.load_state_dict(statedict) + assert "size mismatch" in str(excinfo.value) + + +def test_module_throws_on_missing_key(): + class A(Module): + def __init__(self): + super().__init__() + self.t = Tensor(42.0) + + statedict = { + "not_t": Tensor(42.0), + } + + a = A() + with pytest.raises(RuntimeError) as excinfo: + a.load_state_dict(statedict) + assert 'Missing key(s) in state_dict: "t".' in str(excinfo.value) + + +def test_module_can_load_quantized_tensors(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + self._quantizable_buffers.add("t") + + statedict = { + "t": candle.ones((16, 256)).quantize("q4_0"), + } + a = A() + a.load_state_dict(statedict) + assert isinstance(a.t, QTensor) + assert a.t.ggml_dtype == "Q4_0" + + +def test_module_dequantizes_tensors_automaticaly(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + + statedict = { + "t": candle.ones((16, 256)).quantize("q4_0"), + } + a = A() + a.load_state_dict(statedict) + assert isinstance(a.t, Tensor) + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_module_can_be_moved_to_cuda(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + + a = A() + a.cuda() + assert a.t.device == "cuda" + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_module_can_be_moved_from_cuda_to_cpu(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + + a = A() + a.cuda() + assert a.t.device == "cuda" + a.cpu() + assert a.t.device == "cpu" diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py new file mode 100644 index 0000000000..1f5b74f677 --- /dev/null +++ b/candle-pyo3/tests/native/test_tensor.py @@ -0,0 +1,74 @@ +import candle +from candle import Tensor + + +def test_tensor_can_be_constructed(): + t = Tensor(42.0) + assert t.values() == 42.0 + + +def test_tensor_can_be_constructed_from_list(): + t = Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) + assert t.values() == [3.0, 1, 4, 1, 5, 9, 2, 6] + + +def test_tensor_can_be_constructed_from_list_of_lists(): + t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]]) + assert t.values() == [[3.0, 1, 4, 1], [5, 9, 2, 6]] + + +def test_tensor_can_be_quantized(): + t = candle.randn((16, 256)) + for format in [ + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2k", + "q3k", + "q4k", + "q5k", + "q8k", + ]: + for formatted_format in [format.upper(), format.lower()]: + quant_t = t.quantize(formatted_format) + assert quant_t.ggml_dtype.lower() == format.lower() + assert quant_t.shape == t.shape + + +def test_tensor_can_be_indexed(): + t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]]) + assert t[0].values() == [3.0, 1.0, 4.0, 1.0] + assert t[1].values() == [5.0, 9.0, 2.0, 6.0] + assert t[-1].values() == [5.0, 9.0, 2.0, 6.0] + assert t[-2].values() == [3.0, 1.0, 4.0, 1.0] + + +def test_tensor_can_be_sliced(): + t = Tensor([3.0, 1, 4, 10, 5, 9, 2, 6]) + + assert t[0:4].values() == [3.0, 1.0, 4.0, 10.0] + assert t[4:8].values() == [5.0, 9.0, 2.0, 6.0] + assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0] + assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0] + assert t[-4:-2].values() == [5.0, 9.0] + + +def test_tensor_can_be_sliced_2d(): + t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]]) + assert t[:, 0].values() == [3.0, 5] + assert t[:, 1].values() == [1.0, 9.0] + assert t[0, 0].values() == 3.0 + assert t[:, -1].values() == [1.0, 6.0] + assert t[:, -4].values() == [3.0, 5] + assert t[..., 0].values() == [3.0, 5] + + +def test_tensor_can_be_scliced_3d(): + t = Tensor([[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]]) + assert t[:, :, 0].values() == [[1, 5], [9, 13]] + assert t[:, :, 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] + assert t[:, 0, 0].values() == [1, 9] + assert t[..., 0].values() == [[1, 5], [9, 13]] + assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] diff --git a/candle-pyo3/tests/native/test_utils.py b/candle-pyo3/tests/native/test_utils.py new file mode 100644 index 0000000000..f5f5312250 --- /dev/null +++ b/candle-pyo3/tests/native/test_utils.py @@ -0,0 +1,51 @@ +import candle +from candle import Tensor, QTensor +from candle.utils import load_safetensors, save_gguf, load_gguf, save_safetensors +from pathlib import Path + +TEST_DIR = Path(__file__).parent.parent / "_workdir" +TEST_DIR.mkdir(exist_ok=True) + + +def test_can_roundtrip_safetensors(): + tensors = { + "a": candle.randn((16, 256)), + "b": candle.randn((16, 16)), + } + + file = str(TEST_DIR / "test.safetensors") + save_safetensors(file, tensors) + loaded_tensors = load_safetensors(file) + assert set(tensors.keys()) == set(loaded_tensors.keys()) + for key in tensors.keys(): + assert tensors[key].values() == loaded_tensors[key].values(), "Values are not equal" + assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal" + assert str(tensors[key].dtype) == str(loaded_tensors[key].dtype), "Dtypes are not equal" + + +def test_can_roundtrip_gguf(): + metadata = { + "a": 1, + "b": "foo", + "c": [1, 2, 3], + "d": [[1, 2], [3, 4]], + } + + tensors = { + "a": candle.randn((16, 256)).quantize("q4_0"), + "b": candle.randn((16, 16)).quantize("f32"), + } + + file = str(TEST_DIR / "test.gguf") + save_gguf(file, tensors, metadata) + loaded_tensors, loaded_metadata = load_gguf(file) + + assert set(metadata.keys()) == set(loaded_metadata.keys()) + for key in metadata.keys(): + assert metadata[key] == loaded_metadata[key] + + assert set(tensors.keys()) == set(loaded_tensors.keys()) + for key in tensors.keys(): + assert tensors[key].dequantize().values() == loaded_tensors[key].dequantize().values(), "Values are not equal" + assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal" + assert str(tensors[key].ggml_dtype) == str(loaded_tensors[key].ggml_dtype), "Dtypes are not equal"