Skip to content

Commit

Permalink
Make the Python Wrapper more Hackable and simplify Quantization (hugg…
Browse files Browse the repository at this point in the history
…ingface#1010)

* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
  • Loading branch information
LLukas22 authored Oct 6, 2023
1 parent b0442ef commit 904bbda
Show file tree
Hide file tree
Showing 25 changed files with 2,426 additions and 182 deletions.
11 changes: 11 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"python.testing.pytestArgs": [
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
1 change: 1 addition & 0 deletions candle-pyo3/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
tests/_workdir
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
104 changes: 104 additions & 0 deletions candle-pyo3/e5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from candle.utils import load_safetensors, save_gguf, load_gguf
from candle.models.bert import BertModel, Config
import json
from candle import Tensor
from tqdm import tqdm
from dataclasses import fields
import os
import time

from huggingface_hub import hf_hub_download
from transformers import BertTokenizer, AutoModel
import torch

if __name__ == "__main__":
model_name = "intfloat/e5-small-v2"
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
config_file = hf_hub_download(repo_id=model_name, filename="config.json")

tensors = load_safetensors(model_file)
config = Config()
with open(config_file, "r") as f:
raw_config = json.load(f)
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])

# Load the model
model = BertModel(config)
model.load_state_dict(tensors)

hf_model = AutoModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
]

def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
"""Average the hidden states according to the attention mask"""
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

tokenized = tokenizer(sentences, padding=True)
tokens = Tensor(tokenized["input_ids"])
token_type_ids = Tensor(tokenized["token_type_ids"])
encoder_out, _ = model.forward(tokens, token_type_ids)

hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]

hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])

loss = torch.nn.L1Loss()
error = loss(hf_pooled, candle_pooled).mean().item()
print(f"Mean error between torch-referenze and candle: {error}")

# Quantize all attention 'weights'
quantized_tensors = {}
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
# check if the tensor is k-quantizable
if tensor.shape[-1] % 256 == 0:
new_tensor = tensor.quantize("q4k")
else:
new_tensor = tensor.quantize("q5_0")
quantized_tensors[name] = new_tensor
else:
quantized_tensors[name] = tensor.quantize("q8_0")

print(f"Saving quantized tensors")
# Remove all None values from the config
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
# Save the model
quantized_model_file = "e5_small.gguf"
save_gguf(quantized_model_file, quantized_tensors, config_to_save)

file_size_mb = os.path.getsize(model_file) / 1024 / 1024
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
# Load the model from the gguf
tensors, raw_config = load_gguf(quantized_model_file)
config = Config()
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])
model = BertModel(config)
# "embeddings.position_ids" is missing in the gguf as it is i64
model.load_state_dict(tensors, strict=False)

# Run the model again
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")

candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
error = loss(hf_pooled, candle_pooled_2).mean().item()
print(f"Mean error between torch-referenze and quantized-candle: {error}")
29 changes: 27 additions & 2 deletions candle-pyo3/py_src/candle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
from .candle import *
import logging

try:
from .candle import *
except ImportError as e:
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
import os
import platform

# Try to locate CUDA_PATH environment variable
cuda_path = os.environ.get("CUDA_PATH", None)
if cuda_path:
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
if platform.system() == "Windows":
cuda_path = os.path.join(cuda_path, "bin")
else:
cuda_path = os.path.join(cuda_path, "lib64")

logging.warning(f"Adding {cuda_path} to DLL search path...")
os.add_dll_directory(cuda_path)

try:
from .candle import *
except ImportError as inner_e:
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")

__doc__ = candle.__doc__
if hasattr(candle, "__all__"):
__all__ = candle.__all__
__all__ = candle.__all__
8 changes: 8 additions & 0 deletions candle-pyo3/py_src/candle/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Generated content DO NOT EDIT
from .. import functional

gelu = functional.gelu
relu = functional.relu
silu = functional.silu
softmax = functional.softmax
tanh = functional.tanh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@ from os import PathLike
from candle.typing import _ArrayLike, Device
from candle import Tensor, DType, QTensor

