Skip to content

Commit

Permalink
feat: Add postprocessing (and fix WASM) (#65)
Browse files Browse the repository at this point in the history
* add postprocessing for sign and step

* fix wasm
  • Loading branch information
retraigo authored Sep 25, 2024
1 parent 1d5a750 commit a811fd1
Show file tree
Hide file tree
Showing 22 changed files with 273 additions and 93 deletions.
2 changes: 1 addition & 1 deletion crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ safetensors = { workspace = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen = "0.2.92"
getrandom = { version = "0.2", features = ["js"] }
js-sys = "0.3.69"
js-sys = "0.3.69"
54 changes: 31 additions & 23 deletions crates/core/src/cpu/backend.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use std::collections::HashMap;
use std::time::Instant;

use ndarray::{ArrayD, ArrayViewD, IxDyn};
use safetensors::{serialize, SafeTensors};

use crate::{
to_arr, ActivationCPULayer, BackendConfig, BatchNorm1DCPULayer, BatchNorm2DCPULayer,
BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUScheduler, Conv2DCPULayer, ConvTensors,
ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors, Dropout1DCPULayer,
Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger, Pool2DCPULayer, SoftmaxCPULayer,
Tensor, Tensors,
BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUPostProcessor, CPUScheduler,
Conv2DCPULayer, ConvTensors, ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors,
Dropout1DCPULayer, Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger,
Pool2DCPULayer, PostProcessor, SoftmaxCPULayer, Tensor, Tensors, Timer,
};

pub struct Backend {
Expand All @@ -23,10 +22,16 @@ pub struct Backend {
pub optimizer: CPUOptimizer,
pub scheduler: CPUScheduler,
pub logger: Logger,
pub timer: Timer,
}

impl Backend {
pub fn new(config: BackendConfig, logger: Logger, mut tensors: Option<Vec<Tensors>>) -> Self {
pub fn new(
config: BackendConfig,
logger: Logger,
timer: Timer,
mut tensors: Option<Vec<Tensors>>,
) -> Self {
let mut layers = Vec::new();
let mut size = config.size.clone();
for layer in config.layers.iter() {
Expand Down Expand Up @@ -99,6 +104,7 @@ impl Backend {
optimizer,
scheduler,
size,
timer,
}
}

Expand Down Expand Up @@ -147,7 +153,7 @@ impl Backend {
let mut cost = 0f32;
let mut time: u128;
let mut total_time = 0u128;
let start = Instant::now();
let start = (self.timer.now)();
let total_iter = epochs * datasets.len();
while epoch < epochs {
let mut total = 0.0;
Expand All @@ -160,11 +166,11 @@ impl Backend {
let minibatch = outputs.dim()[0];
if !self.silent && ((i + 1) * minibatch) % batches == 0 {
cost = total / (batches) as f32;
time = start.elapsed().as_millis() - total_time;
time = ((self.timer.now)() - start) - total_time;
total_time += time;
let current_iter = epoch * datasets.len() + i;
let msg = format!(
"Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s",
"Epoch={}, Dataset={}, Cost={}, Time={:.3}s, ETA={:.3}s",
epoch,
i * minibatch,
cost,
Expand All @@ -188,25 +194,20 @@ impl Backend {
} else {
disappointments += 1;
if !self.silent {
println!(
(self.logger.log)(format!(
"Patience counter: {} disappointing epochs out of {}.",
disappointments, self.patience
);
));
}
}
if disappointments >= self.patience {
if !self.silent {
println!(
(self.logger.log)(format!(
"No improvement for {} epochs. Stopping early at cost={}",
disappointments, best_cost
);
));
}
let net = Self::load(
&best_net,
Logger {
log: |x| println!("{}", x),
},
);
let net = Self::load(&best_net, self.logger.clone(), self.timer.clone());
self.layers = net.layers;
break;
}
Expand All @@ -215,11 +216,18 @@ impl Backend {
}
}

pub fn predict(&mut self, data: ArrayD<f32>, layers: Option<Vec<usize>>) -> ArrayD<f32> {
pub fn predict(
&mut self,
data: ArrayD<f32>,
postprocess: PostProcessor,
layers: Option<Vec<usize>>,
) -> ArrayD<f32> {
let processor = CPUPostProcessor::from(&postprocess);
for layer in &mut self.layers {
layer.reset(1);
}
self.forward_propagate(data, false, layers)
let res = self.forward_propagate(data, false, layers);
processor.process(res)
}

pub fn save(&self) -> Vec<u8> {
Expand Down Expand Up @@ -272,7 +280,7 @@ impl Backend {
serialize(tensors, &Some(metadata)).unwrap()
}

pub fn load(buffer: &[u8], logger: Logger) -> Self {
pub fn load(buffer: &[u8], logger: Logger, timer: Timer) -> Self {
let tensors = SafeTensors::deserialize(buffer).unwrap();
let (_, metadata) = SafeTensors::read_metadata(buffer).unwrap();
let data = metadata.metadata().as_ref().unwrap();
Expand Down Expand Up @@ -304,6 +312,6 @@ impl Backend {
};
}

Backend::new(config, logger, Some(layers))
Backend::new(config, logger, timer, Some(layers))
}
}
4 changes: 3 additions & 1 deletion crates/core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod layers;
mod optimizers;
mod schedulers;
mod regularizer;
mod postprocessing;

pub use activation::*;
pub use backend::*;
Expand All @@ -14,4 +15,5 @@ pub use init::*;
pub use layers::*;
pub use optimizers::*;
pub use schedulers::*;
pub use regularizer::*;
pub use regularizer::*;
pub use postprocessing::*;
28 changes: 28 additions & 0 deletions crates/core/src/cpu/postprocessing/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use ndarray::ArrayD;
use crate::PostProcessor;

mod step;
use step::CPUStepFunction;

pub enum CPUPostProcessor {
None,
Sign,
Step(CPUStepFunction),
}

impl CPUPostProcessor {
pub fn from(processor: &PostProcessor) -> Self {
match processor {
PostProcessor::None => CPUPostProcessor::None,
PostProcessor::Sign => CPUPostProcessor::Sign,
PostProcessor::Step(config) => CPUPostProcessor::Step(CPUStepFunction::new(config)),
}
}
pub fn process(&self, x: ArrayD<f32>) -> ArrayD<f32> {
match self {
CPUPostProcessor::None => x,
CPUPostProcessor::Sign => x.map(|y| y.signum()),
CPUPostProcessor::Step(processor) => x.map(|y| processor.step(*y)),
}
}
}
22 changes: 22 additions & 0 deletions crates/core/src/cpu/postprocessing/step.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::StepFunctionConfig;

pub struct CPUStepFunction {
thresholds: Vec<f32>,
values: Vec<f32>
}
impl CPUStepFunction {
pub fn new(config: &StepFunctionConfig) -> Self {
return Self {
thresholds: config.thresholds.clone(),
values: config.values.clone()
}
}
pub fn step(&self, x: f32) -> f32 {
for (i, &threshold) in self.thresholds.iter().enumerate() {
if x < threshold {
return self.values[i];
}
}
return self.values.last().unwrap().clone()
}
}
18 changes: 13 additions & 5 deletions crates/core/src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::slice::{from_raw_parts, from_raw_parts_mut};
use std::time::{SystemTime, UNIX_EPOCH};

use crate::{
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, TrainOptions,
RESOURCES,
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, Timer,
TrainOptions, RESOURCES,
};

type AllocBufferFn = extern "C" fn(usize) -> *mut u8;
Expand All @@ -11,10 +12,17 @@ fn log(string: String) {
println!("{}", string)
}

fn now() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Your system is behind the Unix Epoch")
.as_millis()
}

#[no_mangle]
pub extern "C" fn ffi_backend_create(ptr: *const u8, len: usize, alloc: AllocBufferFn) -> usize {
let config = decode_json(ptr, len);
let net_backend = Backend::new(config, Logger { log }, None);
let net_backend = Backend::new(config, Logger { log }, Timer { now }, None);
let buf: Vec<u8> = net_backend
.size
.iter()
Expand Down Expand Up @@ -75,7 +83,7 @@ pub extern "C" fn ffi_backend_predict(

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
let res = backend[id].predict(inputs, options.layers);
let res = backend[id].predict(inputs, options.post_process, options.layers);
outputs.copy_from_slice(res.as_slice().unwrap());
});
}
Expand All @@ -98,7 +106,7 @@ pub extern "C" fn ffi_backend_load(
alloc: AllocBufferFn,
) -> usize {
let buffer = unsafe { from_raw_parts(file_ptr, file_len) };
let net_backend = Backend::load(buffer, Logger { log });
let net_backend = Backend::load(buffer, Logger { log }, Timer { now });
let buf: Vec<u8> = net_backend.size.iter().map(|x| *x as u8).collect();
let size_ptr = alloc(buf.len());
let output_shape = unsafe { from_raw_parts_mut(size_ptr, buf.len()) };
Expand Down
16 changes: 16 additions & 0 deletions crates/core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ pub enum Scheduler {
OneCycle(OneCycleScheduler),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StepFunctionConfig {
pub thresholds: Vec<f32>,
pub values: Vec<f32>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type", content = "config")]
#[serde(rename_all = "lowercase")]
pub enum PostProcessor {
None,
Sign,
Step(StepFunctionConfig),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct TrainOptions {
Expand All @@ -212,6 +227,7 @@ pub struct PredictOptions {
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
pub layers: Option<Vec<usize>>,
pub post_process: PostProcessor,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
6 changes: 6 additions & 0 deletions crates/core/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ use ndarray::ArrayD;
use safetensors::tensor::TensorView;
use serde::Deserialize;

#[derive(Clone)]
pub struct Logger {
pub log: fn(string: String) -> (),
}

#[derive(Clone)]
pub struct Timer {
pub now: fn() -> u128,
}

pub fn length(shape: Vec<usize>) -> usize {
return shape.iter().fold(1, |i, x| i * x);
}
Expand Down
31 changes: 23 additions & 8 deletions crates/core/src/wasm.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
use js_sys::{Array, Float32Array, Uint8Array};
use ndarray::ArrayD;

use wasm_bindgen::{prelude::wasm_bindgen, JsValue};

use crate::{Backend, Dataset, Logger, PredictOptions, TrainOptions, RESOURCES};
use crate::{Backend, Dataset, Logger, PredictOptions, Timer, TrainOptions, RESOURCES};

#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
#[wasm_bindgen(js_namespace = Date)]
fn now() -> f64;

}

fn console_log(string: String) {
log(string.as_str())
}

fn performance_now() -> u128 {
now() as u128
}

#[wasm_bindgen]
pub fn wasm_backend_create(config: String, shape: Array) -> usize {
let config = serde_json::from_str(&config).unwrap();
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = Backend::new(config, logger, None);
let net_backend = Backend::new(
config,
logger,
Timer {
now: performance_now,
},
None,
);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
Expand All @@ -37,7 +50,6 @@ pub fn wasm_backend_create(config: String, shape: Array) -> usize {
#[wasm_bindgen]
pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String) {
let options: TrainOptions = serde_json::from_str(&options).unwrap();

let mut datasets = Vec::new();
for i in 0..options.datasets {
let input = buffers[i * 2].to_vec();
Expand All @@ -47,7 +59,6 @@ pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String
outputs: ArrayD::from_shape_vec(options.output_shape.clone(), output).unwrap(),
});
}

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
backend[id].train(datasets, options.epochs, options.batches, options.rate)
Expand All @@ -59,11 +70,12 @@ pub fn wasm_backend_predict(id: usize, buffer: Float32Array, options: String) ->
let options: PredictOptions = serde_json::from_str(&options).unwrap();
let inputs = ArrayD::from_shape_vec(options.input_shape, buffer.to_vec()).unwrap();

let res = ArrayD::zeros(options.output_shape);
let mut res = ArrayD::zeros(options.output_shape.clone());

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
let _res = backend[id].predict(inputs, options.layers);
let _res = backend[id].predict(inputs, options.post_process, options.layers);
res.assign(&ArrayD::from_shape_vec(options.output_shape, _res.as_slice().unwrap().to_vec()).unwrap());
});
Float32Array::from(res.as_slice().unwrap())
}
Expand All @@ -82,7 +94,10 @@ pub fn wasm_backend_save(id: usize) -> Uint8Array {
pub fn wasm_backend_load(buffer: Uint8Array, shape: Array) -> usize {
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger);
let timer = Timer {
now: performance_now,
};
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger, timer);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
Expand Down
Loading

0 comments on commit a811fd1

Please sign in to comment.