Skip to content

Commit

Permalink
examples: add example for image matting with ModNet (#142)
Browse files Browse the repository at this point in the history
* Add ModNet example

* Add show-image crate for visualizing output

* use parcel model download

---------

Co-authored-by: Carson M <[email protected]>
  • Loading branch information
GitNiko and decahedron1 authored Jan 15, 2024
1 parent 127154b commit e33c7b8
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 2 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ members = [
'ort-sys',
'examples/gpt2',
'examples/model-info',
'examples/yolov8'
'examples/yolov8',
'examples/modnet'
]
default-members = [
'.',
'examples/gpt2',
'examples/model-info',
'examples/yolov8'
'examples/yolov8',
'examples/modnet'
]

[package]
Expand Down
17 changes: 17 additions & 0 deletions examples/modnet/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
publish = false
name = "example-modnet"
version = "0.0.0"
edition = "2021"

[dependencies]
ort = { path = "../../" }
ndarray = "0.15"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
image = "0.24"
tracing = "0.1"
show-image = { version = "0.13", features = [ "image", "raqote" ] }

[features]
load-dynamic = [ "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]
Binary file added examples/modnet/data/photo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions examples/modnet/examples/modnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#![allow(clippy::manual_retain)]

use std::{ops::Mul, path::Path};

use image::{imageops::FilterType, GenericImageView, ImageBuffer, Rgba};
use ndarray::Array;
use ort::{inputs, CUDAExecutionProvider, Session};
use show_image::{event, AsImageView, WindowOptions};

#[show_image::main]
fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();

ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

let model =
Session::builder()?.with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?;

let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap();
let (img_width, img_height) = (original_img.width(), original_img.height());
let img = original_img.resize_exact(512, 512, FilterType::Triangle);
let mut input = Array::zeros((1, 3, 512, 512));
for pixel in img.pixels() {
let x = pixel.0 as _;
let y = pixel.1 as _;
let [r, g, b, _] = pixel.2.0;
input[[0, 0, y, x]] = (r as f32 - 127.5) / 127.5;
input[[0, 1, y, x]] = (g as f32 - 127.5) / 127.5;
input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5;
}

let outputs = model.run(inputs!["input" => input.view()]?)?;

let binding = outputs["output"].extract_tensor::<f32>().unwrap();
let output = binding.view();

// convert to 8-bit
let output = output.mul(255.0).map(|x| *x as u8);
let output = output.into_raw_vec();

// change rgb to rgba
let output_img = ImageBuffer::from_fn(512, 512, |x, y| {
let i = (x + y * 512) as usize;
Rgba([output[i], output[i], output[i], 255])
});

let mut output = image::imageops::resize(&output_img, img_width, img_height, FilterType::Triangle);
output.enumerate_pixels_mut().for_each(|(x, y, pixel)| {
let origin = original_img.get_pixel(x, y);
pixel.0[3] = pixel.0[0];
pixel.0[0] = origin.0[0];
pixel.0[1] = origin.0[1];
pixel.0[2] = origin.0[2];
});

let window = show_image::context()
.run_function_wait(move |context| -> Result<_, String> {
let mut window = context
.create_window(
"ort + modnet",
WindowOptions {
size: Some([img_width, img_height]),
..WindowOptions::default()
}
)
.map_err(|e| e.to_string())?;
window.set_image("photo", &output.as_image_view().map_err(|e| e.to_string())?);
Ok(window.proxy())
})
.unwrap();

for event in window.event_channel().unwrap() {
if let event::WindowEvent::KeyboardInput(event) = event {
if event.input.key_code == Some(event::VirtualKeyCode::Escape) && event.input.state.is_pressed() {
break;
}
}
}

Ok(())
}

0 comments on commit e33c7b8

Please sign in to comment.