Skip to content

Commit

Permalink
feat: Time per epoch and ETA logging when silent=false (#64)
Browse files Browse the repository at this point in the history
* add time logger

* return vec on fit

* use Set

* use logistic reg
  • Loading branch information
retraigo authored Sep 20, 2024
1 parent a3d719e commit 1d5a750
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 24 deletions.
43 changes: 37 additions & 6 deletions crates/core/src/cpu/backend.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::time::Instant;

use ndarray::{ArrayD, ArrayViewD, IxDyn};
use safetensors::{serialize, SafeTensors};
Expand Down Expand Up @@ -110,7 +111,10 @@ impl Backend {
match layers {
Some(layer_indices) => {
for layer_index in layer_indices {
let layer = self.layers.get_mut(layer_index).expect(&format!("Layer #{} does not exist.", layer_index));
let layer = self
.layers
.get_mut(layer_index)
.expect(&format!("Layer #{} does not exist.", layer_index));
inputs = layer.forward_propagate(inputs, training);
}
}
Expand Down Expand Up @@ -141,6 +145,10 @@ impl Backend {
let mut disappointments = 0;
let mut best_net = self.save();
let mut cost = 0f32;
let mut time: u128;
let mut total_time = 0u128;
let start = Instant::now();
let total_iter = epochs * datasets.len();
while epoch < epochs {
let mut total = 0.0;
for (i, dataset) in datasets.iter().enumerate() {
Expand All @@ -152,7 +160,19 @@ impl Backend {
let minibatch = outputs.dim()[0];
if !self.silent && ((i + 1) * minibatch) % batches == 0 {
cost = total / (batches) as f32;
let msg = format!("Epoch={}, Dataset={}, Cost={}", epoch, i * minibatch, cost);
time = start.elapsed().as_millis() - total_time;
total_time += time;
let current_iter = epoch * datasets.len() + i;
let msg = format!(
"Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s",
epoch,
i * minibatch,
cost,
(time as f32) / 1000.0,
(((total_time as f32) / current_iter as f32)
* (total_iter - current_iter) as f32)
/ 1000.0
);
(self.logger.log)(msg);
total = 0.0;
}
Expand All @@ -165,17 +185,28 @@ impl Backend {
disappointments = 0;
best_cost = cost;
best_net = self.save();
} else {
} else {
disappointments += 1;
if !self.silent {
println!("Patience counter: {} disappointing epochs out of {}.", disappointments, self.patience);
println!(
"Patience counter: {} disappointing epochs out of {}.",
disappointments, self.patience
);
}
}
if disappointments >= self.patience {
if !self.silent {
println!("No improvement for {} epochs. Stopping early at cost={}", disappointments, best_cost);
println!(
"No improvement for {} epochs. Stopping early at cost={}",
disappointments, best_cost
);
}
let net = Self::load(&best_net, Logger { log: |x| println!("{}", x) });
let net = Self::load(
&best_net,
Logger {
log: |x| println!("{}", x),
},
);
self.layers = net.layers;
break;
}
Expand Down
4 changes: 4 additions & 0 deletions deno.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 13 additions & 16 deletions examples/classification/spam.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import {
// Import helpers for metrics
import {
ClassificationReport,
CountVectorizer,
SplitTokenizer,
TfIdfTransformer,
TextCleaner,
TextVectorizer,
// Split the dataset
useSplit,
} from "../../packages/utilities/mod.ts";
import { SigmoidLayer } from "../../mod.ts";

// Define classes
const ymap = ["spam", "ham"];
Expand All @@ -32,25 +32,21 @@ const data = parse(_data);
const x = data.map((msg) => msg[1]);

// Get the classes
const y = data.map((msg) => (ymap.indexOf(msg[0]) === 0 ? -1 : 1));
const y = data.map((msg) => (ymap.indexOf(msg[0]) === 0 ? 0 : 1));

// Split the dataset for training and testing
const [train, test] = useSplit({ ratio: [7, 3], shuffle: true }, x, y);

// Vectorize the text messages

const tokenizer = new SplitTokenizer({
skipWords: "english",
standardize: { lowercase: true },
}).fit(train[0]);
const textCleaner = new TextCleaner({ lowercase: true });

const vec = new CountVectorizer(tokenizer.vocabulary.size);
train[0] = textCleaner.clean(train[0])

const x_vec = vec.transform(tokenizer.transform(train[0]), "f32")
const vec = new TextVectorizer("tfidf").fit(train[0]);

const tfidf = new TfIdfTransformer();
const x_vec = vec.transform(train[0], "f32");

const x_tfidf = tfidf.fit(x_vec).transform(x_vec)

// Setup the CPU backend for Netsaur
await setupBackend(CPU);
Expand All @@ -73,14 +69,15 @@ const net = new Sequential({
// A dense layer with 1 neuron
DenseLayer({ size: [1] }),
// A sigmoid activation layer
SigmoidLayer()
],

// We are using Log Loss for finding cost
cost: Cost.Hinge,
cost: Cost.BinCrossEntropy,
optimizer: NadamOptimizer(),
});

const inputs = tensor(x_tfidf);
const inputs = tensor(x_vec);

const time = performance.now();
// Train the network
Expand All @@ -99,10 +96,10 @@ net.train(

console.log(`training time: ${performance.now() - time}ms`);

const x_vec_test = tfidf.transform(vec.transform(tokenizer.transform(test[0]), "f32"));
const x_vec_test = vec.transform(test[0], "f32");

// Calculate metrics
const res = await net.predict(tensor(x_vec_test));
const y1 = res.data.map((i) => (i < 0 ? -1 : 1));
const y1 = res.data.map((i) => (i < 0.5 ? 0 : 1));
const cMatrix = new ClassificationReport(test[1], y1);
console.log("Confusion Matrix: ", cMatrix);
3 changes: 2 additions & 1 deletion packages/utilities/src/text/vectorizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export class TextVectorizer {
this.mode = mode;
this.mapper = new DiscreteMapper();
}
fit(document: string | string[]) {
fit(document: string | string[]): TextVectorizer {
this.mapper.fit(
(Array.isArray(document) ? document.join(" ") : document).split(" ")
);
Expand All @@ -27,6 +27,7 @@ export class TextVectorizer {
this.transformer.fit(this.encoder.transform(tokens, "f32"));
}
}
return this;
}
transform<DT extends DataType>(
document: string | string[],
Expand Down
3 changes: 2 additions & 1 deletion packages/utilities/src/utils/array/unique.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
*/
export function useUnique<T>(arr: ArrayLike<T>): T[] {
const array = Array.from(arr);
return array.filter((x, i) => array.indexOf(x) === i);
return [...new Set(array)]
// return array.filter((x, i) => array.indexOf(x) === i);
}

0 comments on commit 1d5a750

Please sign in to comment.