diff --git a/src/common/operation.rs b/src/common/operation.rs index 6bb1706..7c6882e 100644 --- a/src/common/operation.rs +++ b/src/common/operation.rs @@ -1,77 +1,108 @@ -use image::imageops::FilterType; use opencv::{ - core::{self, MatTraitConst}, - imgcodecs, + core::{self, Mat, MatTraitConst, Size, Vector}, + imgcodecs::{self, imread, imwrite}, + imgproc, }; -use std::{fs, io::Cursor, sync::Arc}; -use tokio::{fs::File, io::AsyncWriteExt, sync::Semaphore, task::JoinSet}; +use std::{fs, path::PathBuf, sync::Arc}; +use tokio::{sync::Semaphore, task::JoinSet}; use tracing_unwrap::{OptionExt, ResultExt}; use crate::THREAD_POOL; pub async fn resize_images( dataset_path: &String, - target_height: &u32, - target_width: &u32, + target_height: &i32, + target_width: &i32, filter: &str, ) { + let mut entries: Vec = Vec::new(); + let dataset_path = PathBuf::from(dataset_path); + if dataset_path.is_file() { + entries.push(dataset_path.clone()); + fs::create_dir_all(format!( + "{}\\output\\", + dataset_path.parent().unwrap().to_str().unwrap() + )) + .expect_or_log("Failed to create directory"); + } else { + entries = fs::read_dir(dataset_path.clone()) + .unwrap() + .map(|x| x.unwrap().path()) + .collect(); + + fs::create_dir_all(format!("{}\\output\\", dataset_path.to_str().unwrap())) + .expect_or_log("Failed to create directory"); + } + let filter = match filter.to_lowercase().as_str() { - "nearest" => FilterType::Nearest, - "linear" => FilterType::Triangle, - "cubic" => FilterType::CatmullRom, - "gaussian" => FilterType::Gaussian, - "lanczos" => FilterType::Lanczos3, + "nearest" => imgproc::INTER_NEAREST, + "linear" => imgproc::INTER_LINEAR, + "cubic" => imgproc::INTER_CUBIC, + "lanczos" => imgproc::INTER_LANCZOS4, _ => { tracing::error!("Invalid filter type. Please use one of the following: nearest, linear, cubic, gaussian, lanczos"); return; } }; - let entries = fs::read_dir(dataset_path).unwrap(); let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new( + let semaphore = Arc::new(Semaphore::new( (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), )); - fs::create_dir_all(format!("{}\\..\\resized\\", dataset_path)).unwrap(); + for entry in entries { - let entry = entry.unwrap(); - if entry - .path() - .file_name() - .unwrap() - .to_str() - .unwrap() - .ends_with(".png") - { - let permit = Arc::clone(&sem); - let dataset_path = dataset_path.clone(); - let target_height = *target_height; - let target_width = *target_width; - threads.spawn(async move { - let _permit = permit.acquire().await.unwrap(); - - let img = image::open(entry.path()).unwrap(); - let img_resized = img.resize_exact(target_width, target_height, filter); - let mut bytes: Vec = Vec::new(); - img_resized - .write_to(&mut Cursor::new(&mut bytes), image::ImageFormat::Png) - .unwrap(); - let mut file = File::create(format!( - "{}\\..\\resized\\{}", - dataset_path, - entry.path().file_name().unwrap().to_str().unwrap() - )) - .await - .unwrap(); - file.write_all(&bytes).await.unwrap(); - tracing::info!( - "Image {} done\n", - entry.path().file_name().unwrap().to_str().unwrap() + let permit = Arc::clone(&semaphore); + let target_width = target_width.clone(); + let target_height = target_height.clone(); + let filter = filter.clone(); + threads.spawn(async move { + let _permit = permit.acquire().await.unwrap(); + let img = imread(entry.to_str().unwrap(), imgcodecs::IMREAD_COLOR) + .expect_or_log("Failed to read image"); + if img.channels() != 3 { + tracing::error!( + "Image {} is not RGB image", + entry.file_name().unwrap().to_str().unwrap() ); - }); + return Err(()); + } + + let mut result = Mat::default(); + opencv::imgproc::resize( + &img, + &mut result, + Size::new(target_width, target_height), + 0., + 0., + filter, + ) + .expect_or_log("Failed to resize image"); + + imwrite( + format!( + "{}\\output\\{}", + entry.parent().unwrap().to_str().unwrap(), + entry.file_name().unwrap().to_str().unwrap() + ) + .as_str(), + &result, + &Vector::new(), + ) + .unwrap(); + tracing::info!( + "Image {} done", + entry.file_name().unwrap().to_str().unwrap() + ); + Ok(()) + }); + } + while let Some(res) = threads.join_next().await { + if let Err(e) = res { + tracing::error!("{}", e); + continue; } } - while threads.join_next().await.is_some() {} + tracing::info!("All done!"); } #[derive(Clone, Copy)] diff --git a/src/main.rs b/src/main.rs index fb03e11..bd962f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -170,19 +170,23 @@ enum CommonCommands { /// Resize all images in a given folder to a given size with a given filter ResizeImages { - #[arg(short, long, help = "The path for the folder containing images")] + #[arg( + short, + long, + help = "The path of the image / directory containing images" + )] dataset_path: String, #[arg(long, help = "Target height")] - height: u32, + height: i32, #[arg(long, help = "Target width")] - width: u32, + width: i32, #[arg( short, long, - help = "Filter type, could be `nearest`, `linear`, `cubic`, `gaussian` or `lanczos`" + help = "Filter type, could be `nearest`, `linear`, `cubic` or `lanczos`" )] filter: String, },