Skip to content

Commit

Permalink
remove imported types and fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreKempf committed Feb 6, 2024
1 parent 1f4c362 commit 06234cf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .plots import PLOT_TYPES, SKLEARN_PLOTS, CustomPlot, Image, Metric, NumpyEncoder
from .report import BLANK_NOTEBOOK_REPORT, make_report
from .serialize import dump_json, dump_yaml, load_yaml
from .studio import StudioEventKind, get_dvc_studio_config, post_to_studio
from .studio import get_dvc_studio_config, post_to_studio
from .utils import (
StrPath,
catch_and_warn,
Expand All @@ -64,7 +64,7 @@
logger.addHandler(handler)

ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]]
SkleanPlotKind = [*SKLEARN_PLOTS.keys()]
SklearnPlotKind = [*SKLEARN_PLOTS.keys()]


class Live:
Expand Down Expand Up @@ -443,7 +443,7 @@ def log_plot(

def log_sklearn_plot(
self,
kind: SkleanPlotKind,
kind: SklearnPlotKind,
labels: Union[List, np.ndarray],
predictions: Union[List, Tuple, np.ndarray],
name: Optional[str] = None,
Expand Down Expand Up @@ -588,7 +588,7 @@ def make_dvcyaml(self):
make_dvcyaml(self)

@catch_and_warn(DvcException, logger)
def post_to_studio(self, event: StudioEventKind):
def post_to_studio(self, event: str):
post_to_studio(self, event)

def end(self):
Expand Down
3 changes: 0 additions & 3 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@

from dvc_studio_client.config import get_studio_config
from dvc_studio_client.post_live_metrics import post_live_metrics
from dvc_studio_client.schema import BASE_SCHEMA

from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics, rel_path

logger = logging.getLogger("dvclive")

StudioEventKind = [*BASE_SCHEMA.schema["type"].validators]


def _get_unsent_datapoints(plot, latest_step):
return [x for x in plot if int(x["step"]) > latest_step]
Expand Down

0 comments on commit 06234cf

Please sign in to comment.