Skip to content

Commit

Permalink
feat: add calc iou ability
Browse files Browse the repository at this point in the history
  • Loading branch information
4o3F committed Oct 5, 2024
1 parent 6b2487e commit b96e944
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 8 deletions.
18 changes: 11 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,18 @@ jobs:
- uses: actions/checkout@v3

# Install rust
- name: Install rust
- name: Install Rust
run: |
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
rustup target add x86_64-unknown-linux-musl
sudo apt-get update && sudo apt-get install libstdc++-12-dev musl-tools
- name: Install MUSL
run: |
sudo apt-get update && sudo apt-get install libstdc++-12-dev musl-tools p7zip-full
wget -O musl.7z https://github.com/benjaminwan/musl-cross-builder/releases/download/13.2.0/x86_64-linux-musl-13.2.0.7z
7za x -y musl.7z -o/opt
# Install opencv
#- name: Install opencv
Expand Down Expand Up @@ -68,16 +75,13 @@ jobs:
uses: mxschmitt/action-tmate@v3
if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }}

# Note: OPENCV_LINK_LIBS ordering matters for linux. Put lower level deps after higher level. See https://answers.opencv.org/question/186124/undefined-reference-to-cvsoftdoubleoperator/
# libclang files are in /usr/lib/llvm-##/lib. We symlink it to one of the opencv_link_paths
# OpenCV-rust looks for "opencv2/core/version.hpp" for the OpenCV version: https://github.com/twistedfall/opencv-rust/issues/368
# which is under /<install-prefix>/include/opencv4 for linux
# Build
- name: Build
run: |
export OPENCV_LINK_LIBS="opencv_videoio,opencv_imgcodecs,opencv_imgproc,opencv_ximgproc,opencv_core"
export OPENCV_LINK_PATHS=/opt/opencv/lib,/opt/opencv/lib/opencv4/3rdparty,/usr/lib/x86_64-linux-gnu
export OPENCV_INCLUDE_PATHS=/opt/opencv/include,/opt/opencv/include/opencv4
export CC="/opt/x86_64-linux-musl/bin/x86_64-linux-musl-gcc"
export CXX="/opt/x86_64-linux-musl/bin/x86_64-linux-musl-g++"
sudo ln -s /usr/lib/llvm-15/lib/libclang.so.1 /usr/lib/x86_64-linux-gnu/libclang.so
cargo build --release --target x86_64-unknown-linux-musl
Expand Down
3 changes: 2 additions & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod augment;
pub mod convert;
pub mod operation;
pub mod remap;
pub mod dataset;
pub mod dataset;
pub mod metric;
83 changes: 83 additions & 0 deletions src/common/metric.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use std::collections::HashMap;

use opencv::{
core::{self, MatTraitConst, CV_8U},
imgcodecs,
};
use tracing_unwrap::ResultExt;

pub fn calc_iou(target_img: &String, gt_img: &String) {
let target_img = imgcodecs::imread(&target_img, imgcodecs::IMREAD_COLOR)
.expect_or_log("Open output image error");
let gt_img = imgcodecs::imread(&gt_img, imgcodecs::IMREAD_COLOR)
.expect_or_log("Open ground truth image error");

tracing::info!("Image loaded");
if target_img.depth() != CV_8U || gt_img.depth() != CV_8U {
tracing::error!("Output image and ground truth image must be 8-bit 3-channel images");
return;
}

let mut intersection: HashMap<(u8, u8, u8), usize> = HashMap::new();
let mut union: HashMap<(u8, u8, u8), usize> = HashMap::new();

let rows = gt_img.rows();
let cols = gt_img.cols();

for i in 0..rows {
for j in 0..cols {
let pixel1 = target_img
.at_2d::<core::Vec3b>(i, j)
.expect_or_log("Get output pixel error");
let pixel2 = gt_img
.at_2d::<core::Vec3b>(i, j)
.expect_or_log("Get ground truth pixel error");

let color1 = (pixel1[0], pixel1[1], pixel1[2]);
let color2 = (pixel2[0], pixel2[1], pixel2[2]);

// 更新交集和并集
if pixel1[0] > 0 || pixel1[1] > 0 || pixel1[2] > 0 {
*union.entry(color1).or_insert(0) += 1;
}
if pixel2[0] > 0 || pixel2[1] > 0 || pixel2[2] > 0 {
*union.entry(color2).or_insert(0) += 1;
}

if pixel1 == pixel2 {
*intersection.entry(color1).or_insert(0) += 1;
}
}
}

let mut iou = HashMap::new();
let mut total_iou = 0.0;
let mut num_categories = 0;

for (&color, &inter) in &intersection {
let uni = union.get(&color).unwrap_or(&0);
if *uni > 0 {
let iou_value = inter as f64 / *uni as f64;
iou.insert(color, iou_value);
total_iou += iou_value;
num_categories += 1;
}
}

let mean_iou = if num_categories > 0 {
total_iou / num_categories as f64
} else {
0.0
};

for (color, &iou_value) in &iou {
tracing::info!(
"IoU for color RGB({},{},{}): {}",
color.0,
color.1,
color.2,
iou_value
);
}
tracing::info!("Mean IoU: {}", mean_iou);
}
63 changes: 63 additions & 0 deletions src/common/remap.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use image::{GrayImage, Luma, Rgb, RgbImage};
use opencv::{
core::{MatTrait, MatTraitConst, Vec3b, Vector, CV_8U},
imgcodecs::{self, imread, imwrite},
};
use std::{
collections::HashMap,
fs,
ops::Deref,
sync::{Arc, RwLock},
};
use tokio::{sync::Semaphore, task::JoinSet};
use tracing_unwrap::ResultExt;

