Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Albrja/mic 5720/mypy tests randomness index map #574

Merged
merged 8 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.3.16 - TBD/TBD/TBD**

- Type-hinting: Fix mypy errors in tests/framework/randomness/test_index_map.py

**3.2.16 - 01/22/25**

- Type-hinting: Fix mypy errors in tests/framework/randomness/test_manager.py
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ exclude = [
'tests/framework/lookup/test_lookup.py',
'tests/framework/population/test_manager.py',
'tests/framework/population/test_population_view.py',
'tests/framework/randomness/test_index_map.py',
'tests/framework/randomness/test_reproducibility.py',
'tests/framework/randomness/test_stream.py',
'tests/framework/results/helpers.py',
Expand Down
7 changes: 6 additions & 1 deletion src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ def _hash(self, keys: pd.Index[Any], salt: ClockTime = 0) -> pd.Series[int]:
return new_map % len(self)

def _convert_to_ten_digit_int(
self, column: pd.Series[datetime | int | float]
self,
column: pd.Series[pd.Timestamp]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was typed incorrectly before but basically to make myppy happy, we have to have pd.Timestamp as the type to include datetime because pandas always converts datetimes to Timestamps. However, Timestamp is technically a part of datetime. See: https://pandas.pydata.org/docs/user_guide/timeseries.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So even though pd.Timestamp inherits from datatime, you have to include it explicitly (despite also having datetime as part of the signature)?

Also, glad you caught that incorrect mashing of series types!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - I think technically we could just make it pd.Timestamp and not have a series as datetime as an options since no matter what you do pandas will convert it but I t hink this is more intuitive since we do use datetime elsewhere and we have Time in types.py which is a Timestamp | datetime (and unfortunately I could not do a serious of that).

| pd.Series[datetime]
| pd.Series[int]
| pd.Series[float],
) -> pd.Series[int]:
"""Converts a column of datetimes, integers, or floats into a column
of 10 digit integers.
Expand All @@ -215,6 +219,7 @@ def _convert_to_ten_digit_int(
if pdt.is_datetime64_any_dtype(column):
integers = self._clip_to_seconds(column.astype(np.int64))
elif pdt.is_integer_dtype(column):
column = column.astype(int)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behavior is consistent with the other branch of the if/else block and makes mypy happy.

if not len(column >= 0) == len(column):
raise RandomnessError(
"Values in integer columns must be greater than or equal to zero."
Expand Down
84 changes: 51 additions & 33 deletions tests/framework/randomness/test_index_map.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
from __future__ import annotations

from collections.abc import Iterable
from datetime import datetime
from itertools import chain, combinations, product
from typing import Any

import numpy as np
import pandas as pd
import pytest
import pytest_mock
from scipy.stats import chisquare

from vivarium.framework.randomness import RandomnessError
from vivarium.framework.randomness.index_map import IndexMap


def almost_powerset(iterable):
def almost_powerset(iterable: Iterable[int | str]) -> list[tuple[int | str, ...]]:
"""almost_powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
s = list(iterable)
return list(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1)))
powerset = list(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1)))
return powerset


def generate_keys(number, types=("int", "float", "datetime"), seed=123456):
def generate_keys(
number: int,
types: tuple[str, str, str] = ("int", "float", "datetime"),
seed: int = 123456,
) -> pd.DataFrame:
rs = np.random.RandomState(seed=seed)

keys = {}
keys: dict[str, Any] = {}
if "datetime" in types:
year = rs.choice(np.arange(1980, 2018))
day = rs.choice(pd.date_range(f"01/01/{year}", periods=365))
Expand Down Expand Up @@ -49,33 +60,33 @@ def generate_keys(number, types=("int", "float", "datetime"), seed=123456):
seeds = list(rs.randint(10000000, size=1))


def id_fun(param):
def id_fun(param: tuple[float, Any, Any]) -> str:
return f"Size:{param[0]}, Types:{param[1]}, Seed:{param[2]}"


@pytest.fixture(scope="module", params=list(product(index_sizes, types, seeds)), ids=id_fun)
def map_size_and_hashed_values(request):
def map_size_and_hashed_values(request: pytest.FixtureRequest) -> tuple[int, pd.Series[int]]:
index_size, types_, seed = request.param
keys = generate_keys(*request.param).set_index(types_).index
m = IndexMap(key_columns=types_)
return len(m), m._hash(keys)


def test_digit_series():
def test_digit_series() -> None:
m = IndexMap()
k = pd.Series(123456789, index=range(10000))
for i in range(10):
assert len(m._digit(k, i).unique()) == 1
assert m._digit(k, i)[0] == 10 - (i + 1)


def test_clip_to_seconds_scalar():
def test_clip_to_seconds_scalar() -> None:
m = IndexMap()
k = pd.to_datetime("2010-01-25 06:25:31.123456789")
assert m._clip_to_seconds(k.value) == int(str(k.value)[:10])
assert (m._clip_to_seconds(pd.Series(k.value)) == int(str(k.value)[:10])).all()


