Skip to content

Commit

Permalink
fix(studio): package data to send in main thread (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Jan 29, 2025
1 parent 40e4b4e commit bf073b9
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 45 deletions.
31 changes: 27 additions & 4 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
inside_notebook,
matplotlib_installed,
open_file_in_browser,
parse_metrics,
)
from .vscode import (
cleanup_dvclive_step_completed,
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
self._save_dvc_exp: bool = save_dvc_exp
self._step: Optional[int] = None
self._metrics: Dict[str, Any] = {}
self._images: Dict[str, Any] = {}
self._images: Dict[str, Image] = {}
self._params: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}
self._artifacts: Dict[str, Dict] = {}
Expand Down Expand Up @@ -901,19 +902,41 @@ def make_dvcyaml(self):
"""
make_dvcyaml(self)

def _get_live_data(self) -> Optional[dict[str, Any]]:
params = load_yaml(self.params_file) if os.path.isfile(self.params_file) else {}
plots, metrics = parse_metrics(self)

# Plots can grow large, we don't want to keep in memory data
# that we 100% sent already
plots_to_send = {}
plots_start_idx = {}
for name, plot in plots.items():
num_points_sent = self._num_points_sent_to_studio.get(name, 0)
plots_to_send[name] = plot[num_points_sent:]
plots_start_idx[name] = num_points_sent

return {
"params": params,
"plots": plots_to_send,
"plots_start_idx": plots_start_idx,
"metrics": metrics,
"images": list(self._images.values()),
"step": self.step,
}

def post_data_to_studio(self):
if not self._studio_queue:
self._studio_queue = queue.Queue()

def worker():
while True:
item = self._studio_queue.get()
post_to_studio(item, "data")
item, data = self._studio_queue.get()
post_to_studio(item, "data", data)
self._studio_queue.task_done()

threading.Thread(target=worker, daemon=True).start()

self._studio_queue.put(self)
self._studio_queue.put((self, self._get_live_data()))

def _wait_for_studio_updates_posted(self):
if self._studio_queue:
Expand Down
57 changes: 34 additions & 23 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import os
from pathlib import PureWindowsPath
from typing import TYPE_CHECKING, Literal, Mapping
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional

from dvc.exceptions import DvcException
from dvc_studio_client.config import get_studio_config
Expand All @@ -14,9 +14,9 @@
from .utils import catch_and_warn

if TYPE_CHECKING:
from dvclive.plots.image import Image
from dvclive.live import Live
from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics, rel_path, StrPath
from dvclive.utils import rel_path, StrPath

logger = logging.getLogger("dvclive")

Expand Down Expand Up @@ -50,23 +50,24 @@ def _adapt_image(image_path: StrPath):
return base64.b64encode(fobj.read()).decode("utf-8")


def _adapt_images(live: Live):
def _adapt_images(live: Live, images: list[Image]):
return {
_adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)}
for image in live._images.values()
for image in images
if image.step > live._latest_studio_step
}


def get_studio_updates(live: Live):
if os.path.isfile(live.params_file):
params_file = live.params_file
params_file = _adapt_path(live, params_file)
params = {params_file: load_yaml(live.params_file)}
else:
params = {}
def _get_studio_updates(live: Live, data: dict[str, Any]):
params = data["params"]
plots = data["plots"]
plots_start_idx = data["plots_start_idx"]
metrics = data["metrics"]
images = data["images"]

plots, metrics = parse_metrics(live)
params_file = live.params_file
params_file = _adapt_path(live, params_file)
params = {params_file: params}

metrics_file = live.metrics_file
metrics_file = _adapt_path(live, metrics_file)
Expand All @@ -75,11 +76,12 @@ def get_studio_updates(live: Live):
plots_to_send = {}
for name, plot in plots.items():
path = _adapt_path(live, name)
num_points_sent = live._num_points_sent_to_studio.get(path, 0)
plots_to_send[path] = _cast_to_numbers(plot[num_points_sent:])
start_idx = plots_start_idx.get(name, 0)
num_points_sent = live._num_points_sent_to_studio.get(name, 0)
plots_to_send[path] = _cast_to_numbers(plot[num_points_sent - start_idx :])

plots_to_send = {k: {"data": v} for k, v in plots_to_send.items()}
plots_to_send.update(_adapt_images(live))
plots_to_send.update(_adapt_images(live, images))

return metrics, params, plots_to_send

Expand All @@ -91,16 +93,22 @@ def get_dvc_studio_config(live: Live):
return get_studio_config(dvc_studio_config=config)


def increment_num_points_sent_to_studio(live, plots):
for name, plot in plots.items():
def increment_num_points_sent_to_studio(live, plots_sent, data):
for name, _ in data["plots"].items():
path = _adapt_path(live, name)
plot = plots_sent.get(path, {})
if "data" in plot:
num_points_sent = live._num_points_sent_to_studio.get(name, 0)
live._num_points_sent_to_studio[name] = num_points_sent + len(plot["data"])
return live


@catch_and_warn(DvcException, logger)
def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901
def post_to_studio( # noqa: C901
live: Live,
event: Literal["start", "data", "done"],
data: Optional[dict[str, Any]] = None,
):
if event in live._studio_events_to_skip:
return

Expand All @@ -111,8 +119,9 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
if subdir := live._subdir:
kwargs["subdir"] = subdir
elif event == "data":
metrics, params, plots = get_studio_updates(live)
kwargs["step"] = live.step # type: ignore
assert data is not None # noqa: S101
metrics, params, plots = _get_studio_updates(live, data)
kwargs["step"] = data["step"] # type: ignore
kwargs["metrics"] = metrics
kwargs["params"] = params
kwargs["plots"] = plots
Expand All @@ -128,15 +137,17 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
studio_repo_url=live._repo_url,
**kwargs, # type: ignore
)

if not response:
logger.warning(f"`post_to_studio` `{event}` failed.")
if event == "start":
live._studio_events_to_skip.add("start")
live._studio_events_to_skip.add("data")
live._studio_events_to_skip.add("done")
elif event == "data":
live = increment_num_points_sent_to_studio(live, plots)
live._latest_studio_step = live.step
assert data is not None # noqa: S101
live = increment_num_points_sent_to_studio(live, plots, data)
live._latest_studio_step = data["step"]

if event == "done":
live._studio_events_to_skip.add("done")
Expand Down
Loading

0 comments on commit bf073b9

Please sign in to comment.