@staticmethod
def gelu(tensor: Tensor) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
"""
pass

@staticmethod
def relu(tensor: Tensor) -> Tensor:
"""
Applies the Rectified Linear Unit (ReLU) function to a given tensor.
"""
pass

@staticmethod
def silu(tensor: Tensor) -> Tensor:
"""
Expand All @@ -17,3 +31,10 @@ def softmax(tensor: Tensor, dim: int) -> Tensor:
Applies the Softmax function to a given tensor.#
"""
pass

@staticmethod
def tanh(tensor: Tensor) -> Tensor:
"""
Applies the tanh function to a given tensor.
"""
pass
194 changes: 194 additions & 0 deletions candle-pyo3/py_src/candle/models/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from dataclasses import dataclass
from typing import Optional
from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList
from candle import Tensor
import candle
import candle.functional as F
from typing import Tuple, Optional


@dataclass
class Config:
vocab_size: int = 30522
hidden_size: int = 768
num_hidden_layers: int = 12
num_attention_heads: int = 12
intermediate_size: int = 3072
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
layer_norm_eps: float = 1e-12
pad_token_id: int = 0
position_embedding_type: str = "absolute"
use_cache: bool = True
classifier_dropout: Optional[float] = None
model_type: Optional[str] = "bert"


class BertSelfAttention(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
all_head_size = int(config.num_attention_heads * self.attention_head_size)
hidden_size = config.hidden_size
self.query = Linear(hidden_size, all_head_size)
self.key = Linear(hidden_size, all_head_size)
self.value = Linear(hidden_size, all_head_size)

def transpose_for_scores(self, x: Tensor) -> Tensor:
new_x_shape = x.shape[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.reshape(new_x_shape).transpose(1, 2)
return x.contiguous()

def forward(self, hidden_states: Tensor) -> Tensor:
query = self.query.forward(hidden_states)
key = self.key.forward(hidden_states)
value = self.value.forward(hidden_states)

query = self.transpose_for_scores(query)
key = self.transpose_for_scores(key)
value = self.transpose_for_scores(value)

attention_scores = query.matmul(key.t())
attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
attention_probs = F.softmax(attention_scores, dim=-1)

context_layer = attention_probs.matmul(value)
context_layer = context_layer.transpose(1, 2).contiguous()
context_layer = context_layer.flatten_from(-2)
return context_layer


class BertSelfOutput(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
hidden_states = self.dense.forward(hidden_states)
return self.LayerNorm.forward(hidden_states + input_tensor)


class BertAttention(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)

def forward(self, hidden_states: Tensor) -> Tensor:
self_outputs = self.self.forward(hidden_states)
attention_output = self.output.forward(self_outputs, hidden_states)
return attention_output


class BertIntermediate(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.hidden_size, config.intermediate_size)
self.act = F.gelu if config.hidden_act == "gelu" else F.relu

def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense.forward(hidden_states)
return self.act(hidden_states)


class BertOutput(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
hidden_states = self.dense.forward(hidden_states)
return self.LayerNorm.forward(hidden_states + input_tensor)


class BertLayer(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)

def forward(self, hidden_states: Tensor) -> Tensor:
attention_output = self.attention.forward(hidden_states)
# TODO: Support cross-attention?
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
# TODO: Support something similar to `apply_chunking_to_forward`?
intermediate_output = self.intermediate.forward(attention_output)
layer_output = self.output.forward(intermediate_output, attention_output)
return layer_output


class BertEncoder(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.layer = ModuleList()
for _ in range(config.num_hidden_layers):
self.layer.append(BertLayer(config))

def forward(self, hidden_states: Tensor) -> Tensor:
for l in self.layer:
hidden_states = l.forward(hidden_states)
return hidden_states


class BertEmbeddings(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape(
(1, config.max_position_embeddings)
)

def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor:
(_batch_size, seq_len) = input_ids.shape
input_embeddings = self.word_embeddings.forward(input_ids)
token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)
embeddings: Tensor = input_embeddings + token_type_embeddings

position_ids = list(range(seq_len))
position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device)

embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids))
embeddings = self.LayerNorm(embeddings)
return embeddings


class BertPooler(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.hidden_size, config.hidden_size)
self.activation = F.tanh

def forward(self, hidden_states: Tensor) -> Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense.forward(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output


# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
class BertModel(Module):
def __init__(self, config: Config, add_pooling_layer=True) -> None:
super().__init__()
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None

def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
embeddings = self.embeddings.forward(input_ids, token_type_ids)
encoder_out = self.encoder.forward(embeddings)
pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
return encoder_out, pooled_output
Loading

0 comments on commit 904bbda

Please sign in to comment.