-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Albrja/mic 5720/mypy tests randomness index map #574
Changes from 5 commits
591ab19
61e4fac
a870600
ca02dd7
7431c85
c2582d4
f8d924e
738616d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,7 +192,11 @@ def _hash(self, keys: pd.Index[Any], salt: ClockTime = 0) -> pd.Series[int]: | |
return new_map % len(self) | ||
|
||
def _convert_to_ten_digit_int( | ||
self, column: pd.Series[datetime | int | float] | ||
self, | ||
column: pd.Series[pd.Timestamp] | ||
| pd.Series[datetime] | ||
| pd.Series[int] | ||
| pd.Series[float], | ||
) -> pd.Series[int]: | ||
"""Converts a column of datetimes, integers, or floats into a column | ||
of 10 digit integers. | ||
|
@@ -215,6 +219,7 @@ def _convert_to_ten_digit_int( | |
if pdt.is_datetime64_any_dtype(column): | ||
integers = self._clip_to_seconds(column.astype(np.int64)) | ||
elif pdt.is_integer_dtype(column): | ||
column = column.astype(int) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This behavior is consistent with the other branch of the if/else block and makes mypy happy. |
||
if not len(column >= 0) == len(column): | ||
raise RandomnessError( | ||
"Values in integer columns must be greater than or equal to zero." | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,35 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Iterable | ||
from datetime import datetime | ||
from itertools import chain, combinations, product | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
import pytest_mock | ||
from scipy.stats import chisquare | ||
|
||
from vivarium.framework.randomness import RandomnessError | ||
from vivarium.framework.randomness.index_map import IndexMap | ||
|
||
|
||
def almost_powerset(iterable): | ||
def almost_powerset(iterable: Iterable[int | str]) -> list[tuple[int | str, ...]]: | ||
"""almost_powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)""" | ||
s = list(iterable) | ||
return list(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))) | ||
powerset = list(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))) | ||
return powerset | ||
|
||
|
||
def generate_keys(number, types=("int", "float", "datetime"), seed=123456): | ||
def generate_keys( | ||
number: int, | ||
types: tuple[str, str, str] = ("int", "float", "datetime"), | ||
seed: int = 123456, | ||
) -> pd.DataFrame: | ||
rs = np.random.RandomState(seed=seed) | ||
|
||
keys = {} | ||
keys: dict[str, Any] = {} | ||
if "datetime" in types: | ||
year = rs.choice(np.arange(1980, 2018)) | ||
day = rs.choice(pd.date_range(f"01/01/{year}", periods=365)) | ||
|
@@ -49,33 +60,33 @@ def generate_keys(number, types=("int", "float", "datetime"), seed=123456): | |
seeds = list(rs.randint(10000000, size=1)) | ||
|
||
|
||
def id_fun(param): | ||
def id_fun(param: tuple[float, Any, Any]) -> str: | ||
return f"Size:{param[0]}, Types:{param[1]}, Seed:{param[2]}" | ||
|
||
|
||
@pytest.fixture(scope="module", params=list(product(index_sizes, types, seeds)), ids=id_fun) | ||
def map_size_and_hashed_values(request): | ||
def map_size_and_hashed_values(request: pytest.FixtureRequest) -> tuple[int, pd.Series[int]]: | ||
index_size, types_, seed = request.param | ||
keys = generate_keys(*request.param).set_index(types_).index | ||
m = IndexMap(key_columns=types_) | ||
return len(m), m._hash(keys) | ||
|
||
|
||
def test_digit_series(): | ||
def test_digit_series() -> None: | ||
m = IndexMap() | ||
k = pd.Series(123456789, index=range(10000)) | ||
for i in range(10): | ||
assert len(m._digit(k, i).unique()) == 1 | ||
assert m._digit(k, i)[0] == 10 - (i + 1) | ||
|
||
|
||
def test_clip_to_seconds_scalar(): | ||
def test_clip_to_seconds_scalar() -> None: | ||
m = IndexMap() | ||
k = pd.to_datetime("2010-01-25 06:25:31.123456789") | ||
assert m._clip_to_seconds(k.value) == int(str(k.value)[:10]) | ||
assert (m._clip_to_seconds(pd.Series(k.value)) == int(str(k.value)[:10])).all() | ||
|
||
|
||
def test_clip_to_seconds_series(): | ||
def test_clip_to_seconds_series() -> None: | ||
m = IndexMap() | ||
stamp = 1234567890 | ||
k = ( | ||
|
@@ -87,39 +98,39 @@ def test_clip_to_seconds_series(): | |
assert m._clip_to_seconds(k).unique()[0] == stamp | ||
|
||
|
||
def test_spread_series(): | ||
def test_spread_series() -> None: | ||
m = IndexMap() | ||
s = pd.Series(1234567890, index=range(10000)) | ||
assert len(m._spread(s).unique()) == 1 | ||
assert m._spread(s).unique()[0] == 4072825790 | ||
|
||
|
||
def test_shift_series(): | ||
def test_shift_series() -> None: | ||
m = IndexMap() | ||
s = pd.Series(1.1234567890, index=range(10000)) | ||
assert len(m._shift(s).unique()) == 1 | ||
assert m._shift(s).unique()[0] == 1234567890 | ||
|
||
|
||
def test_convert_to_ten_digit_int(): | ||
def test_convert_to_ten_digit_int() -> None: | ||
m = IndexMap() | ||
v = 1234567890 | ||
datetime_col = pd.date_range( | ||
date_col = pd.date_range( | ||
pd.to_datetime(v, unit="s"), periods=10000, freq="ns" | ||
).to_series() | ||
int_col = pd.Series(v, index=range(10000)) | ||
float_col = pd.Series(1.1234567890, index=range(10000)) | ||
bad_col = pd.Series("a", index=range(10000)) | ||
|
||
assert m._convert_to_ten_digit_int(datetime_col).unique()[0] == v | ||
assert m._convert_to_ten_digit_int(date_col).unique()[0] == v | ||
assert m._convert_to_ten_digit_int(int_col).unique()[0] == 4072825790 | ||
assert m._convert_to_ten_digit_int(float_col).unique()[0] == v | ||
with pytest.raises(RandomnessError): | ||
m._convert_to_ten_digit_int(bad_col) | ||
m._convert_to_ten_digit_int(bad_col) # type: ignore [arg-type] | ||
|
||
|
||
@pytest.mark.skip("This fails because the hash needs work") | ||
def test_hash_collisions(map_size_and_hashed_values): | ||
def test_hash_collisions(map_size_and_hashed_values: tuple[int, pd.Series[int]]) -> None: | ||
n, h = map_size_and_hashed_values | ||
k = len(h) | ||
|
||
|
@@ -134,12 +145,12 @@ def test_hash_collisions(map_size_and_hashed_values): | |
|
||
|
||
@pytest.mark.skip("This fails because the hash needs work") | ||
def test_hash_uniformity(map_size_and_hashed_values): | ||
def test_hash_uniformity(map_size_and_hashed_values: tuple[int, pd.Series[int]]) -> None: | ||
n, h = map_size_and_hashed_values | ||
|
||
k = len(h) | ||
num_bins = k // 5 # Want about 5 items per bin for chi-squared | ||
bins = np.linspace(0, n + 1, num_bins) | ||
bins = pd.Series(np.linspace(0, n + 1, num_bins)) | ||
|
||
binned_data = pd.cut(h, bins) | ||
distribution = pd.value_counts(binned_data).sort_index() | ||
|
@@ -149,58 +160,65 @@ def test_hash_uniformity(map_size_and_hashed_values): | |
|
||
|
||
@pytest.fixture(scope="function") | ||
def index_map(mocker): | ||
def index_map(mocker: pytest_mock.MockFixture) -> type[IndexMap]: | ||
mock_index_map = IndexMap | ||
|
||
def hash_mock(k, salt): | ||
def hash_mock( | ||
k: pd.Index[int], salt: pd.Series[datetime] | pd.Series[int] | pd.Series[float] | ||
) -> pd.Series[int]: | ||
seed = 123456 | ||
salt = IndexMap()._convert_to_ten_digit_int(pd.Series(salt, index=k)) | ||
rs = np.random.RandomState(seed=seed + salt) | ||
digit_series = IndexMap()._convert_to_ten_digit_int(pd.Series(salt, index=k)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe |
||
rs = np.random.RandomState(seed=seed + digit_series) | ||
return pd.Series(rs.randint(0, len(k) * 10, size=len(k)), index=k) | ||
|
||
mocker.patch.object(mock_index_map, "_hash", side_effect=hash_mock) | ||
|
||
return mock_index_map | ||
|
||
|
||
def test_update_empty_bad_keys(index_map): | ||
def test_update_empty_bad_keys(index_map: type[IndexMap]) -> None: | ||
keys = pd.DataFrame({"A": ["a"] * 10}, index=range(10)) | ||
m = index_map(key_columns=list(keys.columns)) | ||
with pytest.raises(RandomnessError): | ||
m.update(keys, pd.to_datetime("2023-01-01")) | ||
|
||
|
||
def test_update_nonempty_bad_keys(index_map): | ||
def test_update_nonempty_bad_keys(index_map: type[IndexMap]) -> None: | ||
keys = generate_keys(1000) | ||
m = index_map(key_columns=list(keys.columns)) | ||
m.update(keys, pd.to_datetime("2023-01-01")) | ||
with pytest.raises(RandomnessError): | ||
m.update(keys, pd.to_datetime("2023-01-01")) | ||
|
||
|
||
def test_update_empty_good_keys(index_map): | ||
def test_update_empty_good_keys(index_map: type[IndexMap]) -> None: | ||
keys = generate_keys(1000) | ||
m = index_map(key_columns=list(keys.columns)) | ||
m.update(keys, pd.to_datetime("2023-01-01")) | ||
key_index = keys.set_index(list(keys.columns)).index | ||
assert len(m._map) == len(keys), "All keys not in mapping" | ||
map = m._map | ||
assert isinstance(map, pd.Series) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if you need this assertion if you typed map specifically, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I specifically had to do this because _map can either be pd.Series[int] or None but in this case it has to be a series. For some reason mypy doesnt like inline typing vs the assertion. |
||
|
||
assert len(map) == len(keys), "All keys not in mapping" | ||
assert ( | ||
m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty | ||
map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty | ||
), "Extra keys in mapping" | ||
assert len(m._map.unique()) == len(keys), "Duplicate values in mapping" | ||
assert len(map.unique()) == len(keys), "Duplicate values in mapping" | ||
|
||
|
||
def test_update_nonempty_good_keys(index_map): | ||
def test_update_nonempty_good_keys(index_map: type[IndexMap]) -> None: | ||
keys = generate_keys(2000) | ||
m = index_map(key_columns=list(keys.columns)) | ||
keys1, keys2 = keys[:1000], keys[1000:] | ||
|
||
m.update(keys1, pd.to_datetime("2023-01-01")) | ||
m.update(keys2, pd.to_datetime("2023-01-01")) | ||
map = m._map | ||
assert isinstance(map, pd.Series) | ||
|
||
key_index = keys.set_index(list(keys.columns)).index | ||
assert len(m._map) == len(keys), "All keys not in mapping" | ||
assert len(map) == len(keys), "All keys not in mapping" | ||
assert ( | ||
m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty | ||
map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty | ||
), "Extra keys in mapping" | ||
assert len(m._map.unique()) == len(keys), "Duplicate values in mapping" | ||
assert len(map.unique()) == len(keys), "Duplicate values in mapping" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was typed incorrectly before but basically to make myppy happy, we have to have pd.Timestamp as the type to include datetime because pandas always converts datetimes to Timestamps. However, Timestamp is technically a part of datetime. See: https://pandas.pydata.org/docs/user_guide/timeseries.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So even though pd.Timestamp inherits from datatime, you have to include it explicitly (despite also having datetime as part of the signature)?
Also, glad you caught that incorrect mashing of series types!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - I think technically we could just make it pd.Timestamp and not have a series as datetime as an options since no matter what you do pandas will convert it but I t hink this is more intuitive since we do use datetime elsewhere and we have Time in types.py which is a Timestamp | datetime (and unfortunately I could not do a serious of that).