diff --git a/seqeval/metrics/sequence_labeling.py b/seqeval/metrics/sequence_labeling.py index ac56f18..e29d4d6 100644 --- a/seqeval/metrics/sequence_labeling.py +++ b/seqeval/metrics/sequence_labeling.py @@ -301,16 +301,17 @@ def performance_measure(y_true, y_pred): return performance_dict -def classification_report(y_true, y_pred, digits=2, suffix=False): +def classification_report(y_true, y_pred, digits=2, suffix=False, output_dict=False): """Build a text report showing the main classification metrics. Args: y_true : 2d array. Ground truth (correct) target values. y_pred : 2d array. Estimated targets as returned by a classifier. digits : int. Number of digits for formatting output floating point values. + output_dict : bool(default=False). If True, return output as dict else str. Returns: - report : string. Text summary of the precision, recall, F1 score for each class. + report : string/dict. Summary of the precision, recall, F1 score for each class. Examples: >>> from seqeval.metrics import classification_report @@ -324,6 +325,7 @@ def classification_report(y_true, y_pred, digits=2, suffix=False): micro avg 0.50 0.50 0.50 2 macro avg 0.50 0.50 0.50 2 + weighted avg 0.50 0.50 0.50 2 """ true_entities = set(get_entities(y_true, suffix)) @@ -338,15 +340,19 @@ def classification_report(y_true, y_pred, digits=2, suffix=False): for e in pred_entities: d2[e[0]].add((e[1], e[2])) - last_line_heading = 'weighted avg' - width = max(name_width, len(last_line_heading), digits) + avg_types = ['micro avg', 'macro avg', 'weighted avg'] - headers = ["precision", "recall", "f1-score", "support"] - head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers) - report = head_fmt.format(u'', *headers, width=width) - report += u'\n\n' + if output_dict: + report_dict = dict() + else: + avg_width = max([len(x) for x in avg_types]) + width = max(name_width, avg_width, digits) + headers = ["precision", "recall", "f1-score", "support"] + head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers) + report = head_fmt.format(u'', *headers, width=width) + report += u'\n\n' - row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n' + row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n' ps, rs, f1s, s = [], [], [], [] for type_name in sorted(d1.keys()): @@ -360,33 +366,47 @@ def classification_report(y_true, y_pred, digits=2, suffix=False): r = nb_correct / nb_true if nb_true > 0 else 0 f1 = 2 * p * r / (p + r) if p + r > 0 else 0 - report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits) + if output_dict: + report_dict[type_name] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true} + else: + report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits) ps.append(p) rs.append(r) f1s.append(f1) s.append(nb_true) - report += u'\n' + if not output_dict: + report += u'\n' # compute averages - report += row_fmt.format('micro avg', - precision_score(y_true, y_pred, suffix=suffix), - recall_score(y_true, y_pred, suffix=suffix), - f1_score(y_true, y_pred, suffix=suffix), - np.sum(s), - width=width, digits=digits) - report += row_fmt.format('macro avg', - np.average(ps), - np.average(rs), - np.average(f1s), - np.sum(s), - width=width, digits=digits) - report += row_fmt.format(last_line_heading, - np.average(ps, weights=s), - np.average(rs, weights=s), - np.average(f1s, weights=s), - np.sum(s), - width=width, digits=digits) - - return report + nb_true = np.sum(s) + + for avg_type in avg_types: + if avg_type == 'micro avg': + # micro average + p = precision_score(y_true, y_pred, suffix=suffix) + r = recall_score(y_true, y_pred, suffix=suffix) + f1 = f1_score(y_true, y_pred, suffix=suffix) + elif avg_type == 'macro avg': + # macro average + p = np.average(ps) + r = np.average(rs) + f1 = np.average(f1s) + elif avg_type == 'weighted avg': + # weighted average + p = np.average(ps, weights=s) + r = np.average(rs, weights=s) + f1 = np.average(f1s, weights=s) + else: + assert False, "unexpected average: {}".format(avg_type) + + if output_dict: + report_dict[avg_type] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true} + else: + report += row_fmt.format(*[avg_type, p, r, f1, nb_true], width=width, digits=digits) + + if output_dict: + return report_dict + else: + return report