From bf073b9ceb5a0a61f271c832abaa858b76485735 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Wed, 29 Jan 2025 09:51:24 -0800 Subject: [PATCH] fix(studio): package data to send in main thread (#860) --- src/dvclive/live.py | 31 ++++++++-- src/dvclive/studio.py | 57 +++++++++++------- tests/test_post_to_studio.py | 112 +++++++++++++++++++++++++++++------ 3 files changed, 155 insertions(+), 45 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 754137f..d73be31 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -55,6 +55,7 @@ inside_notebook, matplotlib_installed, open_file_in_browser, + parse_metrics, ) from .vscode import ( cleanup_dvclive_step_completed, @@ -135,7 +136,7 @@ def __init__( self._save_dvc_exp: bool = save_dvc_exp self._step: Optional[int] = None self._metrics: Dict[str, Any] = {} - self._images: Dict[str, Any] = {} + self._images: Dict[str, Image] = {} self._params: Dict[str, Any] = {} self._plots: Dict[str, Any] = {} self._artifacts: Dict[str, Dict] = {} @@ -901,19 +902,41 @@ def make_dvcyaml(self): """ make_dvcyaml(self) + def _get_live_data(self) -> Optional[dict[str, Any]]: + params = load_yaml(self.params_file) if os.path.isfile(self.params_file) else {} + plots, metrics = parse_metrics(self) + + # Plots can grow large, we don't want to keep in memory data + # that we 100% sent already + plots_to_send = {} + plots_start_idx = {} + for name, plot in plots.items(): + num_points_sent = self._num_points_sent_to_studio.get(name, 0) + plots_to_send[name] = plot[num_points_sent:] + plots_start_idx[name] = num_points_sent + + return { + "params": params, + "plots": plots_to_send, + "plots_start_idx": plots_start_idx, + "metrics": metrics, + "images": list(self._images.values()), + "step": self.step, + } + def post_data_to_studio(self): if not self._studio_queue: self._studio_queue = queue.Queue() def worker(): while True: - item = self._studio_queue.get() - post_to_studio(item, "data") + item, data = self._studio_queue.get() + post_to_studio(item, "data", data) self._studio_queue.task_done() threading.Thread(target=worker, daemon=True).start() - self._studio_queue.put(self) + self._studio_queue.put((self, self._get_live_data())) def _wait_for_studio_updates_posted(self): if self._studio_queue: diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 6ed398e..c940ee6 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -5,7 +5,7 @@ import math import os from pathlib import PureWindowsPath -from typing import TYPE_CHECKING, Literal, Mapping +from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional from dvc.exceptions import DvcException from dvc_studio_client.config import get_studio_config @@ -14,9 +14,9 @@ from .utils import catch_and_warn if TYPE_CHECKING: + from dvclive.plots.image import Image from dvclive.live import Live -from dvclive.serialize import load_yaml -from dvclive.utils import parse_metrics, rel_path, StrPath +from dvclive.utils import rel_path, StrPath logger = logging.getLogger("dvclive") @@ -50,23 +50,24 @@ def _adapt_image(image_path: StrPath): return base64.b64encode(fobj.read()).decode("utf-8") -def _adapt_images(live: Live): +def _adapt_images(live: Live, images: list[Image]): return { _adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)} - for image in live._images.values() + for image in images if image.step > live._latest_studio_step } -def get_studio_updates(live: Live): - if os.path.isfile(live.params_file): - params_file = live.params_file - params_file = _adapt_path(live, params_file) - params = {params_file: load_yaml(live.params_file)} - else: - params = {} +def _get_studio_updates(live: Live, data: dict[str, Any]): + params = data["params"] + plots = data["plots"] + plots_start_idx = data["plots_start_idx"] + metrics = data["metrics"] + images = data["images"] - plots, metrics = parse_metrics(live) + params_file = live.params_file + params_file = _adapt_path(live, params_file) + params = {params_file: params} metrics_file = live.metrics_file metrics_file = _adapt_path(live, metrics_file) @@ -75,11 +76,12 @@ def get_studio_updates(live: Live): plots_to_send = {} for name, plot in plots.items(): path = _adapt_path(live, name) - num_points_sent = live._num_points_sent_to_studio.get(path, 0) - plots_to_send[path] = _cast_to_numbers(plot[num_points_sent:]) + start_idx = plots_start_idx.get(name, 0) + num_points_sent = live._num_points_sent_to_studio.get(name, 0) + plots_to_send[path] = _cast_to_numbers(plot[num_points_sent - start_idx :]) plots_to_send = {k: {"data": v} for k, v in plots_to_send.items()} - plots_to_send.update(_adapt_images(live)) + plots_to_send.update(_adapt_images(live, images)) return metrics, params, plots_to_send @@ -91,8 +93,10 @@ def get_dvc_studio_config(live: Live): return get_studio_config(dvc_studio_config=config) -def increment_num_points_sent_to_studio(live, plots): - for name, plot in plots.items(): +def increment_num_points_sent_to_studio(live, plots_sent, data): + for name, _ in data["plots"].items(): + path = _adapt_path(live, name) + plot = plots_sent.get(path, {}) if "data" in plot: num_points_sent = live._num_points_sent_to_studio.get(name, 0) live._num_points_sent_to_studio[name] = num_points_sent + len(plot["data"]) @@ -100,7 +104,11 @@ def increment_num_points_sent_to_studio(live, plots): @catch_and_warn(DvcException, logger) -def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901 +def post_to_studio( # noqa: C901 + live: Live, + event: Literal["start", "data", "done"], + data: Optional[dict[str, Any]] = None, +): if event in live._studio_events_to_skip: return @@ -111,8 +119,9 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa if subdir := live._subdir: kwargs["subdir"] = subdir elif event == "data": - metrics, params, plots = get_studio_updates(live) - kwargs["step"] = live.step # type: ignore + assert data is not None # noqa: S101 + metrics, params, plots = _get_studio_updates(live, data) + kwargs["step"] = data["step"] # type: ignore kwargs["metrics"] = metrics kwargs["params"] = params kwargs["plots"] = plots @@ -128,6 +137,7 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa studio_repo_url=live._repo_url, **kwargs, # type: ignore ) + if not response: logger.warning(f"`post_to_studio` `{event}` failed.") if event == "start": @@ -135,8 +145,9 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa live._studio_events_to_skip.add("data") live._studio_events_to_skip.add("done") elif event == "data": - live = increment_num_points_sent_to_studio(live, plots) - live._latest_studio_step = live.step + assert data is not None # noqa: S101 + live = increment_num_points_sent_to_studio(live, plots, data) + live._latest_studio_step = data["step"] if event == "done": live._studio_events_to_skip.add("done") diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 3f585eb..0c6d119 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -1,4 +1,7 @@ +from collections import defaultdict +from copy import deepcopy from pathlib import Path +import unittest import pytest import time @@ -49,7 +52,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): live.log_metric("foo", 1) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -64,7 +68,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): live.step += 1 live.log_metric("foo", 2) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -78,7 +83,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): mocked_post.reset_mock() live.save_dvc_exp() - post_to_studio(live, "done") + data = live._get_live_data() + post_to_studio(live, "done", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -126,13 +132,15 @@ def test_post_to_studio_failed_data_request( live.log_metric("foo", 1) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post = mocker.patch("requests.post", return_value=valid_response) live.step += 1 live.log_metric("foo", 2) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live", **get_studio_call( @@ -187,6 +195,61 @@ def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_p assert expected_done_calls == actual_done_calls +def test_post_to_studio_snapshots_data_to_send( + tmp_dir, mocked_dvc_repo, mocked_studio_post +): + # Tests race condition between main app thread and Studio post thread + # where the main thread can be faster in producing metrics than the + # Studio post thread in sending them. + mocked_post, _ = mocked_studio_post + + calls = defaultdict(dict) + + def _long_post(*_, **kwargs): + if kwargs["json"]["type"] == "data": + # Mock by default doesn't copy lists, dict, we share "body" var in + # some calls, thus we can't rely on `mocked_post.call_args_list` + json = deepcopy(kwargs)["json"] + step = json["step"] + for key in ["metrics", "params", "plots"]: + if key in json: + calls[step][key] = json[key] + time.sleep(0.1) + return unittest.mock.DEFAULT + + mocked_post.side_effect = lambda *args, **kwargs: _long_post(*args, **kwargs) + + live = Live() + for i in range(10): + live.log_metric("foo", i) + live.log_param(f"fooparam-{i}", i) + live.log_image(f"foo.{i}.png", ImagePIL.new("RGB", (i + 1, i + 1), (0, 0, 0))) + live.next_step() + + live._wait_for_studio_updates_posted() + + assert len(calls) == 10 + for i in range(10): + call = calls[i] + assert call["metrics"] == { + "dvclive/metrics.json": {"data": {"foo": i, "step": i}} + } + assert call["params"] == { + "dvclive/params.yaml": {f"fooparam-{k}": k for k in range(i + 1)} + } + # Check below that `plots`` has the following shape + # { + # 'dvclive/plots/metrics/foo.tsv': {'data': [{'step': i, 'foo': float(i)}]}, + # f"dvclive/plots/images/foo.{i}.png": {'image': '...'} + # } + assert len(call["plots"]) == 2 + foo_data = call["plots"]["dvclive/plots/metrics/foo.tsv"]["data"] + assert len(foo_data) == 1 + assert foo_data[0]["step"] == i + assert foo_data[0]["foo"] == pytest.approx(float(i)) + assert call["plots"][f"dvclive/plots/images/foo.{i}.png"]["image"] + + def test_studio_updates_posted_on_end(tmp_path, mocked_dvc_repo, mocked_studio_post): mocked_post, valid_response = mocked_studio_post metrics_file = tmp_path / "metrics.json" @@ -247,7 +310,8 @@ def test_post_to_studio_dvc_studio_config( live.log_metric("foo", 1) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token" @@ -270,7 +334,8 @@ def test_post_to_studio_skip_if_no_token( live.log_metric("foo", 1) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) assert mocked_post.call_count == 0 @@ -281,7 +346,8 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po live = Live() live.log_metric("eval/loss", 1) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) plots_path = Path(live.plots_dir) loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix() @@ -311,7 +377,8 @@ def test_post_to_studio_inside_dvc_exp( live.log_metric("foo", 1) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list] assert "start" not in call_types @@ -330,7 +397,8 @@ def test_post_to_studio_inside_subdir( live = Live() live.log_metric("foo", 1) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() @@ -361,7 +429,8 @@ def test_post_to_studio_inside_subdir_dvc_exp( live = Live() live.log_metric("foo", 1) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() @@ -416,7 +485,8 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0))) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix() @@ -461,7 +531,8 @@ def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_ live.log_metric("foo", 1) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post, _ = mocked_studio_post call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list] @@ -488,7 +559,8 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): live.log_metric("foo", 1) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -504,7 +576,8 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): live.step += 1 live.log_metric("foo", 2) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -538,7 +611,8 @@ def test_post_to_studio_skip_if_no_repo_url( live.log_metric("foo", 1) live.step = 0 live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) assert mocked_post.call_count == 0 @@ -557,7 +631,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post live.log_metric("foo", 1) live.log_metric("bar", 0.1) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live", @@ -576,7 +651,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post live.log_metric("foo", 3) live.log_metric("bar", 0.2) live.make_summary() - post_to_studio(live, "data") + data = live._get_live_data() + post_to_studio(live, "data", data) mocked_post.assert_called_with( "https://0.0.0.0/api/live",