Skip to content

Commit

Permalink
Add test for rust estimator parser matches pure python implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles committed Jun 18, 2024
1 parent 2e4f3b4 commit 6218996
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 13 deletions.
32 changes: 23 additions & 9 deletions artistools/estimators/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def get_rankbatch_parquetfile(
batch_mpiranks: t.Sequence[int],
batchindex: int,
modelpath: Path | str | None = None,
use_rust: bool = True,
use_rust_parser: bool | None = True,
) -> Path:
modelpath = Path(folderpath).parent if modelpath is None else Path(modelpath)
folderpath = Path(folderpath)
Expand Down Expand Up @@ -248,20 +248,29 @@ def get_rankbatch_parquetfile(

time_start = time.perf_counter()

try:
from artistools.rustext import estimparse as rustestimparse
except ImportError:
warnings.warn("WARNING: Rust extension not available. Falling back to slow python reader.", stacklevel=2)
use_rust = False
if use_rust_parser is None or use_rust_parser:
try:
from artistools.rustext import estimparse as rustestimparse

use_rust_parser = True

except ImportError as err:
warnings.warn(
"WARNING: Rust extension not available. Falling back to slow python reader.", stacklevel=2
)
if use_rust_parser:
msg = "Rust extension not available"
raise ImportError(msg) from err
use_rust_parser = False

print(
f" reading {len(estfilepaths)} estimator files in {folderpath.relative_to(Path(folderpath).parent)} with {'fast rust reader' if use_rust else 'slow python reader'}...",
f" reading {len(estfilepaths)} estimator files in {folderpath.relative_to(Path(folderpath).parent)} with {'fast rust reader' if use_rust_parser else 'slow python reader'}...",
end="",
flush=True,
)

pldf_batch: pl.DataFrame
if use_rust:
if use_rust_parser:
pldf_batch = rustestimparse(str(folderpath), min(batch_mpiranks), max(batch_mpiranks))
pldf_batch = pldf_batch.with_columns(
pl.col(c).cast(pl.Int32)
Expand Down Expand Up @@ -315,6 +324,7 @@ def scan_estimators(
modelpath: Path | str = Path(),
modelgridindex: None | int | t.Sequence[int] = None,
timestep: None | int | t.Sequence[int] = None,
use_rust_parser: bool | None = None,
) -> pl.LazyFrame:
"""Read estimator files into a dictionary of (timestep, modelgridindex): estimators.
Expand Down Expand Up @@ -367,7 +377,11 @@ def scan_estimators(

parquetfiles = (
get_rankbatch_parquetfile(
modelpath=modelpath, folderpath=runfolder, batch_mpiranks=mpiranks, batchindex=batchindex
modelpath=modelpath,
folderpath=runfolder,
batch_mpiranks=mpiranks,
batchindex=batchindex,
use_rust_parser=use_rust_parser,
)
for runfolder in runfolders
for batchindex, mpiranks in mpirank_groups
Expand Down
23 changes: 23 additions & 0 deletions artistools/estimators/test_estimators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import shutil
from pathlib import Path
from unittest import mock

import matplotlib.axes
import numpy as np
import polars as pl
import polars.testing as pltest
import pytest

import artistools as at
Expand Down Expand Up @@ -306,3 +308,24 @@ def test_estimator_timeevolution(mockplot) -> None:
modelgridindex=0,
x="time",
)


@pytest.mark.benchmark()
def test_rust_estimator_parser() -> None:
test_outputpath = outputpath / "test_rust_estimator_parser"
dfestimators: list[pl.DataFrame] = []
for use_rust in [True, False]:
if test_outputpath.exists():
shutil.rmtree(test_outputpath)
test_outputpath.mkdir(exist_ok=True)
shutil.copytree(modelpath, test_outputpath, dirs_exist_ok=True, ignore=shutil.ignore_patterns("*parquet*"))
dfestimators.append(
at.estimators.scan_estimators(modelpath=test_outputpath, use_rust_parser=use_rust).collect()
)

pltest.assert_frame_equal(
dfestimators[0],
dfestimators[1],
rtol=1e-4,
atol=1e-4,
)
10 changes: 9 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
def pytest_sessionstart(session) -> None:
"""Clear the test output of previous runs."""
import shutil
from pathlib import Path

import artistools as at

outputpath = at.get_config()["path_testoutput"]
outputpath = at.get_config("path_testoutput")
assert isinstance(outputpath, Path)
repopath = at.get_config("path_artistools_repository")
assert isinstance(repopath, Path)
if outputpath.exists():
is_descendant = repopath.resolve() in outputpath.resolve().parents
assert (
is_descendant
), f"Refusing to delete {outputpath.resolve()} as it is not a descendant of the repository {repopath.resolve()}"
shutil.rmtree(outputpath)
outputpath.mkdir(exist_ok=True)
22 changes: 19 additions & 3 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use polars::chunked_array::ChunkedArray;
use polars::datatypes::Float32Type;
use polars::prelude::*;
use polars::series::IntoSeries;
use pyo3_polars::{self, PyDataFrame};
use pyo3_polars::PyDataFrame;
use rayon::prelude::*;
use std::collections::HashMap;
use std::fs::File;
Expand Down Expand Up @@ -52,7 +52,7 @@ fn match_colsizes(coldata: &mut HashMap<String, Vec<f32>>, outputrownum: usize)
for singlecoldata in coldata.values_mut() {
if singlecoldata.len() < outputrownum {
assert_eq!(singlecoldata.len(), outputrownum - 1);
singlecoldata.push(f32::NAN);
singlecoldata.push(0.);
}
}
}
Expand All @@ -64,7 +64,7 @@ fn append_or_create(
outputrownum: &usize,
) {
if !coldata.contains_key(colname) {
coldata.insert(colname.clone(), vec![f32::NAN; *outputrownum - 1]);
coldata.insert(colname.clone(), vec![0.; *outputrownum - 1]);
}

let singlecoldata = coldata.get_mut(colname).unwrap();
Expand Down Expand Up @@ -126,6 +126,22 @@ fn parse_line(line: &str, mut coldata: &mut HashMap<String, Vec<f32>>, outputrow
} else {
let ionstageroman = ROMAN[ionstagestr.parse::<usize>().unwrap()];
outcolname = format!("{variablename}_{elsym}_{ionstageroman}");

if variablename.ends_with("*nne") {
let colname_nonne = format!(
"{}_{}_{}",
variablename.strip_suffix("*nne").unwrap(),
elsym,
ionstageroman
);
let colvalue_nonne = colvalue / coldata["nne"].last().unwrap();
append_or_create(
&mut coldata,
&colname_nonne,
colvalue_nonne,
outputrownum,
);
}
}
append_or_create(&mut coldata, &outcolname, colvalue, outputrownum);
}
Expand Down

0 comments on commit 6218996

Please sign in to comment.