forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make the Python Wrapper more Hackable and simplify Quantization (hugg…
…ingface#1010) * Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
- Loading branch information
Showing
25 changed files
with
2,426 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"[python]": { | ||
"editor.defaultFormatter": "ms-python.black-formatter" | ||
}, | ||
"python.formatting.provider": "none", | ||
"python.testing.pytestArgs": [ | ||
"candle-pyo3" | ||
], | ||
"python.testing.unittestEnabled": false, | ||
"python.testing.pytestEnabled": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
tests/_workdir | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from candle.utils import load_safetensors, save_gguf, load_gguf | ||
from candle.models.bert import BertModel, Config | ||
import json | ||
from candle import Tensor | ||
from tqdm import tqdm | ||
from dataclasses import fields | ||
import os | ||
import time | ||
|
||
from huggingface_hub import hf_hub_download | ||
from transformers import BertTokenizer, AutoModel | ||
import torch | ||
|
||
if __name__ == "__main__": | ||
model_name = "intfloat/e5-small-v2" | ||
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors") | ||
config_file = hf_hub_download(repo_id=model_name, filename="config.json") | ||
|
||
tensors = load_safetensors(model_file) | ||
config = Config() | ||
with open(config_file, "r") as f: | ||
raw_config = json.load(f) | ||
for field in fields(config): | ||
if field.name in raw_config: | ||
setattr(config, field.name, raw_config[field.name]) | ||
|
||
# Load the model | ||
model = BertModel(config) | ||
model.load_state_dict(tensors) | ||
|
||
hf_model = AutoModel.from_pretrained(model_name) | ||
tokenizer = BertTokenizer.from_pretrained(model_name) | ||
|
||
sentences = [ | ||
"The cat sits outside", | ||
"A man is playing guitar", | ||
"I love pasta", | ||
"The new movie is awesome", | ||
"The cat plays in the garden", | ||
"A woman watches TV", | ||
"The new movie is so great", | ||
"Do you like pizza?", | ||
] | ||
|
||
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor): | ||
"""Average the hidden states according to the attention mask""" | ||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | ||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | ||
|
||
tokenized = tokenizer(sentences, padding=True) | ||
tokens = Tensor(tokenized["input_ids"]) | ||
token_type_ids = Tensor(tokenized["token_type_ids"]) | ||
encoder_out, _ = model.forward(tokens, token_type_ids) | ||
|
||
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt") | ||
hf_result = hf_model(**hf_tokenized)["last_hidden_state"] | ||
|
||
hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"]) | ||
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"]) | ||
|
||
loss = torch.nn.L1Loss() | ||
error = loss(hf_pooled, candle_pooled).mean().item() | ||
print(f"Mean error between torch-referenze and candle: {error}") | ||
|
||
# Quantize all attention 'weights' | ||
quantized_tensors = {} | ||
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"): | ||
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name): | ||
# check if the tensor is k-quantizable | ||
if tensor.shape[-1] % 256 == 0: | ||
new_tensor = tensor.quantize("q4k") | ||
else: | ||
new_tensor = tensor.quantize("q5_0") | ||
quantized_tensors[name] = new_tensor | ||
else: | ||
quantized_tensors[name] = tensor.quantize("q8_0") | ||
|
||
print(f"Saving quantized tensors") | ||
# Remove all None values from the config | ||
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None} | ||
# Save the model | ||
quantized_model_file = "e5_small.gguf" | ||
save_gguf(quantized_model_file, quantized_tensors, config_to_save) | ||
|
||
file_size_mb = os.path.getsize(model_file) / 1024 / 1024 | ||
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024 | ||
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB") | ||
# Load the model from the gguf | ||
tensors, raw_config = load_gguf(quantized_model_file) | ||
config = Config() | ||
for field in fields(config): | ||
if field.name in raw_config: | ||
setattr(config, field.name, raw_config[field.name]) | ||
model = BertModel(config) | ||
# "embeddings.position_ids" is missing in the gguf as it is i64 | ||
model.load_state_dict(tensors, strict=False) | ||
|
||
# Run the model again | ||
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids) | ||
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu") | ||
|
||
candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"]) | ||
error = loss(hf_pooled, candle_pooled_2).mean().item() | ||
print(f"Mean error between torch-referenze and quantized-candle: {error}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,30 @@ | ||
from .candle import * | ||
import logging | ||
|
||
try: | ||
from .candle import * | ||
except ImportError as e: | ||
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here | ||
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...") | ||
import os | ||
import platform | ||
|
||
# Try to locate CUDA_PATH environment variable | ||
cuda_path = os.environ.get("CUDA_PATH", None) | ||
if cuda_path: | ||
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}") | ||
if platform.system() == "Windows": | ||
cuda_path = os.path.join(cuda_path, "bin") | ||
else: | ||
cuda_path = os.path.join(cuda_path, "lib64") | ||
|
||
logging.warning(f"Adding {cuda_path} to DLL search path...") | ||
os.add_dll_directory(cuda_path) | ||
|
||
try: | ||
from .candle import * | ||
except ImportError as inner_e: | ||
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.") | ||
|
||
__doc__ = candle.__doc__ | ||
if hasattr(candle, "__all__"): | ||
__all__ = candle.__all__ | ||
__all__ = candle.__all__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Generated content DO NOT EDIT | ||
from .. import functional | ||
|
||
gelu = functional.gelu | ||
relu = functional.relu | ||
silu = functional.silu | ||
softmax = functional.softmax | ||
tanh = functional.tanh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList | ||
from candle import Tensor | ||
import candle | ||
import candle.functional as F | ||
from typing import Tuple, Optional | ||
|
||
|
||
@dataclass | ||
class Config: | ||
vocab_size: int = 30522 | ||
hidden_size: int = 768 | ||
num_hidden_layers: int = 12 | ||
num_attention_heads: int = 12 | ||
intermediate_size: int = 3072 | ||
hidden_act: str = "gelu" | ||
hidden_dropout_prob: float = 0.1 | ||
max_position_embeddings: int = 512 | ||
type_vocab_size: int = 2 | ||
initializer_range: float = 0.02 | ||
layer_norm_eps: float = 1e-12 | ||
pad_token_id: int = 0 | ||
position_embedding_type: str = "absolute" | ||
use_cache: bool = True | ||
classifier_dropout: Optional[float] = None | ||
model_type: Optional[str] = "bert" | ||
|
||
|
||
class BertSelfAttention(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.num_attention_heads = config.num_attention_heads | ||
self.attention_head_size = int(config.hidden_size / self.num_attention_heads) | ||
all_head_size = int(config.num_attention_heads * self.attention_head_size) | ||
hidden_size = config.hidden_size | ||
self.query = Linear(hidden_size, all_head_size) | ||
self.key = Linear(hidden_size, all_head_size) | ||
self.value = Linear(hidden_size, all_head_size) | ||
|
||
def transpose_for_scores(self, x: Tensor) -> Tensor: | ||
new_x_shape = x.shape[:-1] + ( | ||
self.num_attention_heads, | ||
self.attention_head_size, | ||
) | ||
x = x.reshape(new_x_shape).transpose(1, 2) | ||
return x.contiguous() | ||
|
||
def forward(self, hidden_states: Tensor) -> Tensor: | ||
query = self.query.forward(hidden_states) | ||
key = self.key.forward(hidden_states) | ||
value = self.value.forward(hidden_states) | ||
|
||
query = self.transpose_for_scores(query) | ||
key = self.transpose_for_scores(key) | ||
value = self.transpose_for_scores(value) | ||
|
||
attention_scores = query.matmul(key.t()) | ||
attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5) | ||
attention_probs = F.softmax(attention_scores, dim=-1) | ||
|
||
context_layer = attention_probs.matmul(value) | ||
context_layer = context_layer.transpose(1, 2).contiguous() | ||
context_layer = context_layer.flatten_from(-2) | ||
return context_layer | ||
|
||
|
||
class BertSelfOutput(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.dense = Linear(config.hidden_size, config.hidden_size) | ||
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
|
||
def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: | ||
hidden_states = self.dense.forward(hidden_states) | ||
return self.LayerNorm.forward(hidden_states + input_tensor) | ||
|
||
|
||
class BertAttention(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.self = BertSelfAttention(config) | ||
self.output = BertSelfOutput(config) | ||
|
||
def forward(self, hidden_states: Tensor) -> Tensor: | ||
self_outputs = self.self.forward(hidden_states) | ||
attention_output = self.output.forward(self_outputs, hidden_states) | ||
return attention_output | ||
|
||
|
||
class BertIntermediate(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.dense = Linear(config.hidden_size, config.intermediate_size) | ||
self.act = F.gelu if config.hidden_act == "gelu" else F.relu | ||
|
||
def forward(self, hidden_states: Tensor) -> Tensor: | ||
hidden_states = self.dense.forward(hidden_states) | ||
return self.act(hidden_states) | ||
|
||
|
||
class BertOutput(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.dense = Linear(config.intermediate_size, config.hidden_size) | ||
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
|
||
def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: | ||
hidden_states = self.dense.forward(hidden_states) | ||
return self.LayerNorm.forward(hidden_states + input_tensor) | ||
|
||
|
||
class BertLayer(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.attention = BertAttention(config) | ||
self.intermediate = BertIntermediate(config) | ||
self.output = BertOutput(config) | ||
|
||
def forward(self, hidden_states: Tensor) -> Tensor: | ||
attention_output = self.attention.forward(hidden_states) | ||
# TODO: Support cross-attention? | ||
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 | ||
# TODO: Support something similar to `apply_chunking_to_forward`? | ||
intermediate_output = self.intermediate.forward(attention_output) | ||
layer_output = self.output.forward(intermediate_output, attention_output) | ||
return layer_output | ||
|
||
|
||
class BertEncoder(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.layer = ModuleList() | ||
for _ in range(config.num_hidden_layers): | ||
self.layer.append(BertLayer(config)) | ||
|
||
def forward(self, hidden_states: Tensor) -> Tensor: | ||
for l in self.layer: | ||
hidden_states = l.forward(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class BertEmbeddings(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) | ||
self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size) | ||
self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size) | ||
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape( | ||
(1, config.max_position_embeddings) | ||
) | ||
|
||
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor: | ||
(_batch_size, seq_len) = input_ids.shape | ||
input_embeddings = self.word_embeddings.forward(input_ids) | ||
token_type_embeddings = self.token_type_embeddings.forward(token_type_ids) | ||
embeddings: Tensor = input_embeddings + token_type_embeddings | ||
|
||
position_ids = list(range(seq_len)) | ||
position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device) | ||
|
||
embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids)) | ||
embeddings = self.LayerNorm(embeddings) | ||
return embeddings | ||
|
||
|
||
class BertPooler(Module): | ||
def __init__(self, config: Config) -> None: | ||
super().__init__() | ||
self.dense = Linear(config.hidden_size, config.hidden_size) | ||
self.activation = F.tanh | ||
|
||
def forward(self, hidden_states: Tensor) -> Tensor: | ||
first_token_tensor = hidden_states[:, 0] | ||
pooled_output = self.dense.forward(first_token_tensor) | ||
pooled_output = self.activation(pooled_output) | ||
return pooled_output | ||
|
||
|
||
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 | ||
class BertModel(Module): | ||
def __init__(self, config: Config, add_pooling_layer=True) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.embeddings = BertEmbeddings(config) | ||
self.encoder = BertEncoder(config) | ||
self.pooler = BertPooler(config) if add_pooling_layer else None | ||
|
||
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]: | ||
embeddings = self.embeddings.forward(input_ids, token_type_ids) | ||
encoder_out = self.encoder.forward(embeddings) | ||
pooled_output = self.pooler(encoder_out) if self.pooler is not None else None | ||
return encoder_out, pooled_output |
Oops, something went wrong.