From afff0915cb1ad24e2872b6c1954b191dc647ee8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Mary=C5=84czak?= Date: Thu, 24 Oct 2024 21:10:35 +0200 Subject: [PATCH] neothesia-ai: Commit random machine learning related experiments (#212) --- Cargo.lock | 389 ++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + neothesia-ai/.gitignore | 2 + neothesia-ai/Cargo.toml | 16 ++ neothesia-ai/src/args.rs | 69 +++++++ neothesia-ai/src/audio.rs | 30 +++ neothesia-ai/src/main.rs | 402 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 909 insertions(+) create mode 100644 neothesia-ai/.gitignore create mode 100644 neothesia-ai/Cargo.toml create mode 100644 neothesia-ai/src/args.rs create mode 100644 neothesia-ai/src/audio.rs create mode 100644 neothesia-ai/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 808e5d69..38e8fc32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1098,6 +1098,15 @@ dependencies = [ "winreg", ] +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "endi" version = "1.1.0" @@ -1220,6 +1229,12 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "extended" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" + [[package]] name = "fast-srgb8" version = "1.0.0" @@ -1255,6 +1270,16 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "flatbuffers" +version = "24.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" +dependencies = [ + "bitflags 1.3.2", + "rustc_version", +] + [[package]] name = "flate2" version = "1.0.33" @@ -1899,6 +1924,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + [[package]] name = "jni" version = "0.21.1" @@ -2105,6 +2136,16 @@ dependencies = [ "libc", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -2244,6 +2285,21 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk" version = "0.8.0" @@ -2333,6 +2389,21 @@ dependencies = [ "winit", ] +[[package]] +name = "neothesia-ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "midly", + "ndarray", + "rten", + "rten-tensor", + "serde", + "serde_json", + "symphonia", + "symphonium", +] + [[package]] name = "neothesia-cli" version = "0.1.0" @@ -2397,6 +2468,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.4.2" @@ -2408,6 +2488,15 @@ dependencies = [ "syn 2.0.77", ] +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2983,6 +3072,21 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22686f4785f02a4fcc856d3b3bb19bf6c8160d103f7a99cc258bddd0251dc7f2" +[[package]] +name = "portable-atomic" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" + +[[package]] +name = "portable-atomic-util" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" +dependencies = [ + "portable-atomic", +] + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -2998,6 +3102,15 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro-crate" version = "3.2.0" @@ -3147,6 +3260,12 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -3177,6 +3296,15 @@ dependencies = [ "font-types", ] +[[package]] +name = "realfft" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390252372b7f2aac8360fc5e72eba10136b166d6faeed97e6d0c8324eb99b2b1" +dependencies = [ + "rustfft", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -3286,6 +3414,60 @@ version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97" +[[package]] +name = "rten" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52026aa6d9bc40ac0d52bfeb4bc81d4fd5b7866825af1826ed7a4d74bd7574c4" +dependencies = [ + "flatbuffers", + "libm", + "num_cpus", + "rayon", + "rten-simd", + "rten-tensor", + "rten-vecmath", + "rustc-hash 2.0.0", + "smallvec", + "wasm-bindgen", +] + +[[package]] +name = "rten-simd" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f1bb63fc8a157699e42a501cf43512871b20d3bea755f3ffac3ab63f1af10c4" + +[[package]] +name = "rten-tensor" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "575ec5dbc7e7059eb4271bca1c06420d240e8a377593cbac41a0c7227ec8645d" +dependencies = [ + "smallvec", +] + +[[package]] +name = "rten-vecmath" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af98a4e48d69c5aa2167d3adb7a8c1585602486a1aedd1ee8b3d684f98059396" +dependencies = [ + "rten-simd", +] + +[[package]] +name = "rubato" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0fe3acbd4cc7c6d726def76dcfd77164c35a65e034256de2741db8ead9a4ae5" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -3307,6 +3489,21 @@ dependencies = [ "semver", ] +[[package]] +name = "rustfft" +version = "6.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43806561bc506d0c5d160643ad742e3161049ac01027b5e6d7524091fd401d86" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", + "version_check", +] + [[package]] name = "rustix" version = "0.38.37" @@ -3337,6 +3534,12 @@ dependencies = [ "unicode-script", ] +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + [[package]] name = "same-file" version = "1.0.6" @@ -3403,6 +3606,18 @@ dependencies = [ "syn 2.0.77", ] +[[package]] +name = "serde_json" +version = "1.0.132" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "serde_repr" version = "0.1.19" @@ -3594,6 +3809,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strict-num" version = "0.1.1" @@ -3617,6 +3838,164 @@ dependencies = [ "zeno", ] +[[package]] +name = "symphonia" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" +dependencies = [ + "lazy_static", + "symphonia-bundle-flac", + "symphonia-bundle-mp3", + "symphonia-codec-adpcm", + "symphonia-codec-pcm", + "symphonia-codec-vorbis", + "symphonia-core", + "symphonia-format-mkv", + "symphonia-format-ogg", + "symphonia-format-riff", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-bundle-flac" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72e34f34298a7308d4397a6c7fbf5b84c5d491231ce3dd379707ba673ab3bd97" +dependencies = [ + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-bundle-mp3" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c01c2aae70f0f1fb096b6f0ff112a930b1fb3626178fba3ae68b09dce71706d4" +dependencies = [ + "lazy_static", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-codec-adpcm" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c94e1feac3327cd616e973d5be69ad36b3945f16b06f19c6773fc3ac0b426a0f" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-codec-pcm" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-codec-vorbis" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a98765fb46a0a6732b007f7e2870c2129b6f78d87db7987e6533c8f164a9f30" +dependencies = [ + "log", + "symphonia-core", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-core" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "798306779e3dc7d5231bd5691f5a813496dc79d3f56bf82e25789f2094e022c3" +dependencies = [ + "arrayvec", + "bitflags 1.3.2", + "bytemuck", + "lazy_static", + "log", + "rustfft", +] + +[[package]] +name = "symphonia-format-mkv" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bb43471a100f7882dc9937395bd5ebee8329298e766250b15b3875652fe3d6f" +dependencies = [ + "lazy_static", + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-format-ogg" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ada3505789516bcf00fc1157c67729eded428b455c27ca370e41f4d785bfa931" +dependencies = [ + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-format-riff" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50" +dependencies = [ + "extended", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-metadata" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc622b9841a10089c5b18e99eb904f4341615d5aa55bbf4eedde1be721a4023c" +dependencies = [ + "encoding_rs", + "lazy_static", + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-utils-xiph" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "484472580fa49991afda5f6550ece662237b00c6f562c7d9638d1b086ed010fe" +dependencies = [ + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4fd880864edb2723b8ae56852ce0f7cb98d8246e8cd8ec3d6881344ac1ffcd0" +dependencies = [ + "log", + "rubato", + "symphonia", +] + [[package]] name = "syn" version = "1.0.109" @@ -3826,6 +4205,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "ttf-parser" version = "0.20.0" diff --git a/Cargo.toml b/Cargo.toml index 7eb3f389..30c6e59c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "wgpu-jumpstart", "mpeg_encoder", "neothesia", + "neothesia-ai", "neothesia-cli", "neothesia-core", "midi-file", diff --git a/neothesia-ai/.gitignore b/neothesia-ai/.gitignore new file mode 100644 index 00000000..2468d6e6 --- /dev/null +++ b/neothesia-ai/.gitignore @@ -0,0 +1,2 @@ +/target +*.rten diff --git a/neothesia-ai/Cargo.toml b/neothesia-ai/Cargo.toml new file mode 100644 index 00000000..0e964bce --- /dev/null +++ b/neothesia-ai/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "neothesia-ai" +description = "AI audio to piano transciption interface" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.89" +midly = "0.5.3" +ndarray = "0.16.1" +rten = "0.13.1" +rten-tensor = "0.13.1" +serde = "1.0.210" +serde_json = "1.0.128" +symphonia = { version = "0.5.4", features = ["mp3"] } +symphonium = { version = "0.2.2", features = ["mp3"] } diff --git a/neothesia-ai/src/args.rs b/neothesia-ai/src/args.rs new file mode 100644 index 00000000..c667bf78 --- /dev/null +++ b/neothesia-ai/src/args.rs @@ -0,0 +1,69 @@ +use std::path::PathBuf; + +fn print_help() { + let help = [ + " -i, --input ", + " -o, --output ", + " -m, --model ", + ]; + println!("Options:"); + println!("{}", help.join("\n")); + println!(); +} + +#[derive(Debug)] +pub struct Args { + pub input: PathBuf, + pub output: PathBuf, + pub model: PathBuf, +} + +impl Args { + pub fn get_from_env() -> anyhow::Result { + let mut args = std::env::args().skip(1); + + let mut input = None; + let mut output = None; + let mut model = None; + + loop { + let Some(arg) = args.next() else { + break; + }; + + match arg.as_str() { + "--input" | "-i" => { + input = args.next(); + } + "--output" | "-o" => { + output = args.next(); + } + "--model" | "-m" => { + model = args.next(); + } + "--help" | "-h" => { + print_help(); + } + _ => {} + } + } + + let Some(input) = input else { + anyhow::bail!("`--input` audio file missing"); + }; + + let Some(output) = output else { + anyhow::bail!("`--output` midi file missing"); + }; + + let Some(model) = model else { + anyhow::bail!("`--model` rten model file missing"); + }; + + Ok(Args { + input: PathBuf::from(input), + output: PathBuf::from(output), + model: PathBuf::from(model), + }) + } +} diff --git a/neothesia-ai/src/audio.rs b/neothesia-ai/src/audio.rs new file mode 100644 index 00000000..1eadf11b --- /dev/null +++ b/neothesia-ai/src/audio.rs @@ -0,0 +1,30 @@ +use std::path::Path; + +use symphonium::{ResampleQuality, SymphoniumLoader}; + +use crate::{SAMPLE_RATE, SEGMENT_SAMPLES}; + +pub fn load(path: impl AsRef) -> anyhow::Result> { + // A struct used to load audio files. + let mut loader = SymphoniumLoader::new(); + + let mut audio_data_f32 = loader + .load_f32(path, Some(SAMPLE_RATE), ResampleQuality::Normal, None) + .unwrap(); + + let left = audio_data_f32.data.remove(0); + let right = audio_data_f32.data.remove(0); + + let mut mono: Vec = left + .into_iter() + .zip(right) + .map(|(l, r)| (l + r) / 2.0) + .collect(); + + let pad_len = + (mono.len() as f32 / SEGMENT_SAMPLES as f32).ceil() as usize * SEGMENT_SAMPLES - mono.len(); + + mono.resize(mono.len() + pad_len, 0.0); + + Ok(mono) +} diff --git a/neothesia-ai/src/main.rs b/neothesia-ai/src/main.rs new file mode 100644 index 00000000..42e88bd2 --- /dev/null +++ b/neothesia-ai/src/main.rs @@ -0,0 +1,402 @@ +use ndarray::{concatenate, s, Axis}; +use ndarray::{Array2, Array3, ArrayView1, ArrayView2}; +use rten::{InputOrOutput, NodeId}; +use rten_tensor::prelude::*; +use rten_tensor::*; + +const FRAMES_PER_SECOND: f32 = 100.0; +const SAMPLE_RATE: u32 = 16000; +const SEGMENT_SAMPLES: usize = SAMPLE_RATE as usize * 10; + +mod args; +mod audio; + +fn main() -> anyhow::Result<()> { + let args = args::Args::get_from_env()?; + + let input = audio::load(&args.input)?; + + let input = ArrayView2::from_shape([1, input.len()], &input)?; + let input = enframe(&input, SEGMENT_SAMPLES); + let input = input.as_slice().unwrap().to_vec(); + + let input = Tensor::from_data(&[input.len() / SEGMENT_SAMPLES, SEGMENT_SAMPLES], input); + + let model = rten::Model::load_file(&args.model)?; + + let inputs: Vec<(NodeId, InputOrOutput)> = vec![(model.input_ids()[0], input.view().into())]; + + let [reg_onset_output, reg_offset_output, frame_output, _velocity_output, _reg_pedal_onset_output, _reg_pedal_offset_output, _pedal_frame_output] = + model.run_n::<7>(inputs, model.output_ids().try_into()?, None)?; + + let (onset_output, onset_shift_output) = { + let output = reg_onset_output.into_float().unwrap(); + let shape: [usize; 3] = output.shape().try_into().unwrap(); + let reg_onset_output = Array3::from_shape_vec(shape, output.to_vec()).unwrap(); + let reg_onset_output: Array2<_> = deframe(®_onset_output); + + let onset_threshold = 0.3; + get_binarized_output_from_regression(®_onset_output.view(), onset_threshold, 2) + }; + + let (offset_output, offset_shift_output) = { + let output = reg_offset_output.into_float().unwrap(); + let shape: [usize; 3] = output.shape().try_into().unwrap(); + let reg_offset_output: Array3<_> = Array3::from_shape_vec(shape, output.to_vec()).unwrap(); + let reg_offset_output: Array2<_> = deframe(®_offset_output); + + let offset_threshold = 0.2; + get_binarized_output_from_regression(®_offset_output.view(), offset_threshold, 4) + }; + + let frame_output: Array3<_> = { + let output = frame_output.into_float().unwrap(); + let shape: [usize; 3] = output.shape().try_into().unwrap(); + Array3::from_shape_vec(shape, output.to_vec()).unwrap() + }; + let frame_output: Array2<_> = deframe(&frame_output); + + let frame_threshold = 0.1; + + let file = note_detection_with_onset_offset_regress( + frame_output.view(), + onset_output.view(), + onset_shift_output.view(), + offset_output.view(), + offset_shift_output.view(), + (), // velocity_output, + frame_threshold, + ); + + file.save(args.output)?; + + Ok(()) +} + +fn enframe(x: &ArrayView2, segment_samples: usize) -> Array2 { + // Ensure that the number of audio samples is divisible by segment_samples + assert!(x.shape()[1] % segment_samples == 0); + + let mut batch: Vec> = Vec::new(); + let mut pointer = 0; + + let total_samples = x.shape()[1]; + + // Enframe the sequence into smaller segments + while pointer + segment_samples <= total_samples { + let segment = x + .slice(s![.., pointer..(pointer + segment_samples)]) + .to_owned(); + batch.push(segment); + pointer += segment_samples / 2; + } + + // Concatenate the segments along the first axis (the segment axis) + concatenate(Axis(0), &batch.iter().map(|a| a.view()).collect::>()).unwrap() +} + +// TODO: Rewrite this madness +fn deframe(x: &Array3) -> Array2 { + // Get the shape of the input (N, segment_frames, classes_num) + let (n_segments, segment_frames, _classes_num) = x.dim(); + + // If there is only one segment, return it as is (removing the outer dimension) + if n_segments == 1 { + return x.index_axis(Axis(0), 0).to_owned(); // Equivalent to `x[0]` in Python + } + + // Remove the last frame from each segment + let x = x.slice(s![.., 0..segment_frames - 1, ..]).to_owned(); + + // Ensure that segment_frames is divisible by 4 + let segment_samples = segment_frames - 1; + assert!(segment_samples % 4 == 0); + + // Collect segments into a vector to concatenate them later + let mut y: Vec> = Vec::new(); + + // Append the first 75% of the first segment + y.push(x.slice(s![0, 0..(segment_samples * 3 / 4), ..]).to_owned()); + + // Append the middle part (25% to 75%) of the middle segments + for i in 1..(n_segments - 1) { + y.push( + x.slice(s![i, (segment_samples / 4)..(segment_samples * 3 / 4), ..]) + .to_owned(), + ); + } + + // Append the last 75% of the last segment + y.push( + x.slice(s![n_segments - 1, (segment_samples / 4).., ..]) + .to_owned(), + ); + + // Concatenate all parts along the first axis (frames axis) + concatenate(Axis(0), &y.iter().map(|a| a.view()).collect::>()).unwrap() +} + +fn get_binarized_output_from_regression( + reg_output: &ArrayView2, + threshold: f32, + neighbour: usize, +) -> (Array2, Array2) { + let (frames_num, classes_num) = reg_output.dim(); + + let mut binary_output = Array2::::default((frames_num, classes_num)); + let mut shift_output = Array2::::zeros((frames_num, classes_num)); + + for k in 0..classes_num { + let x: ArrayView1 = reg_output.slice(ndarray::s![.., k]); + + for n in neighbour..(frames_num - neighbour) { + if x[n] > threshold && is_monotonic_neighbour(&x, n, neighbour) { + binary_output[[n, k]] = true; + + // See Section III-D in [1] for deduction. + // [1] Q. Kong, et al., High-resolution Piano Transcription + // with Pedals by Regressing Onsets and Offsets Times, 2020. + let shift = if x[n - 1] > x[n + 1] { + (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2.0 + } else { + (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2.0 + }; + shift_output[[n, k]] = shift; + } + } + } + + (binary_output, shift_output) +} + +fn is_monotonic_neighbour(x: &ArrayView1, n: usize, neighbour: usize) -> bool { + // Ensure the value of 'n' is within a valid range + if n < neighbour || n + neighbour >= x.len() { + todo!(); + } + + for i in 0..neighbour { + if x[n - i] < x[n - i - 1] { + return false; + } + if x[n + i] < x[n + i + 1] { + return false; + } + } + + true +} + +fn note_detection_with_onset_offset_regress( + frame: ArrayView2, + onset: ArrayView2, + onset_shift: ArrayView2, + offset: ArrayView2, + offset_shift: ArrayView2, + velocity: (), + frame_threshold: f32, +) -> midly::Smf<'static> { + let classes_num = frame.dim().1; + + let mut notes = Vec::new(); + for piano_note in 0..classes_num { + let res = note_detection_with_onset_offset_regress_inner( + frame.slice(ndarray::s![.., piano_note]), + onset.slice(ndarray::s![.., piano_note]), + onset_shift.slice(ndarray::s![.., piano_note]), + offset.slice(ndarray::s![.., piano_note]), + offset_shift.slice(ndarray::s![.., piano_note]), + velocity, + frame_threshold, + ); + + for (bgn, fin, bgn_shift, fin_shift) in res { + let onset_time = (bgn as f32 + bgn_shift) / FRAMES_PER_SECOND; + let offset_time = (fin as f32 + fin_shift) / FRAMES_PER_SECOND; + + let labels: [&str; 12] = [ + "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "H", + ]; + + let label = labels[(piano_note + 9) % labels.len()]; + + // 21 is the first note in 88 keys layout + let piano_note = piano_note + 21; + + notes.push((piano_note, onset_time, offset_time)); + println!("{piano_note} {label}: {onset_time} - {offset_time}"); + } + } + + create_midi_file(notes) +} + +fn note_detection_with_onset_offset_regress_inner( + frame: ArrayView1, + onset: ArrayView1, + onset_shift: ArrayView1, + offset: ArrayView1, + offset_shift: ArrayView1, + _velocity: (), + frame_threshold: f32, +) -> Vec<(usize, usize, f32, f32)> { + let iter = frame + .into_iter() + .zip(onset) + .zip(onset_shift) + .zip(offset) + .zip(offset_shift) + .enumerate() + // God forgive my sins + .map(|(id, ((((a, b), c), d), e))| (id, a, b, c, d, e)); + + let mut output_tuples = Vec::new(); + let mut bgn: Option<(usize, f32)> = None; + let mut frame_disappear: Option<(usize, f32)> = None; + let mut offset_occur: Option<(usize, f32)> = None; + + let len = onset.shape()[0]; + + for (i, frame, onset, onset_shift, offset, offset_shift) in iter { + if *onset { + // Onset detected + if let Some((bgn, bgn_offset)) = bgn { + // Consecutive onsets. E.g., pedal is not released, but two + // consecutive notes being played. + let fin = i.saturating_sub(1); + output_tuples.push((bgn, fin, bgn_offset, 0.0)); + + frame_disappear = None; + offset_occur = None; + } + + bgn = Some((i, *onset_shift)); + } + + if let Some((bgn_time, bgn_shift)) = bgn { + if i > bgn_time { + // If onset found, then search offset + + if *frame <= frame_threshold && frame_disappear.is_none() { + // Frame disappear detected + frame_disappear = Some((i, *offset_shift)); + } + + if *offset && offset_occur.is_none() { + // Offset detected + offset_occur = Some((i, *offset_shift)); + } + + if let Some((frame_disappear_time, frame_disappear_shift)) = frame_disappear { + let (fin, fin_shift) = match offset_occur { + Some((offset_occur, shift)) + if offset_occur - bgn_time > frame_disappear_time - offset_occur => + { + // bgn --------- offset_occur --- frame_disappear + (offset_occur, shift) + } + _ => { + // bgn --- offset_occur --------- frame_disappear + (frame_disappear_time, frame_disappear_shift) + } + }; + output_tuples.push((bgn_time, fin, bgn_shift, fin_shift)); + + bgn = None; + frame_disappear = None; + offset_occur = None; + } + + if let Some((bgn_time, bgn_shift)) = bgn { + if i - bgn_time >= 600 || i == len - 1 { + // Offset not detected + let fin = i; + output_tuples.push((bgn_time, fin, bgn_shift, *offset_shift)); + + bgn = None; + frame_disappear = None; + offset_occur = None; + } + } + } + } + } + + output_tuples.sort_by_key(|v| v.0); + + output_tuples +} + +fn create_midi_file(notes: Vec<(usize, f32, f32)>) -> midly::Smf<'static> { + let ticks_per_beat = 384; + let beats_per_second = 2; + let ticks_per_second = ticks_per_beat * beats_per_second; + let microseconds_per_beat = (1_000_000.0 / beats_per_second as f64) as u32; + + let mut track1 = vec![]; + + let mut message_roll = vec![]; + + for (midi_note, start, end) in notes { + message_roll.push((start, midi_note, 100)); + message_roll.push((end, midi_note, 0)); + } + + message_roll.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + + let mut previous_ticks = 0; + + let start_time = 0.0; + for message in message_roll { + let this_ticks = ((message.0 - start_time) * ticks_per_second as f32) as i32; + + if this_ticks >= 0 { + let diff_ticks = this_ticks - previous_ticks; + previous_ticks = this_ticks; + + track1.push(midly::TrackEvent { + delta: (diff_ticks as u32).into(), + kind: midly::TrackEventKind::Midi { + channel: 0.into(), + message: midly::MidiMessage::NoteOn { + key: (message.1 as u8).into(), + vel: message.2.into(), + }, + }, + }); + } + } + + track1.push(midly::TrackEvent { + delta: 1.into(), + kind: midly::TrackEventKind::Meta(midly::MetaMessage::EndOfTrack), + }); + + midly::Smf { + header: midly::Header { + format: midly::Format::Parallel, + timing: midly::Timing::Metrical(ticks_per_beat.into()), + }, + tracks: vec![ + vec![ + midly::TrackEvent { + delta: 0.into(), + kind: midly::TrackEventKind::Meta(midly::MetaMessage::Tempo( + microseconds_per_beat.into(), + )), + }, + midly::TrackEvent { + delta: 0.into(), + kind: midly::TrackEventKind::Meta(midly::MetaMessage::TimeSignature( + 4, 2, 24, 8, + )), + }, + midly::TrackEvent { + delta: 1.into(), + kind: midly::TrackEventKind::Meta(midly::MetaMessage::EndOfTrack), + }, + ], + track1, + ], + } +}