Skip to content

Commit

Permalink
Log benchmark name
Browse files Browse the repository at this point in the history
  • Loading branch information
m30m committed Feb 6, 2021
1 parent f3f4ccc commit 2b42109
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def is_trained_model_valid(self, test_acc):

def run(self):
print(f"Using device {self.device}")
benchmark_name = self.__class__.__name__
all_explanations = defaultdict(list)
all_runtimes = defaultdict(list)
for experiment_i in tq(range(self.sample_count)):
Expand Down Expand Up @@ -142,7 +143,7 @@ def time_wrapper(*args, **kwargs):

time_wrapper.explain_function = explain_function
accs = self.evaluate_explanation(time_wrapper, model, test_dataset, explain_name)
print(f'Run #{experiment_i + 1}, Explain Method: {explain_name}, Accuracy: {np.mean(accs)}')
print(f'Benchmark:{benchmark_name} Run #{experiment_i + 1}, Explain Method: {explain_name}, Accuracy: {np.mean(accs)}')
all_explanations[explain_name].append(list(accs))
all_runtimes[explain_name].extend(duration_samples)
metrics = {
Expand All @@ -154,7 +155,7 @@ def time_wrapper(*args, **kwargs):
json.dump(all_explanations, open(file_path, 'w'), indent=2)
mlflow.log_artifact(file_path)
mlflow.log_metrics(metrics, step=experiment_i)
print(f'Run #{experiment_i + 1} finished. Average Explanation Accuracies for each method:')
print(f'Benchmark:{benchmark_name} Run #{experiment_i + 1} finished. Average Explanation Accuracies for each method:')
accuracies_summary = {}
for name, run_accs in all_explanations.items():
run_accs = [np.mean(single_run_acc) for single_run_acc in run_accs]
Expand Down

0 comments on commit 2b42109

Please sign in to comment.