pub fn remap_color(original_color: &str, new_color: &str, image_path: &String, save_path: &String) {
let mut original_color_vec: Vec<u8> = vec![];
Expand Down Expand Up @@ -105,6 +110,64 @@ pub async fn remap_color_dir(
tracing::info!("All color remap done!");
}

pub fn remap_background_color(
valid_colors: &String,
new_color: &str,
image_path: &String,
save_path: &String,
) {
let mut original_color_vec: Vec<Vec<u8>> = vec![];
for valid_color in valid_colors.split(";") {
let mut color_vec = vec![];
for splited in valid_color.split(',') {
let splited = splited
.parse::<u8>()
.expect_or_log("Malformed original color RGB, please use R,G,B format");
color_vec.push(splited);
}

original_color_vec.push(color_vec);
}

let mut new_color_vec: Vec<u8> = vec![];
for splited in new_color.split(',') {
let splited = splited
.parse::<u8>()
.expect_or_log("Malformed new color RGB, please use R,G,B format");
new_color_vec.push(splited);
}

if original_color_vec.iter().any(|x| x.len() != 3) || new_color_vec.len() != 3 {
tracing::error!("Malformed color RGB, please use R,G,B format");
return;
}

let new_color_rgb = [new_color_vec[0], new_color_vec[1], new_color_vec[2]];

let mut img = imread(&image_path, imgcodecs::IMREAD_COLOR).expect_or_log("Open image error");
tracing::info!("Image loaded");
if img.depth() != CV_8U {
tracing::error!("Image depth is not 8U, not supported");
return;
}
let cols = img.cols();
let rows = img.rows();

for y in 0..rows {
for x in 0..cols {
let pixel = img.at_2d_mut::<Vec3b>(y, x).unwrap();
if !original_color_vec.contains(&vec![pixel.0[0], pixel.0[1], pixel.0[2]]) {
pixel.0 = new_color_rgb;
}
}
tracing::trace!("Y {} done", y);
}

imwrite(save_path, &img, &Vector::new()).expect_or_log("Save image error");

tracing::info!("{} background color remap done!", image_path);
}

pub async fn class2rgb(dataset_path: &String, rgb_list: &str) {
let entries = fs::read_dir(dataset_path).unwrap();
fs::create_dir_all(format!("{}\\..\\output\\", dataset_path)).unwrap();
Expand Down
43 changes: 43 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ enum CommonCommands {
save_path: String,
},

MapBackgroundColor {
#[arg(short, long, help = "In R1,G1,B1;R2,G2,B2 format")]
valid_colors: String,

#[arg(short, long, help = "In R,G,B format")]
new_color: String,

#[arg(short, long, help = "The path for the original image")]
image_path: String,

#[arg(short, long, help = "The path for the mapped new image")]
save_path: String,
},

/// Split large images to small pieces for augmentation purposes
SplitImages {
#[arg(short, long, help = "The path for the folder containing images")]
Expand Down Expand Up @@ -261,6 +275,16 @@ enum CommonCommands {
#[arg(short, long, help = "The path for the folder containing images")]
dataset_path: String,
},

/// Calc the IoU of two images
#[command(name = "calc-iou")]
CalcIoU {
#[arg(short, long, help = "The path for the target image")]
target_image: String,

#[arg(short, long, help = "The path for the ground truth image")]
gt_image: String,
},
}

#[derive(Subcommand)]
Expand Down Expand Up @@ -342,6 +366,19 @@ async fn main() {
} => {
common::remap::remap_color(original_color, new_color, image_path, save_path);
}
CommonCommands::MapBackgroundColor {
valid_colors,
new_color,
image_path,
save_path,
} => {
common::remap::remap_background_color(
valid_colors,
new_color,
image_path,
save_path,
);
}
CommonCommands::MapColorDir {
original_color,
new_color,
Expand Down Expand Up @@ -456,6 +493,12 @@ async fn main() {
CommonCommands::CalcMeanStd { dataset_path } => {
common::dataset::calc_mean_std(dataset_path).await;
}
CommonCommands::CalcIoU {
target_image,
gt_image,
} => {
common::metric::calc_iou(target_image, gt_image);
}
},
Some(Commands::Yolo { command }) => match command {
YoloCommands::SplitDataset {
Expand Down

0 comments on commit b96e944

Please sign in to comment.