diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4041e6cd..c3c4f1d9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.3.17 - 01/24/25** + + - Type-hinting: Fix mypy errors in tests/framework/randomness/test_index_map.py + **3.2.16 - 01/22/25** - Type-hinting: Fix mypy errors in tests/framework/randomness/test_manager.py diff --git a/pyproject.toml b/pyproject.toml index fbad18d6..8fd3a14b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,6 @@ exclude = [ 'tests/framework/lookup/test_lookup.py', 'tests/framework/population/test_manager.py', 'tests/framework/population/test_population_view.py', - 'tests/framework/randomness/test_index_map.py', 'tests/framework/randomness/test_reproducibility.py', 'tests/framework/randomness/test_stream.py', 'tests/framework/results/helpers.py', diff --git a/src/vivarium/framework/randomness/index_map.py b/src/vivarium/framework/randomness/index_map.py index e0a1230e..f5614fb9 100644 --- a/src/vivarium/framework/randomness/index_map.py +++ b/src/vivarium/framework/randomness/index_map.py @@ -192,7 +192,11 @@ def _hash(self, keys: pd.Index[Any], salt: ClockTime = 0) -> pd.Series[int]: return new_map % len(self) def _convert_to_ten_digit_int( - self, column: pd.Series[datetime | int | float] + self, + column: pd.Series[pd.Timestamp] + | pd.Series[datetime] + | pd.Series[int] + | pd.Series[float], ) -> pd.Series[int]: """Converts a column of datetimes, integers, or floats into a column of 10 digit integers. @@ -215,6 +219,7 @@ def _convert_to_ten_digit_int( if pdt.is_datetime64_any_dtype(column): integers = self._clip_to_seconds(column.astype(np.int64)) elif pdt.is_integer_dtype(column): + column = column.astype(int) if not len(column >= 0) == len(column): raise RandomnessError( "Values in integer columns must be greater than or equal to zero." diff --git a/tests/framework/randomness/test_index_map.py b/tests/framework/randomness/test_index_map.py index 67c1e33b..a4bf2fa9 100644 --- a/tests/framework/randomness/test_index_map.py +++ b/tests/framework/randomness/test_index_map.py @@ -1,24 +1,35 @@ +from __future__ import annotations + +from collections.abc import Iterable +from datetime import datetime from itertools import chain, combinations, product +from typing import Any import numpy as np import pandas as pd import pytest +import pytest_mock from scipy.stats import chisquare from vivarium.framework.randomness import RandomnessError from vivarium.framework.randomness.index_map import IndexMap -def almost_powerset(iterable): +def almost_powerset(iterable: Iterable[int | str]) -> list[tuple[int | str, ...]]: """almost_powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)""" s = list(iterable) - return list(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))) + powerset = list(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))) + return powerset -def generate_keys(number, types=("int", "float", "datetime"), seed=123456): +def generate_keys( + number: int, + types: tuple[str, str, str] = ("int", "float", "datetime"), + seed: int = 123456, +) -> pd.DataFrame: rs = np.random.RandomState(seed=seed) - keys = {} + keys: dict[str, Any] = {} if "datetime" in types: year = rs.choice(np.arange(1980, 2018)) day = rs.choice(pd.date_range(f"01/01/{year}", periods=365)) @@ -49,19 +60,19 @@ def generate_keys(number, types=("int", "float", "datetime"), seed=123456): seeds = list(rs.randint(10000000, size=1)) -def id_fun(param): +def id_fun(param: tuple[float, Any, Any]) -> str: return f"Size:{param[0]}, Types:{param[1]}, Seed:{param[2]}" @pytest.fixture(scope="module", params=list(product(index_sizes, types, seeds)), ids=id_fun) -def map_size_and_hashed_values(request): +def map_size_and_hashed_values(request: pytest.FixtureRequest) -> tuple[int, pd.Series[int]]: index_size, types_, seed = request.param keys = generate_keys(*request.param).set_index(types_).index m = IndexMap(key_columns=types_) return len(m), m._hash(keys) -def test_digit_series(): +def test_digit_series() -> None: m = IndexMap() k = pd.Series(123456789, index=range(10000)) for i in range(10): @@ -69,13 +80,13 @@ def test_digit_series(): assert m._digit(k, i)[0] == 10 - (i + 1) -def test_clip_to_seconds_scalar(): +def test_clip_to_seconds_scalar() -> None: m = IndexMap() k = pd.to_datetime("2010-01-25 06:25:31.123456789") - assert m._clip_to_seconds(k.value) == int(str(k.value)[:10]) + assert (m._clip_to_seconds(pd.Series(k.value)) == int(str(k.value)[:10])).all() -def test_clip_to_seconds_series(): +def test_clip_to_seconds_series() -> None: m = IndexMap() stamp = 1234567890 k = ( @@ -87,39 +98,39 @@ def test_clip_to_seconds_series(): assert m._clip_to_seconds(k).unique()[0] == stamp -def test_spread_series(): +def test_spread_series() -> None: m = IndexMap() s = pd.Series(1234567890, index=range(10000)) assert len(m._spread(s).unique()) == 1 assert m._spread(s).unique()[0] == 4072825790 -def test_shift_series(): +def test_shift_series() -> None: m = IndexMap() s = pd.Series(1.1234567890, index=range(10000)) assert len(m._shift(s).unique()) == 1 assert m._shift(s).unique()[0] == 1234567890 -def test_convert_to_ten_digit_int(): +def test_convert_to_ten_digit_int() -> None: m = IndexMap() v = 1234567890 - datetime_col = pd.date_range( + date_col = pd.date_range( pd.to_datetime(v, unit="s"), periods=10000, freq="ns" ).to_series() int_col = pd.Series(v, index=range(10000)) float_col = pd.Series(1.1234567890, index=range(10000)) bad_col = pd.Series("a", index=range(10000)) - assert m._convert_to_ten_digit_int(datetime_col).unique()[0] == v + assert m._convert_to_ten_digit_int(date_col).unique()[0] == v assert m._convert_to_ten_digit_int(int_col).unique()[0] == 4072825790 assert m._convert_to_ten_digit_int(float_col).unique()[0] == v with pytest.raises(RandomnessError): - m._convert_to_ten_digit_int(bad_col) + m._convert_to_ten_digit_int(bad_col) # type: ignore [arg-type] @pytest.mark.skip("This fails because the hash needs work") -def test_hash_collisions(map_size_and_hashed_values): +def test_hash_collisions(map_size_and_hashed_values: tuple[int, pd.Series[int]]) -> None: n, h = map_size_and_hashed_values k = len(h) @@ -134,12 +145,12 @@ def test_hash_collisions(map_size_and_hashed_values): @pytest.mark.skip("This fails because the hash needs work") -def test_hash_uniformity(map_size_and_hashed_values): +def test_hash_uniformity(map_size_and_hashed_values: tuple[int, pd.Series[int]]) -> None: n, h = map_size_and_hashed_values k = len(h) num_bins = k // 5 # Want about 5 items per bin for chi-squared - bins = np.linspace(0, n + 1, num_bins) + bins = pd.Series(np.linspace(0, n + 1, num_bins)) binned_data = pd.cut(h, bins) distribution = pd.value_counts(binned_data).sort_index() @@ -149,13 +160,15 @@ def test_hash_uniformity(map_size_and_hashed_values): @pytest.fixture(scope="function") -def index_map(mocker): +def index_map(mocker: pytest_mock.MockFixture) -> type[IndexMap]: mock_index_map = IndexMap - def hash_mock(k, salt): + def hash_mock( + k: pd.Index[int], salt: pd.Series[datetime] | pd.Series[int] | pd.Series[float] + ) -> pd.Series[int]: seed = 123456 - salt = IndexMap()._convert_to_ten_digit_int(pd.Series(salt, index=k)) - rs = np.random.RandomState(seed=seed + salt) + digit_series = IndexMap()._convert_to_ten_digit_int(pd.Series(salt, index=k)) + rs = np.random.RandomState(seed=seed + digit_series) return pd.Series(rs.randint(0, len(k) * 10, size=len(k)), index=k) mocker.patch.object(mock_index_map, "_hash", side_effect=hash_mock) @@ -163,14 +176,14 @@ def hash_mock(k, salt): return mock_index_map -def test_update_empty_bad_keys(index_map): +def test_update_empty_bad_keys(index_map: type[IndexMap]) -> None: keys = pd.DataFrame({"A": ["a"] * 10}, index=range(10)) m = index_map(key_columns=list(keys.columns)) with pytest.raises(RandomnessError): m.update(keys, pd.to_datetime("2023-01-01")) -def test_update_nonempty_bad_keys(index_map): +def test_update_nonempty_bad_keys(index_map: type[IndexMap]) -> None: keys = generate_keys(1000) m = index_map(key_columns=list(keys.columns)) m.update(keys, pd.to_datetime("2023-01-01")) @@ -178,29 +191,34 @@ def test_update_nonempty_bad_keys(index_map): m.update(keys, pd.to_datetime("2023-01-01")) -def test_update_empty_good_keys(index_map): +def test_update_empty_good_keys(index_map: type[IndexMap]) -> None: keys = generate_keys(1000) m = index_map(key_columns=list(keys.columns)) m.update(keys, pd.to_datetime("2023-01-01")) key_index = keys.set_index(list(keys.columns)).index - assert len(m._map) == len(keys), "All keys not in mapping" + map = m._map + assert isinstance(map, pd.Series) + + assert len(map) == len(keys), "All keys not in mapping" assert ( - m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty + map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty ), "Extra keys in mapping" - assert len(m._map.unique()) == len(keys), "Duplicate values in mapping" + assert len(map.unique()) == len(keys), "Duplicate values in mapping" -def test_update_nonempty_good_keys(index_map): +def test_update_nonempty_good_keys(index_map: type[IndexMap]) -> None: keys = generate_keys(2000) m = index_map(key_columns=list(keys.columns)) keys1, keys2 = keys[:1000], keys[1000:] m.update(keys1, pd.to_datetime("2023-01-01")) m.update(keys2, pd.to_datetime("2023-01-01")) + map = m._map + assert isinstance(map, pd.Series) key_index = keys.set_index(list(keys.columns)).index - assert len(m._map) == len(keys), "All keys not in mapping" + assert len(map) == len(keys), "All keys not in mapping" assert ( - m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty + map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty ), "Extra keys in mapping" - assert len(m._map.unique()) == len(keys), "Duplicate values in mapping" + assert len(map.unique()) == len(keys), "Duplicate values in mapping"