Skip to content

Commit

Permalink
new german and english models, version models
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Oct 20, 2020
1 parent 758bef0 commit cc42ba6
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 19 deletions.
Binary file modified models/de/model.onnx
Binary file not shown.
Binary file modified models/en/model.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion nnsplit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ thiserror = "1.0"
lazy_static = "1.4"
serde = "1.0"
serde_derive = "1.0"
tract-onnx = { version = "0.10.0", optional = true }
tract-onnx = { version = "0.11.1", optional = true }
directories = {version = "3.0.1", optional = true}
minreq = {version = "2.2.1", features = ["https"], optional = true}
url = {version = "2.1.1", optional = true}
Expand Down
14 changes: 7 additions & 7 deletions nnsplit/models.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
de,https://raw.githubusercontent.com/bminixhofer/nnsplit/master/models/de/
en,https://raw.githubusercontent.com/bminixhofer/nnsplit/master/models/en/
tr,https://raw.githubusercontent.com/bminixhofer/nnsplit/master/models/tr/
fr,https://raw.githubusercontent.com/bminixhofer/nnsplit/master/models/fr/
no,https://raw.githubusercontent.com/bminixhofer/nnsplit/master/models/no/
sv,https://raw.githubusercontent.com/bminixhofer/nnsplit/master/models/sv/
zh,https://raw.githubusercontent.com/bminixhofer/nnsplit/master/models/zh/
de,https://github.com/bminixhofer/nnsplit/raw/0.5.0/models/de/model.onnx
en,https://github.com/bminixhofer/nnsplit/raw/0.5.0/models/en/model.onnx
tr,https://github.com/bminixhofer/nnsplit/raw/0.5.0/models/tr/model.onnx
fr,https://github.com/bminixhofer/nnsplit/raw/0.5.0/models/fr/model.onnx
no,https://github.com/bminixhofer/nnsplit/raw/0.5.0/models/no/model.onnx
sv,https://github.com/bminixhofer/nnsplit/raw/0.5.0/models/sv/model.onnx
zh,https://github.com/bminixhofer/nnsplit/raw/0.5.0/models/zh/model.onnx
36 changes: 25 additions & 11 deletions nnsplit/src/tract_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ use tract_onnx::prelude::*;
struct TractBackend {
model: TypedModel,
n_outputs: usize,
length_divisor: usize,
}

impl TractBackend {
fn new(model: TypedModel) -> TractResult<Self> {
let n_outputs = if let TDim::Val(value) = model.outlet_fact(model.outputs[0])?.shape.dim(2)
{
fn new(model: TypedModel, length_divisor: usize) -> TractResult<Self> {
let n_outputs = if let TDim::Val(value) = model.outlet_fact(model.outputs[0])?.shape[2] {
value as usize
} else {
0 // TODO: raise error here
};

Ok(TractBackend { model, n_outputs })
Ok(TractBackend {
model,
n_outputs,
length_divisor,
})
}

fn predict(
Expand All @@ -28,7 +32,10 @@ impl TractBackend {
let input_shape = input.shape();
let opt_model = self
.model
.concretize_stream_dim(input_shape[1])?
.concretize_dims(&SymbolValues::default().with(
's'.into(),
input_shape[1] as i64 / self.length_divisor as i64,
))?
.optimize()?
.into_runnable()?;

Expand Down Expand Up @@ -58,11 +65,14 @@ pub struct NNSplit {
}

impl NNSplit {
fn type_model(model: InferenceModel) -> TractResult<TypedModel> {
fn type_model(model: InferenceModel, length_divisor: usize) -> TractResult<TypedModel> {
model
.with_input_fact(
0,
InferenceFact::dt_shape(u8::datum_type(), tvec!(1.into(), TDim::s())),
InferenceFact::dt_shape(
u8::datum_type(),
tvec!(1.into(), TDim::from('s') * length_divisor),
),
)?
.into_typed()?
.declutter()
Expand All @@ -75,8 +85,9 @@ impl NNSplit {
model_path: P,
options: NNSplitOptions,
) -> Result<Self, Box<dyn Error>> {
let model = NNSplit::type_model(onnx().model_for_path(model_path)?)?;
let backend = TractBackend::new(model)?;
let model =
NNSplit::type_model(onnx().model_for_path(model_path)?, options.length_divisor)?;
let backend = TractBackend::new(model, options.length_divisor)?;

Ok(NNSplit {
backend,
Expand All @@ -88,9 +99,12 @@ impl NNSplit {
#[cfg(feature = "model-loader")]
pub fn load(model_name: &str, options: NNSplitOptions) -> Result<Self, Box<dyn Error>> {
let mut model_data = crate::model_loader::get_resource(model_name, "model.onnx")?.0;
let model = NNSplit::type_model(onnx().model_for_read(&mut model_data)?)?;
let model = NNSplit::type_model(
onnx().model_for_read(&mut model_data)?,
options.length_divisor,
)?;

let backend = TractBackend::new(model)?;
let backend = TractBackend::new(model, options.length_divisor)?;

Ok(NNSplit {
backend,
Expand Down
2 changes: 2 additions & 0 deletions update_version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ update_cargo_toml_version $1 nnsplit/Cargo.toml
update_cargo_toml_version $1 bindings/python/Cargo.toml
update_cargo_toml_version $1-post0 bindings/python/Cargo.build.toml
npm version $1 --prefix bindings/javascript --allow-same-version

$SED -i "s/[0-9]\.[0-9]\.[0-9]/$1/" nnsplit/models.csv

0 comments on commit cc42ba6

Please sign in to comment.