Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
teddygroves committed Jul 7, 2023
2 parents 0e445e3 + 8d4ca56 commit 66274d2
Show file tree
Hide file tree
Showing 20 changed files with 790 additions and 772 deletions.
2 changes: 2 additions & 0 deletions bibat/examples/baseball/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ else
INSTALL_CMDSTAN_FLAGS =
endif

env: $(ENV_MARKER)

$(ACTIVATE_VENV):
python -m venv .venv --prompt=baseball

Expand Down
70 changes: 35 additions & 35 deletions bibat/examples/baseball/baseball/data_preparation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Provide functions prepare_data_x.
"""Provides functions prepare_data_x.
These functions should take in a dataframe of measurements and return a
PreparedData object.
Expand All @@ -10,14 +10,15 @@
import pandas as pd
import pandera as pa
from pandera.typing import DataFrame, Series
from pydantic.dataclasses import dataclass
from pydantic import BaseModel

from baseball.util import CoordDict
from baseball import util

NAME_FILE = "name.txt"
COORDS_FILE = "coords.json"
MEASUREMENTS_FILE = "measurements.csv"
N_CV_FOLDS = 10

HERE = os.path.dirname(__file__)
DATA_DIR = os.path.join(HERE, "..", "data")
RAW_DIR = os.path.join(DATA_DIR, "raw")
Expand Down Expand Up @@ -66,40 +67,14 @@ class MeasurementsDF(pa.SchemaModel):
n_success: Series[int] = pa.Field(ge=0)


@dataclass
class PreparedData:
class PreparedData(BaseModel, arbitrary_types_allowed=True):
"""What prepared data looks like in this analysis."""

name: str
coords: CoordDict
coords: util.CoordDict
measurements: DataFrame[MeasurementsDF]


def load_prepared_data(directory: str) -> PreparedData:
"""Load prepared data from files in directory."""
with open(os.path.join(directory, COORDS_FILE), "r") as f:
coords = json.load(f)
with open(os.path.join(directory, NAME_FILE), "r") as f:
name = f.read()
measurements = pd.read_csv(os.path.join(directory, MEASUREMENTS_FILE))
return PreparedData(
name=name,
coords=coords,
measurements=measurements,
)


def write_prepared_data(prepped: PreparedData, directory):
"""Write prepared data files to a directory."""
if not os.path.exists(directory):
os.mkdir(directory)
prepped.measurements.to_csv(os.path.join(directory, MEASUREMENTS_FILE))
with open(os.path.join(directory, COORDS_FILE), "w") as f:
json.dump(prepped.coords, f)
with open(os.path.join(directory, NAME_FILE), "w") as f:
f.write(prepped.name)


def prepare_data_2006(measurements_raw: pd.DataFrame) -> PreparedData:
"""Prepare the 2006 data."""
measurements = measurements_raw.rename(
Expand All @@ -112,9 +87,9 @@ def prepare_data_2006(measurements_raw: pd.DataFrame) -> PreparedData:
name="2006",
coords={
"player_season": measurements["player_season"].tolist(),
"season": measurements["season"].tolist(),
"season": measurements["season"].astype(str).tolist(),
},
measurements=measurements,
measurements=DataFrame[MeasurementsDF](measurements),
)


Expand Down Expand Up @@ -181,7 +156,32 @@ def filter_batters(df: pd.DataFrame):
name="bdb",
coords={
"player_season": measurements["player_season"].tolist(),
"season": measurements["season"].tolist(),
"season": measurements["season"].astype(str).tolist(),
},
measurements=measurements,
measurements=DataFrame[MeasurementsDF](measurements),
)


def load_prepared_data(directory: str) -> PreparedData:
"""Load prepared data from files in directory."""
with open(os.path.join(directory, COORDS_FILE), "r") as f:
coords = json.load(f)
with open(os.path.join(directory, NAME_FILE), "r") as f:
name = f.read()
measurements = pd.read_csv(os.path.join(directory, MEASUREMENTS_FILE))
return PreparedData(
name=name,
coords=coords,
measurements=DataFrame[MeasurementsDF](measurements),
)


def write_prepared_data(prepped: PreparedData, directory):
"""Write prepared data files to a directory."""
if not os.path.exists(directory):
os.mkdir(directory)
prepped.measurements.to_csv(os.path.join(directory, MEASUREMENTS_FILE))
with open(os.path.join(directory, COORDS_FILE), "w") as f:
json.dump(prepped.coords, f)
with open(os.path.join(directory, NAME_FILE), "w") as f:
f.write(prepped.name)
25 changes: 12 additions & 13 deletions bibat/examples/baseball/baseball/inference_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Dict, List, Optional

import toml
from pydantic import BaseModel, Field, root_validator, validator
from pydantic import BaseModel, Field, field_validator, model_validator

from baseball import fitting_mode, stan_input_functions

Expand Down Expand Up @@ -73,36 +73,35 @@ def __init__(self, **data):
]
super().__init__(**data)

@root_validator
def check_folds(cls, values):
@model_validator(mode="after")
def check_folds(cls, m: "InferenceConfiguration"):
"""Check that there is a number of folds if required."""
if any(m == "kfold" for m in values["fitting_mode_names"]):
if "mode_options" not in values.keys():
if any(m == "kfold" for m in m.fitting_mode_names):
if m.mode_options is None:
raise ValueError(
"Mode 'kfold' requires a mode_options.kfold table."
)
mode_options = values["mode_options"]
if "kfold" not in mode_options.keys():
if "kfold" not in m.mode_options.keys():
raise ValueError(
"Mode 'kfold' requires a mode_options.kfold table."
)
elif "n_folds" not in mode_options["kfold"].keys():
elif "n_folds" not in m.mode_options["kfold"].keys():
raise ValueError("Set 'n_folds' field in kfold mode options.")
else:
assert int(mode_options["kfold"]["n_folds"]), (
assert int(m.mode_options["kfold"]["n_folds"]), (
f"Could not coerce n_folds choice "
f"{mode_options['kfold']['n_folds']} to int."
f"{m.mode_options['kfold']['n_folds']} to int."
)
return values
return m

@validator("stan_file")
@field_validator("stan_file")
def check_stan_file_exists(cls, v):
"""Check that the stan file exists."""
if not os.path.exists(os.path.join(STAN_DIR, v)):
raise ValueError(f"{v} is not a file in {STAN_DIR}.")
return v

@validator("fitting_modes")
@field_validator("fitting_modes")
def check_modes(cls, v):
"""Check that the provided modes exist."""
for mode in v:
Expand Down
27 changes: 13 additions & 14 deletions bibat/examples/baseball/baseball/stan/custom_functions.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
/* This file is for your stan functions */

real gpareto_lpdf(vector y, real ymin, real k, real sigma) {
// generalised Pareto log pdf
int N = rows(y);
real inv_k = inv(k);
if (k<0 && max(y-ymin)/sigma > -inv_k)
reject("k<0 and max(y-ymin)/sigma > -1/k; found k, sigma =", k, ", ", sigma);
if (sigma<=0)
reject("sigma<=0; found sigma =", sigma);
if (fabs(k) > 1e-15)
return -(1+inv_k)*sum(log1p((y-ymin) * (k/sigma))) -N*log(sigma);
else
return -sum(y-ymin)/sigma -N*log(sigma); // limit k->0
}
vector standardise_vector(vector v, real mu, real s){
return (v - mu) / (2 * s);
}
Expand Down Expand Up @@ -30,17 +43,3 @@ vector col_sds(matrix m){
out[c] = sd(m[,c]);
return out;
}

real gpareto_lpdf(vector y, real ymin, real k, real sigma) {
// generalised Pareto log pdf
int N = rows(y);
real inv_k = inv(k);
if (k<0 && max(y-ymin)/sigma > -inv_k)
reject("k<0 and max(y-ymin)/sigma > -1/k; found k, sigma =", k, ", ", sigma);
if (sigma<=0)
reject("sigma<=0; found sigma =", sigma);
if (fabs(k) > 1e-15)
return -(1+inv_k)*sum(log1p((y-ymin) * (k/sigma))) -N*log(sigma);
else
return -sum(y-ymin)/sigma -N*log(sigma); // limit k->0
}
Binary file modified bibat/examples/baseball/baseball/stan/gpareto
Binary file not shown.
Binary file modified bibat/examples/baseball/baseball/stan/normal
Binary file not shown.
1 change: 0 additions & 1 deletion bibat/examples/baseball/baseball/stan_input_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Functions for generating input to Stan from prepared data."""


from typing import Dict

import numpy as np
Expand Down
Loading

0 comments on commit 66274d2

Please sign in to comment.