def test_clip_to_seconds_series():
def test_clip_to_seconds_series() -> None:
m = IndexMap()
stamp = 1234567890
k = (
Expand All @@ -87,39 +98,39 @@ def test_clip_to_seconds_series():
assert m._clip_to_seconds(k).unique()[0] == stamp


def test_spread_series():
def test_spread_series() -> None:
m = IndexMap()
s = pd.Series(1234567890, index=range(10000))
assert len(m._spread(s).unique()) == 1
assert m._spread(s).unique()[0] == 4072825790


def test_shift_series():
def test_shift_series() -> None:
m = IndexMap()
s = pd.Series(1.1234567890, index=range(10000))
assert len(m._shift(s).unique()) == 1
assert m._shift(s).unique()[0] == 1234567890


def test_convert_to_ten_digit_int():
def test_convert_to_ten_digit_int() -> None:
m = IndexMap()
v = 1234567890
datetime_col = pd.date_range(
date_col = pd.date_range(
pd.to_datetime(v, unit="s"), periods=10000, freq="ns"
).to_series()
int_col = pd.Series(v, index=range(10000))
float_col = pd.Series(1.1234567890, index=range(10000))
bad_col = pd.Series("a", index=range(10000))

assert m._convert_to_ten_digit_int(datetime_col).unique()[0] == v
assert m._convert_to_ten_digit_int(date_col).unique()[0] == v
assert m._convert_to_ten_digit_int(int_col).unique()[0] == 4072825790
assert m._convert_to_ten_digit_int(float_col).unique()[0] == v
with pytest.raises(RandomnessError):
m._convert_to_ten_digit_int(bad_col)
m._convert_to_ten_digit_int(bad_col) # type: ignore [arg-type]


@pytest.mark.skip("This fails because the hash needs work")
def test_hash_collisions(map_size_and_hashed_values):
def test_hash_collisions(map_size_and_hashed_values: tuple[int, pd.Series[int]]) -> None:
n, h = map_size_and_hashed_values
k = len(h)

Expand All @@ -134,12 +145,12 @@ def test_hash_collisions(map_size_and_hashed_values):


@pytest.mark.skip("This fails because the hash needs work")
def test_hash_uniformity(map_size_and_hashed_values):
def test_hash_uniformity(map_size_and_hashed_values: tuple[int, pd.Series[int]]) -> None:
n, h = map_size_and_hashed_values

k = len(h)
num_bins = k // 5 # Want about 5 items per bin for chi-squared
bins = np.linspace(0, n + 1, num_bins)
bins = pd.Series(np.linspace(0, n + 1, num_bins))

binned_data = pd.cut(h, bins)
distribution = pd.value_counts(binned_data).sort_index()
Expand All @@ -149,58 +160,65 @@ def test_hash_uniformity(map_size_and_hashed_values):


@pytest.fixture(scope="function")
def index_map(mocker):
def index_map(mocker: pytest_mock.MockFixture) -> type[IndexMap]:
mock_index_map = IndexMap

def hash_mock(k, salt):
def hash_mock(
k: pd.Index[int], salt: pd.Series[datetime] | pd.Series[int] | pd.Series[float]
) -> pd.Series[int]:
seed = 123456
salt = IndexMap()._convert_to_ten_digit_int(pd.Series(salt, index=k))
rs = np.random.RandomState(seed=seed + salt)
digit_series = IndexMap()._convert_to_ten_digit_int(pd.Series(salt, index=k))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe salt_series?

rs = np.random.RandomState(seed=seed + digit_series)
return pd.Series(rs.randint(0, len(k) * 10, size=len(k)), index=k)

mocker.patch.object(mock_index_map, "_hash", side_effect=hash_mock)

return mock_index_map


def test_update_empty_bad_keys(index_map):
def test_update_empty_bad_keys(index_map: type[IndexMap]) -> None:
keys = pd.DataFrame({"A": ["a"] * 10}, index=range(10))
m = index_map(key_columns=list(keys.columns))
with pytest.raises(RandomnessError):
m.update(keys, pd.to_datetime("2023-01-01"))


def test_update_nonempty_bad_keys(index_map):
def test_update_nonempty_bad_keys(index_map: type[IndexMap]) -> None:
keys = generate_keys(1000)
m = index_map(key_columns=list(keys.columns))
m.update(keys, pd.to_datetime("2023-01-01"))
with pytest.raises(RandomnessError):
m.update(keys, pd.to_datetime("2023-01-01"))


def test_update_empty_good_keys(index_map):
def test_update_empty_good_keys(index_map: type[IndexMap]) -> None:
keys = generate_keys(1000)
m = index_map(key_columns=list(keys.columns))
m.update(keys, pd.to_datetime("2023-01-01"))
key_index = keys.set_index(list(keys.columns)).index
assert len(m._map) == len(keys), "All keys not in mapping"
map = m._map
assert isinstance(map, pd.Series)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you need this assertion if you typed map specifically, map: pd.Series = m._map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I specifically had to do this because _map can either be pd.Series[int] or None but in this case it has to be a series. For some reason mypy doesnt like inline typing vs the assertion.


assert len(map) == len(keys), "All keys not in mapping"
assert (
m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty
map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty
), "Extra keys in mapping"
assert len(m._map.unique()) == len(keys), "Duplicate values in mapping"
assert len(map.unique()) == len(keys), "Duplicate values in mapping"


def test_update_nonempty_good_keys(index_map):
def test_update_nonempty_good_keys(index_map: type[IndexMap]) -> None:
keys = generate_keys(2000)
m = index_map(key_columns=list(keys.columns))
keys1, keys2 = keys[:1000], keys[1000:]

m.update(keys1, pd.to_datetime("2023-01-01"))
m.update(keys2, pd.to_datetime("2023-01-01"))
map = m._map
assert isinstance(map, pd.Series)

key_index = keys.set_index(list(keys.columns)).index
assert len(m._map) == len(keys), "All keys not in mapping"
assert len(map) == len(keys), "All keys not in mapping"
assert (
m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty
map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty
), "Extra keys in mapping"
assert len(m._map.unique()) == len(keys), "Duplicate values in mapping"
assert len(map.unique()) == len(keys), "Duplicate values in mapping"
Loading