diff --git a/README.md b/README.md index a520c0e..b1b7231 100644 --- a/README.md +++ b/README.md @@ -366,11 +366,11 @@ For a small tutorial using some concrete examples see the [`tutorial`][tutorial_ ## Evaluation The Laypa repository also contains a few tools used to evaluate the results generated by the model. -The first tool is a visual comparison between the predictions of the model and the ground truth. This is done as an overlay of the classes over the original image. The overlay class names and colors are taken from the dataset catalog. The tool to do this is [`eval.py`][eval_link]. The visualization has almost the same arguments as the training command ([`main.py`][main_link]). +The first tool is a visual comparison between the predictions of the model and the ground truth. This is done as an overlay of the classes over the original image. The overlay class names and colors are taken from the dataset catalog. The tool to do this is [`visualization.py`][eval_link]. The visualization has almost the same arguments as the training command ([`main.py`][main_link]). Required arguments: ```sh -python eval.py \ +python visualization.py \ -c/--config CONFIG \ -i/--input INPUT [INPUT ...] \ ``` @@ -380,7 +380,7 @@ python eval.py \ Optional arguments: ```sh -python eval.py \ +python visualization.py \ -c/--config CONFIG \ -i/--input INPUT [INPUT ...] \ [-o/--output OUTPUT] \ @@ -394,13 +394,13 @@ python eval.py \ The optional arguments are shown using square brackets. The `-o/output` parameter specifies the output directory for the visualization masks. The `--tmp_dir` parameter specifies a folder in which to store temporary files. While the `--keep_tmp_dir` parameter prevents the temporary files from being deleted after a run (mostly for debugging). The final parameter `--opts` allows you to change values specified in the config files. For example, `--opts SOLVER.IMS_PER_BATCH 8` sets the batch size to 8. The `--sorted` parameter sorts the images based on the order in the operating system. The `--save` parameter specifies what type of file the visualization should be saved as. The options are "pred" for the prediction, "gt" for the ground truth, "both" for both the prediction and the ground truth and "all" for all of the previous. If just `--save` is given the default is "all". -Example of running [`eval.py`][eval_link]: +Example of running [`visualization.py`][eval_link]: ```sh -python eval.py -c config.yml -i input_dir +python visualization.py -c config.yml -i input_dir ``` -The [`eval.py`][eval_link] will then open a window with both the prediction and the ground truth side by side (if the ground truth exists). Allowing for easier comparison. The visualization masks are created in the same way the preprocessing converts pageXML to masks. +The [`visualization.py`][eval_link] will then open a window with both the prediction and the ground truth side by side (if the ground truth exists). Allowing for easier comparison. The visualization masks are created in the same way the preprocessing converts pageXML to masks. The second tool is a program to compare the similarity of two sets of pageXML. This can mean either comparing ground truth to predicted pageXML, or determining the similarity of two annotations by different people. This tool is the [`xml_comparison.py`][xml_comparison_link] file. The comparison allows you to specify how regions and baseline should be drawn in when creating the pixel masks. The pixel masks are then compared based on their Intersection over Union (IoU) and Accuracy (Acc) scores. For the sake of the Accuracy metric one of the two sets needs to be specified as the ground truth set. So one set is the ground truth directory (`--gt`) argument and the other is the input directory (`--input`) argument. @@ -493,7 +493,7 @@ If you discover a bug or missing feature that you would like to help with please [tutorial_link]: tutorial/ [main_link]: main.py [run_link]: run.py -[eval_link]: eval.py +[eval_link]: visualization.py [xml_comparison_link]: xml_comparison.py [xml_viewer_link]: xml_viewer.py [start_flask_link]: /api/start_flask.sh diff --git a/core/trainer.py b/core/trainer.py index 5bfa449..50511cd 100644 --- a/core/trainer.py +++ b/core/trainer.py @@ -208,7 +208,7 @@ class Trainer(DefaultTrainer): Trainer class """ - def __init__(self, cfg: CfgNode): + def __init__(self, cfg: CfgNode, validation: bool = False): TrainerBase.__init__(self) # logger = logging.getLogger("detectron2") @@ -219,7 +219,7 @@ def __init__(self, cfg: CfgNode): # Assume these objects must be constructed in this order. model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) - data_loader = self.build_train_loader(cfg) + data_loader = self.build_train_loader(cfg) if not validation else None model = create_ddp_model(model, broadcast_buffers=False) self._trainer = (AMPTrainer if cfg.MODEL.AMP_TRAIN.ENABLED else SimpleTrainer)(model, data_loader, optimizer) @@ -319,3 +319,6 @@ def build_optimizer(cls, cfg, model): @classmethod def build_lr_scheduler(cls, cfg, optimizer): return build_lr_scheduler(cfg, optimizer) + + def validate(self): + results = self.test(self.cfg, self.model) diff --git a/validation.py b/validation.py new file mode 100644 index 0000000..f74577e --- /dev/null +++ b/validation.py @@ -0,0 +1,60 @@ +import argparse +import logging + +from detectron2.evaluation import SemSegEvaluator + +from core.preprocess import preprocess_datasets +from core.setup import setup_cfg, setup_logging, setup_saving, setup_seed +from core.trainer import Trainer +from utils.logging_utils import get_logger_name +from utils.tempdir import OptionalTemporaryDirectory + + +def get_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Validation of model compared to ground truth") + + detectron2_args = parser.add_argument_group("detectron2") + + detectron2_args.add_argument("-c", "--config", help="config file", required=True) + detectron2_args.add_argument("--opts", nargs="+", help="optional args to change", action="extend", default=[]) + + io_args = parser.add_argument_group("IO") + # io_args.add_argument("-t", "--train", help="Train input folder/file", + # nargs="+", action="extend", type=str, default=None) + io_args.add_argument("-i", "--input", help="Input folder/file", nargs="+", action="extend", type=str, default=None) + io_args.add_argument("-o", "--output", help="Output folder", type=str) + + tmp_args = parser.add_argument_group("tmp files") + tmp_args.add_argument("--tmp_dir", help="Temp files folder", type=str, default=None) + tmp_args.add_argument("--keep_tmp_dir", action="store_true", help="Don't remove tmp dir after execution") + + parser.add_argument("--sorted", action="store_true", help="Sorted iteration") + parser.add_argument("--save", nargs="?", const="all", default=None, help="Save images instead of displaying") + + args = parser.parse_args() + + return args + + +def main(args): + + cfg = setup_cfg(args) + setup_logging(cfg) + setup_seed(cfg) + setup_saving(cfg) + + logger = logging.getLogger(get_logger_name()) + + # Temp dir for preprocessing in case no temporary dir was specified + with OptionalTemporaryDirectory(name=args.tmp_dir, cleanup=not (args.keep_tmp_dir)) as tmp_dir: + preprocess_datasets(cfg, None, args.input, tmp_dir) + + trainer = Trainer(cfg, validation=True) + trainer.resume_or_load(resume=False) + + trainer.validate() + + +if __name__ == "__main__": + args = get_arguments() + main(args) diff --git a/eval.py b/visualization.py similarity index 98% rename from eval.py rename to visualization.py index c18eebf..99c012d 100644 --- a/eval.py +++ b/visualization.py @@ -9,17 +9,14 @@ import matplotlib.pyplot as plt import numpy as np import torch -from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.utils.visualizer import Visualizer from natsort import os_sorted from tqdm import tqdm -from core.preprocess import preprocess_datasets from core.setup import setup_cfg from datasets.dataset import metadata_from_classes from datasets.mapper import AugInput from page_xml.xml_converter import XMLConverter -from page_xml.xml_regions import XMLRegions from run import Predictor from utils.image_utils import load_image_array_from_path, save_image_array_to_path from utils.input_utils import get_file_paths, supported_image_formats @@ -31,7 +28,7 @@ def get_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Eval of prediction of model using visualizer") + parser = argparse.ArgumentParser(description="Visualization of prediction/GT of model") detectron2_args = parser.add_argument_group("detectron2")