From 3949008081f70738fe578fac10011a6f36e10311 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 22 Aug 2024 10:12:44 +0200 Subject: [PATCH] Fix runpaths bug in evaluate ensemble --- src/ert/run_models/evaluate_ensemble.py | 15 ++++--- .../run_models/test_model_factory.py | 40 ++++++++++++++++++- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/src/ert/run_models/evaluate_ensemble.py b/src/ert/run_models/evaluate_ensemble.py index 05fad857bb9..645944d87ce 100644 --- a/src/ert/run_models/evaluate_ensemble.py +++ b/src/ert/run_models/evaluate_ensemble.py @@ -8,7 +8,7 @@ from ert.ensemble_evaluator import EvaluatorServerConfig from ert.run_context import RunContext from ert.run_models.run_arguments import EvaluateEnsembleRunArguments -from ert.storage import Ensemble, Storage +from ert.storage import Storage from . import BaseRunModel @@ -39,26 +39,25 @@ def __init__( queue_config: QueueConfig, status_queue: SimpleQueue[StatusEvents], ): + ensemble_uuid = UUID(simulation_arguments.ensemble_id) + ensemble = storage.get_ensemble(ensemble_uuid) + self.ensemble = ensemble super().__init__( config, storage, queue_config, status_queue, + number_of_iterations=1, + start_iteration=ensemble.iteration, active_realizations=simulation_arguments.active_realizations, minimum_required_realizations=simulation_arguments.minimum_required_realizations, ) - self.ensemble_id = simulation_arguments.ensemble_id def run_experiment( self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False ) -> RunContext: self.setPhaseName("Running evaluate experiment...") - - ensemble_id = self.ensemble_id - ensemble_uuid = UUID(ensemble_id) - ensemble = self._storage.get_ensemble(ensemble_uuid) - assert isinstance(ensemble, Ensemble) - + ensemble = self.ensemble experiment = ensemble.experiment self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id)) self.set_env_key("_ERT_ENSEMBLE_ID", str(ensemble.id)) diff --git a/tests/unit_tests/run_models/test_model_factory.py b/tests/unit_tests/run_models/test_model_factory.py index f1217d7d419..cdb43a4ed01 100644 --- a/tests/unit_tests/run_models/test_model_factory.py +++ b/tests/unit_tests/run_models/test_model_factory.py @@ -5,7 +5,7 @@ import pytest import ert -from ert.config import ErtConfig +from ert.config import ErtConfig, ModelConfig from ert.libres_facade import LibresFacade from ert.run_models import ( EnsembleExperiment, @@ -15,6 +15,8 @@ SingleTestRun, model_factory, ) +from ert.run_models.evaluate_ensemble import EvaluateEnsemble +from ert.run_models.run_arguments import EvaluateEnsembleRunArguments @pytest.mark.parametrize( @@ -200,4 +202,38 @@ def test_multiple_data_assimilation_restart_paths( ) base_path = tmp_path / "simulations" expected_path = [str(base_path / expected) for expected in expected_path] - assert model.paths == expected_path + assert set(model.paths) == set(expected_path) + + +@pytest.mark.parametrize( + "ensemble_iteration, expected_path", + [ + [0, ["realization-0/iter-0"]], + [1, ["realization-0/iter-1"]], + [2, ["realization-0/iter-2"]], + [100, ["realization-0/iter-100"]], + ], +) +def test_evaluate_ensemble_paths( + tmp_path, monkeypatch, ensemble_iteration, expected_path +): + monkeypatch.chdir(tmp_path) + monkeypatch.setattr( + ert.run_models.base_run_model.BaseRunModel, "validate", MagicMock() + ) + run_args = EvaluateEnsembleRunArguments( + active_realizations=[True], + minimum_required_realizations=1, + ensemble_id=str(uuid1(0)), + random_seed=1234, + ensemble_size=1, + ) + storage_mock = MagicMock() + ensemble_mock = MagicMock() + ensemble_mock.iteration = ensemble_iteration + config = ErtConfig(model_config=ModelConfig(num_realizations=1)) + storage_mock.get_ensemble.return_value = ensemble_mock + model = EvaluateEnsemble(run_args, config, storage_mock, MagicMock(), MagicMock()) + base_path = tmp_path / "simulations" + expected_path = [str(base_path / expected) for expected in expected_path] + assert set(model.paths) == set(expected_path)