Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Time per epoch and ETA logging when silent=false #64

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions crates/core/src/cpu/backend.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::time::Instant;

use ndarray::{ArrayD, ArrayViewD, IxDyn};
use safetensors::{serialize, SafeTensors};
Expand Down Expand Up @@ -110,7 +111,10 @@ impl Backend {
match layers {
Some(layer_indices) => {
for layer_index in layer_indices {
let layer = self.layers.get_mut(layer_index).expect(&format!("Layer #{} does not exist.", layer_index));
let layer = self
.layers
.get_mut(layer_index)
.expect(&format!("Layer #{} does not exist.", layer_index));
inputs = layer.forward_propagate(inputs, training);
}
}
Expand Down Expand Up @@ -141,6 +145,10 @@ impl Backend {
let mut disappointments = 0;
let mut best_net = self.save();
let mut cost = 0f32;
let mut time: u128;
let mut total_time = 0u128;
let start = Instant::now();
let total_iter = epochs * datasets.len();
while epoch < epochs {
let mut total = 0.0;
for (i, dataset) in datasets.iter().enumerate() {
Expand All @@ -152,7 +160,19 @@ impl Backend {
let minibatch = outputs.dim()[0];
if !self.silent && ((i + 1) * minibatch) % batches == 0 {
cost = total / (batches) as f32;
let msg = format!("Epoch={}, Dataset={}, Cost={}", epoch, i * minibatch, cost);
time = start.elapsed().as_millis() - total_time;
total_time += time;
let current_iter = epoch * datasets.len() + i;
let msg = format!(
"Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s",
epoch,
i * minibatch,
cost,
(time as f32) / 1000.0,
(((total_time as f32) / current_iter as f32)
* (total_iter - current_iter) as f32)
/ 1000.0
);
(self.logger.log)(msg);
total = 0.0;
}
Expand All @@ -165,17 +185,28 @@ impl Backend {
disappointments = 0;
best_cost = cost;
best_net = self.save();
} else {
} else {
disappointments += 1;
if !self.silent {
println!("Patience counter: {} disappointing epochs out of {}.", disappointments, self.patience);
println!(
"Patience counter: {} disappointing epochs out of {}.",
disappointments, self.patience
);
}
}
if disappointments >= self.patience {
if !self.silent {
println!("No improvement for {} epochs. Stopping early at cost={}", disappointments, best_cost);
println!(
"No improvement for {} epochs. Stopping early at cost={}",
disappointments, best_cost
);
}
let net = Self::load(&best_net, Logger { log: |x| println!("{}", x) });
let net = Self::load(
&best_net,
Logger {
log: |x| println!("{}", x),
},
);
self.layers = net.layers;
break;
}
Expand Down
4 changes: 4 additions & 0 deletions deno.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 13 additions & 16 deletions examples/classification/spam.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import {
// Import helpers for metrics
import {
ClassificationReport,
CountVectorizer,
SplitTokenizer,
TfIdfTransformer,
TextCleaner,
TextVectorizer,
// Split the dataset
useSplit,
} from "../../packages/utilities/mod.ts";
import { SigmoidLayer } from "../../mod.ts";

// Define classes
const ymap = ["spam", "ham"];
Expand All @@ -32,25 +32,21 @@ const data = parse(_data);
const x = data.map((msg) => msg[1]);

// Get the classes
const y = data.map((msg) => (ymap.indexOf(msg[0]) === 0 ? -1 : 1));
const y = data.map((msg) => (ymap.indexOf(msg[0]) === 0 ? 0 : 1));

// Split the dataset for training and testing
const [train, test] = useSplit({ ratio: [7, 3], shuffle: true }, x, y);

// Vectorize the text messages

const tokenizer = new SplitTokenizer({
skipWords: "english",
standardize: { lowercase: true },
}).fit(train[0]);
const textCleaner = new TextCleaner({ lowercase: true });

const vec = new CountVectorizer(tokenizer.vocabulary.size);
train[0] = textCleaner.clean(train[0])

const x_vec = vec.transform(tokenizer.transform(train[0]), "f32")
const vec = new TextVectorizer("tfidf").fit(train[0]);

const tfidf = new TfIdfTransformer();
const x_vec = vec.transform(train[0], "f32");

const x_tfidf = tfidf.fit(x_vec).transform(x_vec)

// Setup the CPU backend for Netsaur
await setupBackend(CPU);
Expand All @@ -73,14 +69,15 @@ const net = new Sequential({
// A dense layer with 1 neuron
DenseLayer({ size: [1] }),
// A sigmoid activation layer
SigmoidLayer()
],

// We are using Log Loss for finding cost
cost: Cost.Hinge,
cost: Cost.BinCrossEntropy,
optimizer: NadamOptimizer(),
});

const inputs = tensor(x_tfidf);
const inputs = tensor(x_vec);

const time = performance.now();
// Train the network
Expand All @@ -99,10 +96,10 @@ net.train(

console.log(`training time: ${performance.now() - time}ms`);

const x_vec_test = tfidf.transform(vec.transform(tokenizer.transform(test[0]), "f32"));
const x_vec_test = vec.transform(test[0], "f32");

// Calculate metrics
const res = await net.predict(tensor(x_vec_test));
const y1 = res.data.map((i) => (i < 0 ? -1 : 1));
const y1 = res.data.map((i) => (i < 0.5 ? 0 : 1));
const cMatrix = new ClassificationReport(test[1], y1);
console.log("Confusion Matrix: ", cMatrix);
3 changes: 2 additions & 1 deletion packages/utilities/src/text/vectorizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export class TextVectorizer {
this.mode = mode;
this.mapper = new DiscreteMapper();
}
fit(document: string | string[]) {
fit(document: string | string[]): TextVectorizer {
this.mapper.fit(
(Array.isArray(document) ? document.join(" ") : document).split(" ")
);
Expand All @@ -27,6 +27,7 @@ export class TextVectorizer {
this.transformer.fit(this.encoder.transform(tokens, "f32"));
}
}
return this;
}
transform<DT extends DataType>(
document: string | string[],
Expand Down
3 changes: 2 additions & 1 deletion packages/utilities/src/utils/array/unique.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
*/
export function useUnique<T>(arr: ArrayLike<T>): T[] {
const array = Array.from(arr);
return array.filter((x, i) => array.indexOf(x) === i);
return [...new Set(array)]
// return array.filter((x, i) => array.indexOf(x) === i);
}
Loading