diff --git a/asreviewcontrib/makita/templates/script_get_plot.py.template b/asreviewcontrib/makita/templates/script_get_plot.py.template index c47b0111..52e884d9 100644 --- a/asreviewcontrib/makita/templates/script_get_plot.py.template +++ b/asreviewcontrib/makita/templates/script_get_plot.py.template @@ -20,7 +20,6 @@ Authors import argparse from pathlib import Path -import matplotlib.colors as mcolors import matplotlib.pyplot as plt from asreview import open_state @@ -30,31 +29,39 @@ from asreviewcontrib.insights.plot import plot_recall def get_plot_from_states(states, filename, legend=None): """Generate an ASReview plot from state files.""" + # sort the states alphabetically + states = sorted(states) + fig, ax = plt.subplots() labels = [] - colors = list(mcolors.TABLEAU_COLORS.values()) for state_file in states: with open_state(state_file) as state: # draw the plot plot_recall(ax, state) - # set the label + # settings for legend "filename" if legend == "filename": ax.lines[-2].set_label(state_file.stem) ax.legend(loc=4, prop={'size': 8}) + # settings for legend "settings" elif legend: metadata = state.settings_metadata + # settings for legend "model" if legend == "model": label = " - ".join( [metadata["settings"]["model"], metadata["settings"]["feature_extraction"], metadata["settings"]["balance_strategy"], metadata["settings"]["query_strategy"]]) + + # settings for legend "classifier" elif legend == "classifier": label = metadata["settings"]["model"] + + # settings for legend from metadata else: try: label = metadata["settings"][legend] @@ -63,12 +70,17 @@ def get_plot_from_states(states, filename, legend=None): f"Legend setting '{legend}' " "not found in state file settings." ) from exc + + # add label to legend if not already present + # (multiple states can have the same label) if label not in labels: ax.lines[-2].set_label(label) labels.append(label) - ax.lines[-2].set_color(colors[labels.index(label) % len(colors)]) - ax.legend(loc=4, prop={'size': 8}) + # add legend to plot + ax.legend(loc=4, prop={'size': 8}) + + # save plot fig.savefig(str(filename)) @@ -92,10 +104,10 @@ if __name__ == "__main__": args = parser.parse_args() # load states - states = Path(args.s).glob("*.asreview") + states = list(Path(args.s).glob("*.asreview")) # check if states are found - if len(list(states)) == 0: + if len(states) == 0: raise FileNotFoundError(f"No state files found in {args.s}") # generate plot and save results