From 8638279386aa84af000ff269c7021982e17f8b74 Mon Sep 17 00:00:00 2001 From: Charlie Date: Wed, 22 Mar 2023 18:46:02 +0000 Subject: [PATCH] feat: add best-fit to toolkit --- parameter_results/toolkit.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/parameter_results/toolkit.py b/parameter_results/toolkit.py index 757e162e1..6b7d803f7 100644 --- a/parameter_results/toolkit.py +++ b/parameter_results/toolkit.py @@ -14,6 +14,7 @@ import os import json +import numpy as np import pandas as pd import plotly.express as px import ipywidgets as widgets @@ -202,6 +203,7 @@ def plot_results( self, variables: Optional[list] = None, iterations: Optional[list] = None, + best_fit: bool = False, ): """Plots a variables result for each iteration or averaged across iterations. @@ -257,6 +259,7 @@ def plot_results( xlabel=f"Time [{self.granularity.name}]", title=f"{keys[i][0]} // (iteration={keys[i][1]})", labels=labels, + best_fit=best_fit, ) def plot_comparison( @@ -644,6 +647,7 @@ def _add_plot( title: Optional[str] = None, labels: Optional[list] = None, labels_right: Optional[list] = None, + best_fit: bool = False, ): """Adds a complete plot to a given subplot axes. @@ -687,6 +691,7 @@ def _add_plot( parameters=parameters, iterations=iterations, labels=labels, + best_fit=best_fit, ) if variables_right is not None: @@ -698,6 +703,7 @@ def _add_plot( parameters=parameters, iterations=iterations, labels=labels_right, + best_fit=best_fit, ) lns = lns + lns_right @@ -723,6 +729,7 @@ def _add_data( parameters: Optional[list] = None, iterations: list = ["avg"], labels: Optional[list] = None, + best_fit: bool = False, ): """Adds a plot to a specific axes of a subplot. @@ -769,6 +776,16 @@ def _add_data( label = f"{labels[0][i]} {labels[1][j]} {labels[2][k]}" lns.append(ax.plot(xdata, ydata, fmt, label=label)) + if best_fit: + a, b = np.polyfit(xdata, ydata, deg=1) + print(label, a, b) + ax.plot( + xdata, + xdata * a + b, + fmt, + label=None, + ) + else: df = self.data_raw[parameter_index][ self.data_raw[parameter_index]["Iteration"] == iteration @@ -784,6 +801,16 @@ def _add_data( label = f"{labels[0][i]} {labels[1][j]} {labels[2][k]}" lns.append(ax.plot(xdata, ydata, fmt, label=label)) + if best_fit: + a, b = np.polyfit(xdata, ydata, deg=1) + print(label, a, b) + ax.plot( + xdata, + xdata * a + b, + fmt, + label=None, + ) + if "Market State" in variables: self._market_state_ticks(ax)