Skip to content

Commit

Permalink
feat: add best-fit to toolkit
Browse files Browse the repository at this point in the history
  • Loading branch information
cdummett committed Apr 4, 2023
1 parent c9e30bc commit 8638279
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions parameter_results/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import json

import numpy as np
import pandas as pd
import plotly.express as px
import ipywidgets as widgets
Expand Down Expand Up @@ -202,6 +203,7 @@ def plot_results(
self,
variables: Optional[list] = None,
iterations: Optional[list] = None,
best_fit: bool = False,
):
"""Plots a variables result for each iteration or averaged across iterations.
Expand Down Expand Up @@ -257,6 +259,7 @@ def plot_results(
xlabel=f"Time [{self.granularity.name}]",
title=f"{keys[i][0]} // (iteration={keys[i][1]})",
labels=labels,
best_fit=best_fit,
)

def plot_comparison(
Expand Down Expand Up @@ -644,6 +647,7 @@ def _add_plot(
title: Optional[str] = None,
labels: Optional[list] = None,
labels_right: Optional[list] = None,
best_fit: bool = False,
):
"""Adds a complete plot to a given subplot axes.
Expand Down Expand Up @@ -687,6 +691,7 @@ def _add_plot(
parameters=parameters,
iterations=iterations,
labels=labels,
best_fit=best_fit,
)

if variables_right is not None:
Expand All @@ -698,6 +703,7 @@ def _add_plot(
parameters=parameters,
iterations=iterations,
labels=labels_right,
best_fit=best_fit,
)
lns = lns + lns_right

Expand All @@ -723,6 +729,7 @@ def _add_data(
parameters: Optional[list] = None,
iterations: list = ["avg"],
labels: Optional[list] = None,
best_fit: bool = False,
):
"""Adds a plot to a specific axes of a subplot.
Expand Down Expand Up @@ -769,6 +776,16 @@ def _add_data(
label = f"{labels[0][i]} {labels[1][j]} {labels[2][k]}"
lns.append(ax.plot(xdata, ydata, fmt, label=label))

if best_fit:
a, b = np.polyfit(xdata, ydata, deg=1)
print(label, a, b)
ax.plot(
xdata,
xdata * a + b,
fmt,
label=None,
)

else:
df = self.data_raw[parameter_index][
self.data_raw[parameter_index]["Iteration"] == iteration
Expand All @@ -784,6 +801,16 @@ def _add_data(
label = f"{labels[0][i]} {labels[1][j]} {labels[2][k]}"
lns.append(ax.plot(xdata, ydata, fmt, label=label))

if best_fit:
a, b = np.polyfit(xdata, ydata, deg=1)
print(label, a, b)
ax.plot(
xdata,
xdata * a + b,
fmt,
label=None,
)

if "Market State" in variables:
self._market_state_ticks(ax)

Expand Down

0 comments on commit 8638279

Please sign in to comment.