Skip to content

Commit

Permalink
Fix runpaths bug in evaluate ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Aug 22, 2024
1 parent ac2d917 commit 3949008
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
15 changes: 7 additions & 8 deletions src/ert/run_models/evaluate_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_context import RunContext
from ert.run_models.run_arguments import EvaluateEnsembleRunArguments
from ert.storage import Ensemble, Storage
from ert.storage import Storage

from . import BaseRunModel

Expand Down Expand Up @@ -39,26 +39,25 @@ def __init__(
queue_config: QueueConfig,
status_queue: SimpleQueue[StatusEvents],
):
ensemble_uuid = UUID(simulation_arguments.ensemble_id)
ensemble = storage.get_ensemble(ensemble_uuid)
self.ensemble = ensemble
super().__init__(
config,
storage,
queue_config,
status_queue,
number_of_iterations=1,
start_iteration=ensemble.iteration,
active_realizations=simulation_arguments.active_realizations,
minimum_required_realizations=simulation_arguments.minimum_required_realizations,
)
self.ensemble_id = simulation_arguments.ensemble_id

def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
) -> RunContext:
self.setPhaseName("Running evaluate experiment...")

ensemble_id = self.ensemble_id
ensemble_uuid = UUID(ensemble_id)
ensemble = self._storage.get_ensemble(ensemble_uuid)
assert isinstance(ensemble, Ensemble)

ensemble = self.ensemble
experiment = ensemble.experiment
self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id))
self.set_env_key("_ERT_ENSEMBLE_ID", str(ensemble.id))
Expand Down
40 changes: 38 additions & 2 deletions tests/unit_tests/run_models/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import ert
from ert.config import ErtConfig
from ert.config import ErtConfig, ModelConfig
from ert.libres_facade import LibresFacade
from ert.run_models import (
EnsembleExperiment,
Expand All @@ -15,6 +15,8 @@
SingleTestRun,
model_factory,
)
from ert.run_models.evaluate_ensemble import EvaluateEnsemble
from ert.run_models.run_arguments import EvaluateEnsembleRunArguments


@pytest.mark.parametrize(
Expand Down Expand Up @@ -200,4 +202,38 @@ def test_multiple_data_assimilation_restart_paths(
)
base_path = tmp_path / "simulations"
expected_path = [str(base_path / expected) for expected in expected_path]
assert model.paths == expected_path
assert set(model.paths) == set(expected_path)


@pytest.mark.parametrize(
"ensemble_iteration, expected_path",
[
[0, ["realization-0/iter-0"]],
[1, ["realization-0/iter-1"]],
[2, ["realization-0/iter-2"]],
[100, ["realization-0/iter-100"]],
],
)
def test_evaluate_ensemble_paths(
tmp_path, monkeypatch, ensemble_iteration, expected_path
):
monkeypatch.chdir(tmp_path)
monkeypatch.setattr(
ert.run_models.base_run_model.BaseRunModel, "validate", MagicMock()
)
run_args = EvaluateEnsembleRunArguments(
active_realizations=[True],
minimum_required_realizations=1,
ensemble_id=str(uuid1(0)),
random_seed=1234,
ensemble_size=1,
)
storage_mock = MagicMock()
ensemble_mock = MagicMock()
ensemble_mock.iteration = ensemble_iteration
config = ErtConfig(model_config=ModelConfig(num_realizations=1))
storage_mock.get_ensemble.return_value = ensemble_mock
model = EvaluateEnsemble(run_args, config, storage_mock, MagicMock(), MagicMock())
base_path = tmp_path / "simulations"
expected_path = [str(base_path / expected) for expected in expected_path]
assert set(model.paths) == set(expected_path)

0 comments on commit 3949008

Please sign in to comment.