Skip to content

Commit

Permalink
fix(studio): package data to send in main thread
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 25, 2025
1 parent 4b9c726 commit d34234a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 40 deletions.
21 changes: 17 additions & 4 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
inside_notebook,
matplotlib_installed,
open_file_in_browser,
parse_metrics,
)
from .vscode import (
cleanup_dvclive_step_completed,
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
self._save_dvc_exp: bool = save_dvc_exp
self._step: Optional[int] = None
self._metrics: Dict[str, Any] = {}
self._images: Dict[str, Any] = {}
self._images: Dict[str, Image] = {}
self._params: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}
self._artifacts: Dict[str, Dict] = {}
Expand Down Expand Up @@ -901,19 +902,31 @@ def make_dvcyaml(self):
"""
make_dvcyaml(self)

def _get_live_data(self) -> Optional[dict[str, Any]]:
params = load_yaml(self.params_file) if os.path.isfile(self.params_file) else {}
plots, metrics = parse_metrics(self)

return {
"params": params,
"plots": plots,
"metrics": metrics,
"images": list(self._images.values()),
"step": self.step,
}

def post_data_to_studio(self):
if not self._studio_queue:
self._studio_queue = queue.Queue()

def worker():
while True:
item = self._studio_queue.get()
post_to_studio(item, "data")
item, data = self._studio_queue.get()
post_to_studio(item, "data", data)
self._studio_queue.task_done()

threading.Thread(target=worker, daemon=True).start()

self._studio_queue.put(self)
self._studio_queue.put((self, self._get_live_data()))

def _wait_for_studio_updates_posted(self):
if self._studio_queue:
Expand Down
43 changes: 25 additions & 18 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import os
from pathlib import PureWindowsPath
from typing import TYPE_CHECKING, Literal, Mapping
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional

from dvc.exceptions import DvcException
from dvc_studio_client.config import get_studio_config
Expand All @@ -14,9 +14,9 @@
from .utils import catch_and_warn

if TYPE_CHECKING:
from dvclive.plots.image import Image
from dvclive.live import Live
from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics, rel_path, StrPath
from dvclive.utils import rel_path, StrPath

logger = logging.getLogger("dvclive")

Expand Down Expand Up @@ -50,23 +50,23 @@ def _adapt_image(image_path: StrPath):
return base64.b64encode(fobj.read()).decode("utf-8")


def _adapt_images(live: Live):
def _adapt_images(live: Live, images: list[Image]):
return {
_adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)}
for image in live._images.values()
for image in images
if image.step > live._latest_studio_step
}


def get_studio_updates(live: Live):
if os.path.isfile(live.params_file):
params_file = live.params_file
params_file = _adapt_path(live, params_file)
params = {params_file: load_yaml(live.params_file)}
else:
params = {}
def _get_studio_updates(live: Live, data: dict[str, Any]):
params = data["params"]
plots = data["plots"]
metrics = data["metrics"]
images = data["images"]

plots, metrics = parse_metrics(live)
params_file = live.params_file
params_file = _adapt_path(live, params_file)
params = {params_file: params}

metrics_file = live.metrics_file
metrics_file = _adapt_path(live, metrics_file)
Expand All @@ -79,7 +79,7 @@ def get_studio_updates(live: Live):
plots_to_send[path] = _cast_to_numbers(plot[num_points_sent:])

plots_to_send = {k: {"data": v} for k, v in plots_to_send.items()}
plots_to_send.update(_adapt_images(live))
plots_to_send.update(_adapt_images(live, images))

return metrics, params, plots_to_send

Expand All @@ -100,7 +100,11 @@ def increment_num_points_sent_to_studio(live, plots):


@catch_and_warn(DvcException, logger)
def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901
def post_to_studio( # noqa: C901
live: Live,
event: Literal["start", "data", "done"],
data: Optional[dict[str, Any]] = None,
):
if event in live._studio_events_to_skip:
return

Expand All @@ -111,8 +115,9 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
if subdir := live._subdir:
kwargs["subdir"] = subdir
elif event == "data":
metrics, params, plots = get_studio_updates(live)
kwargs["step"] = live.step # type: ignore
assert data is not None # noqa: S101
metrics, params, plots = _get_studio_updates(live, data)
kwargs["step"] = data["step"] # type: ignore
kwargs["metrics"] = metrics
kwargs["params"] = params
kwargs["plots"] = plots
Expand All @@ -128,15 +133,17 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
studio_repo_url=live._repo_url,
**kwargs, # type: ignore
)

if not response:
logger.warning(f"`post_to_studio` `{event}` failed.")
if event == "start":
live._studio_events_to_skip.add("start")
live._studio_events_to_skip.add("data")
live._studio_events_to_skip.add("done")
elif event == "data":
assert data is not None # noqa: S101
live = increment_num_points_sent_to_studio(live, plots)
live._latest_studio_step = live.step
live._latest_studio_step = data["step"]

if event == "done":
live._studio_events_to_skip.add("done")
Expand Down
54 changes: 36 additions & 18 deletions tests/test_post_to_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -64,7 +65,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
live.step += 1
live.log_metric("foo", 2)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -78,7 +80,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):

mocked_post.reset_mock()
live.save_dvc_exp()
post_to_studio(live, "done")
data = live._get_live_data()
post_to_studio(live, "done", data)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand Down Expand Up @@ -126,13 +129,15 @@ def test_post_to_studio_failed_data_request(
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post = mocker.patch("requests.post", return_value=valid_response)
live.step += 1
live.log_metric("foo", 2)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call(
Expand Down Expand Up @@ -247,7 +252,8 @@ def test_post_to_studio_dvc_studio_config(
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token"

Expand All @@ -270,7 +276,8 @@ def test_post_to_studio_skip_if_no_token(
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

assert mocked_post.call_count == 0

Expand All @@ -281,7 +288,8 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po
live = Live()
live.log_metric("eval/loss", 1)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

plots_path = Path(live.plots_dir)
loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix()
Expand Down Expand Up @@ -311,7 +319,8 @@ def test_post_to_studio_inside_dvc_exp(
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
assert "start" not in call_types
Expand All @@ -330,7 +339,8 @@ def test_post_to_studio_inside_subdir(
live = Live()
live.log_metric("foo", 1)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()

Expand Down Expand Up @@ -361,7 +371,8 @@ def test_post_to_studio_inside_subdir_dvc_exp(
live = Live()
live.log_metric("foo", 1)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()

Expand Down Expand Up @@ -416,7 +427,8 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post):
live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0)))
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix()

Expand Down Expand Up @@ -461,7 +473,8 @@ def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post, _ = mocked_studio_post
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
Expand All @@ -488,7 +501,8 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):

live.log_metric("foo", 1)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -504,7 +518,8 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
live.step += 1
live.log_metric("foo", 2)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand Down Expand Up @@ -538,7 +553,8 @@ def test_post_to_studio_skip_if_no_repo_url(
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

assert mocked_post.call_count == 0

Expand All @@ -557,7 +573,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
live.log_metric("foo", 1)
live.log_metric("bar", 0.1)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -576,7 +593,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
live.log_metric("foo", 3)
live.log_metric("bar", 0.2)
live.make_summary()
post_to_studio(live, "data")
data = live._get_live_data()
post_to_studio(live, "data", data)

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand Down

0 comments on commit d34234a

Please sign in to comment.