Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
add bert model
Browse files Browse the repository at this point in the history
Co-authored-by: Lukas Kreussel <[email protected]>
Co-authored-by: Philpax <[email protected]>
  • Loading branch information
3 people committed Aug 7, 2023
1 parent c3b868a commit 7a10cfb
Show file tree
Hide file tree
Showing 6 changed files with 522 additions and 1 deletion.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ impl Context {
pub fn storage(&self) -> &ContextStorage {
self.storage.as_ref().unwrap()
}

/// Set all values of the tensor with the specified value.
pub fn set_f32(&self, a: &Tensor, x: f32) -> Tensor {
let raw = unsafe { sys::ggml_set_f32(a.ptr.as_ptr(), x) };
self.new_tensor_raw(raw)
}
}
// Operations
impl Context {
Expand Down Expand Up @@ -598,6 +604,30 @@ impl Context {
};
self.new_tensor_raw(tensor)
}

/// Creates a new tensor with the square of `a`
pub fn op_sqr(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_sqr(self.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Creates a new tensor with the square-root of `a`
pub fn op_sqrt(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_sqrt(self.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Unknown
pub fn op_sum(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_sum(self.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Unknown
pub fn op_div(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_div(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}
}
// Public to this crate methods
impl Context {
Expand Down
4 changes: 3 additions & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" }
llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" }
llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" }
llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" }
llm-bert = { path = "../models/bert", optional = true, version = "0.2.0-dev" }

serde = { workspace = true }
tracing = { workspace = true }
Expand All @@ -34,13 +35,14 @@ default = ["models", "tokenizers-remote"]

tokenizers-remote = ["llm-base/tokenizers-remote"]

models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"]
models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt", "bert"]
llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
bloom = ["dep:llm-bloom"]
gptneox = ["dep:llm-gptneox"]
mpt = ["dep:llm-mpt"]
bert = ["dep:llm-bert"]
# Falcon is off by default. See `llm_falcon`'s module documentation for more information.
falcon = ["dep:llm-falcon"]

Expand Down
1 change: 1 addition & 0 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ macro_rules! define_models {
}

define_models!(
(bert, "bert", Bert, llm_bert, "Bert"),
(bloom, "bloom", Bloom, llm_bloom, "BLOOM"),
(gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"),
(gptj, "gptj", GptJ, llm_gptj, "GPT-J"),
Expand Down
14 changes: 14 additions & 0 deletions crates/models/bert/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "llm-bert"
version = "0.2.0-dev"
license = { workspace = true }
repository = { workspace = true }
description = "An implementation of BERT for the `llm` ecosystem."
edition = "2021"
readme = "../../../README.md"

[dependencies]
bytemuck.workspace = true
llm-base = { path = "../../llm-base", version = "0.2.0-dev" }
tracing = { version = "0.1", features = ["log"] }

Loading

0 comments on commit 7a10cfb

Please sign in to comment.