Skip to content

Commit

Permalink
unified and improved experiment structure
Browse files Browse the repository at this point in the history
  • Loading branch information
galeone committed Jun 24, 2021
1 parent 64d30a4 commit b82c235
Show file tree
Hide file tree
Showing 25 changed files with 885 additions and 608 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
logs
surface_cracks/
results/

# Documentation stub files

Expand Down
150 changes: 105 additions & 45 deletions bin/anomaly-box.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Training & evaluation script for anomaly toolbox."""

import importlib
import logging
import sys
from datetime import datetime
from pathlib import Path

import click
import warnings

from anomaly_toolbox.benchmarks import AVAILABLE_BENCHMARKS
from anomaly_toolbox.experiments import AVAILABLE_EXPERIMENTS
import anomaly_toolbox.datasets as available_datasets
import anomaly_toolbox.experiments as available_experiments
from anomaly_toolbox.datasets.dataset import AnomalyDetectionDataset
from anomaly_toolbox.hps import grid_search


Expand All @@ -16,69 +18,127 @@
"chosen_experiment",
"--experiment",
help="Experiment to run.",
type=click.Choice(list(AVAILABLE_EXPERIMENTS.keys()), case_sensitive=False),
)
@click.option(
"chosen_benchmark",
"--benchmark",
help="Benchmark to run.",
type=click.Choice(list(AVAILABLE_BENCHMARKS.keys()), case_sensitive=False),
)
@click.option(
"test_run",
"--test-run",
help="When running a benchmark, the path of the SavedModel to use.",
type=str,
type=click.Choice(available_experiments.__ALL__, case_sensitive=True),
)
@click.option(
"hparams_file_path",
"hps_path",
"--hps-path",
help="When running a benchmark, the path of the JSON file where all "
help="When running an experiment, the path of the JSON file where all "
"the hyperparameters are located.",
type=Path,
required=True,
)
@click.option(
"hparams_tuning",
"--hps-tuning",
"hps_tuning",
"--tuning",
help="If you want to use hyperparameters tuning, use 'True' here. Default is False.",
type=bool,
default=False,
)
@click.option(
"dataset",
"--dataset",
help=(
"The dataset to use. Can be a ready to use dataset, or a .py file "
"that implements the AnomalyDectionDataset interface"
),
type=str,
required=True,
)
@click.option(
"run_all",
"--run-all",
help="Run all the available experiments",
type=bool,
default=False,
)
def main(
chosen_experiment: str,
chosen_benchmark: str,
test_run: str,
hparams_file_path: Path,
hparams_tuning: bool,
hps_path: Path,
hps_tuning: bool,
dataset: str,
run_all: bool,
) -> int:

# Warning to the user if the hparmas_tuning.json && --hps-tuning==False
if "tuning" in str(hparams_file_path) and not hparams_tuning:
warnings.warn(
# Warning to the user if the hparmas_tuning.json && --tuning==False
if "tuning" in str(hps_path) and not hps_tuning:
logging.warning(
"You choose to use the tuning JSON but the tuning boolean ("
"--hps-tuning) is False. Only one kind of each parameters will be taken "
"--tuning) is False. Only one kind of each parameters will be taken "
"into consideration. No tuning will be performed."
)

"""Console script for anomaly_toolbox."""
if chosen_experiment:
log_dir = (
Path("logs")
/ "experiments"
/ chosen_experiment
/ datetime.now().strftime("%Y%m%d-%H%M%S")
# Instantiate dataset config from dataset name
if dataset.endswith(".py"):
file_path = Path(dataset).absolute()
name = file_path.stem
sys.path.append(str(file_path.parent))

dataset_instance = getattr(__import__(name), name)()
if not isinstance(dataset_instance, AnomalyDetectionDataset):
logging.error(
"Your class %s must implement the "
"anomaly_toolbox.datasets.dataset.AnomalyDetectionDataset"
"interface",
dataset_instance,
)
return 1
else:
try:
dataset_instance = getattr(
importlib.import_module("anomaly_toolbox.datasets"),
dataset,
)()
except (ModuleNotFoundError, AttributeError, TypeError):
logging.error(
"Dataset %s is not among the availables: %s",
dataset,
",".join(available_datasets.__ALL__),
)
return 1

if run_all and chosen_experiment:
logging.error("Only one between --run_all and --experiment can be used.")
return 1

if not (chosen_experiment or run_all):
logging.error(
"Please choose a valid CLI flag.\n%s --help", Path(sys.argv[0]).name
)
experiment = AVAILABLE_EXPERIMENTS[chosen_experiment.lower()](
hparams_file_path, log_dir
return 1

if not hps_path or not hps_path.exists():
logging.error(
"Check that %s exists and it's a valid JSON containing the hyperparameters.",
hps_path,
)
experiment.run(hparams_tuning, grid_search)
elif chosen_benchmark:
benchmark = AVAILABLE_BENCHMARKS[chosen_benchmark.lower()](run_path=test_run)
benchmark.load_from_savedmodel().run()
else:
exe = sys.argv[0]
print(f"Please choose a valid CLI flag.\n{exe} --help", file=sys.stderr)
return 1

if chosen_experiment:
experiments = [chosen_experiment]
else:
experiments = available_experiments.__ALL__

for experiment in experiments:
log_dir = Path("logs") / experiment

try:
experiment_instance = getattr(
importlib.import_module("anomaly_toolbox.experiments"),
experiment,
)(hps_path, log_dir)
except (ModuleNotFoundError, AttributeError, TypeError):
logging.error(
"Experiment %s is not among the availables: %s",
experiment,
",".join(available_experiments.__ALL__),
)
return 1

print(hps_path)
print(hps_tuning, grid_search, dataset_instance)
experiment_instance.run(hps_tuning, grid_search, dataset_instance)

return 0


Expand Down
Loading

0 comments on commit b82c235

Please sign in to comment.