Skip to content

Commit

Permalink
Add compatibility for multiple Pydantic versions (#66)
Browse files Browse the repository at this point in the history
### Changes:
 * Add pydantic_version_compact.py with `_get_major_pydantic_version` function to handle version-specific logic.
 * Updated test cases to use `pydantic_version_compatibility` for version compatibility.
 * Adjusted error handling in tests to accommodate different Pydantic versions.
 * Refactored imports and added necessary logic to support both Pydantic v1 and v2.
 * Updated continuous integration and delivery workflows to test for both version of pydantic.
  • Loading branch information
shaharbar1 authored Oct 21, 2024
1 parent cddda5c commit 8594158
Show file tree
Hide file tree
Showing 17 changed files with 485 additions and 55 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/continuous_delivery.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
fail-fast: false
matrix:
python-version: [ "3.8", "3.9", "3.10" ]
pydantic-version: [ "1.10.*", "2.*" ]

steps:
- name: Checkout repository
Expand All @@ -28,6 +29,7 @@ jobs:
export PATH="$HOME/.poetry/bin:$PATH"
- name: Install project dependencies with Poetry
run: |
poetry add pydantic@${{ matrix.pydantic-version }}
poetry install
- name: Style check
run: |
Expand All @@ -51,14 +53,14 @@ jobs:
echo "VERSION_CHANGED=true" >> $GITHUB_ENV
fi
- name: Create a Git tag
if: ${{ env.VERSION_CHANGED == 'true' && matrix.python-version == '3.8' }}
if: ${{ env.VERSION_CHANGED == 'true' && matrix.python-version == '3.8' && matrix.pydantic-version == '2.*' }}
run: |
git config --global user.email "github-actions[bot]@users.noreply.github.com"
git config --global user.name "GitHub Actions"
git tag "${{ env.PACKAGE_VERSION }}"
git push origin "${{ env.PACKAGE_VERSION }}"
- name: Publish Draft Release
if: ${{ env.VERSION_CHANGED == 'true' && matrix.python-version == '3.8' }}
if: ${{ env.VERSION_CHANGED == 'true' && matrix.python-version == '3.8' && matrix.pydantic-version == '2.*' }}
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
Expand All @@ -83,7 +85,7 @@ jobs:
core.setFailed(`Draft release named "Draft" not found.`);
};
- name: Build and publish to pypi
if: ${{ env.VERSION_CHANGED == 'true' && matrix.python-version == '3.8' }}
if: ${{ env.VERSION_CHANGED == 'true' && matrix.python-version == '3.8' && matrix.pydantic-version == '2.*' }}
run: |
poetry config pypi-token.pypi ${{ secrets.PYPI_TOKEN }}
poetry publish --build
2 changes: 2 additions & 0 deletions .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
fail-fast: false
matrix:
python-version: [ "3.8", "3.9", "3.10" ]
pydantic-version: [ "1.10.*", "2.*" ]

steps:
- name: Checkout repository
Expand All @@ -36,6 +37,7 @@ jobs:
export PATH="$HOME/.poetry/bin:$PATH"
- name: Install project dependencies with Poetry
run: |
poetry add pydantic@${{ matrix.pydantic-version }}
poetry install
- name: Style check
run: |
Expand Down
53 changes: 51 additions & 2 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@
# SOFTWARE.


from typing import Dict, List, NewType, Tuple, Union
from typing import Any, Dict, List, NewType, Tuple, Union

from pydantic import BaseModel, confloat, conint, constr
from pybandits.pydantic_version_compatibility import (
PYDANTIC_VERSION_1,
PYDANTIC_VERSION_2,
BaseModel,
confloat,
conint,
constr,
pydantic_version,
)

ActionId = NewType("ActionId", constr(min_length=1))
Float01 = NewType("Float_0_1", confloat(ge=0, le=1))
Expand All @@ -41,7 +49,48 @@
ACTION_IDS_PREFIX = "action_ids_"


class _classproperty(property):
def __get__(self, instance, owner):
return self.fget(owner)


class PyBanditsBaseModel(BaseModel, extra="forbid"):
"""
BaseModel of the PyBandits library.
"""

def _apply_version_adjusted_method(self, v2_method_name: str, v1_method_name: str, **kwargs) -> Any:
"""
Apply the method with the given name, adjusting for the pydantic version.
Parameters
----------
v2_method_name : str
The method name for pydantic v2.
v1_method_name : str
The method name for pydantic v1.
"""
if pydantic_version == PYDANTIC_VERSION_1:
return getattr(self, v1_method_name)(**kwargs)
elif pydantic_version == PYDANTIC_VERSION_2:
return getattr(self, v2_method_name)(**kwargs)
else:
raise ValueError(f"Unsupported pydantic version: {pydantic_version}")

@classmethod
def _get_value_with_default(cls, key: str, values: Dict[str, Any]) -> Any:
return values.get(key, cls.model_fields[key].default)

if pydantic_version == PYDANTIC_VERSION_1:

@_classproperty
def model_fields(cls) -> Dict[str, Any]:
"""
Get the model fields.
Returns
-------
List[str]
The model fields.
"""
return cls.__fields__
2 changes: 1 addition & 1 deletion pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from numpy import array
from numpy.random import choice
from numpy.typing import ArrayLike
from pydantic import field_validator, validate_call

from pybandits.base import ActionId, BinaryReward, CmabPredictions
from pybandits.mab import BaseMab
from pybandits.model import BayesianLogisticRegression, BayesianLogisticRegressionCC
from pybandits.pydantic_version_compatibility import field_validator, validate_call
from pybandits.strategy import (
BestActionIdentificationBandit,
ClassicBandit,
Expand Down
51 changes: 34 additions & 17 deletions pybandits/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, get_args

import numpy as np
from pydantic import field_validator, model_validator, validate_call

from pybandits.base import (
ACTION_IDS_PREFIX,
Expand All @@ -38,6 +37,14 @@
PyBanditsBaseModel,
)
from pybandits.model import Model
from pybandits.pydantic_version_compatibility import (
PYDANTIC_VERSION_1,
PYDANTIC_VERSION_2,
field_validator,
model_validator,
pydantic_version,
validate_call,
)
from pybandits.strategy import Strategy
from pybandits.utils import extract_argument_names_from_function

Expand Down Expand Up @@ -102,21 +109,31 @@ def at_least_one_action_is_defined(cls, v):
raise AttributeError("All actions should follow the same type.")
return v

@model_validator(mode="after")
def check_default_action(self):
if not self.epsilon and self.default_action:
raise AttributeError("A default action should only be defined when epsilon is defined.")
if self.default_action and self.default_action not in self.actions:
raise AttributeError("The default action must be valid action defined in the actions set.")
return self

@model_validator(mode="after")
def validate_default_action(self):
if not self.epsilon and self.default_action:
raise AttributeError("A default action should only be defined when epsilon is defined.")
if self.default_action and self.default_action not in self.actions:
raise AttributeError("The default action should be defined in the actions.")
return self
if pydantic_version == PYDANTIC_VERSION_1:

@model_validator(mode="before")
@classmethod
def check_default_action(cls, values):
epsilon = cls._get_value_with_default("epsilon", values)
default_action = cls._get_value_with_default("default_action", values)
if not epsilon and default_action:
raise AttributeError("A default action should only be defined when epsilon is defined.")
if default_action and default_action not in values["actions"]:
raise AttributeError("The default action must be valid action defined in the actions set.")
return values

elif pydantic_version == PYDANTIC_VERSION_2:

@model_validator(mode="after")
def check_default_action(self):
if not self.epsilon and self.default_action:
raise AttributeError("A default action should only be defined when epsilon is defined.")
if self.default_action and self.default_action not in self.actions:
raise AttributeError("The default action must be valid action defined in the actions set.")
return self

else:
raise ValueError(f"Unsupported pydantic version: {pydantic_version}")

############################################# Method Input Validators ##############################################

Expand Down Expand Up @@ -217,7 +234,7 @@ def get_state(self) -> (str, dict):
The internal state of the model (actions, scores, etc.).
"""
model_name = self.__class__.__name__
state: dict = self.model_dump()
state: dict = self._apply_version_adjusted_method("model_dump", "dict")
return model_name, state

@validate_call
Expand Down
72 changes: 52 additions & 20 deletions pybandits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,24 @@
import pymc.math as pmath
from numpy import array, c_, insert, mean, multiply, ones, sqrt, std
from numpy.typing import ArrayLike
from pydantic import (
Field,
NonNegativeFloat,
PositiveInt,
confloat,
model_validator,
validate_call,
)
from pymc import Bernoulli, Data, Deterministic, fit, sample
from pymc import Model as PymcModel
from pymc import StudentT as PymcStudentT
from pytensor.tensor import TensorVariable, dot
from scipy.stats import t

from pybandits.base import BinaryReward, Probability, PyBanditsBaseModel
from pybandits.pydantic_version_compatibility import (
PYDANTIC_VERSION_1,
PYDANTIC_VERSION_2,
Field,
NonNegativeFloat,
PositiveInt,
confloat,
model_validator,
pydantic_version,
validate_call,
)

UpdateMethods = Literal["MCMC", "VI"]

Expand Down Expand Up @@ -283,7 +286,12 @@ class BayesianLogisticRegression(Model):
"""

alpha: StudentT
betas: List[StudentT] = Field(..., min_length=1)
if pydantic_version == PYDANTIC_VERSION_1:
betas: List[StudentT] = Field(..., min_items=1)
elif pydantic_version == PYDANTIC_VERSION_2:
betas: List[StudentT] = Field(..., min_length=1)
else:
raise ValueError("Invalid version.")
update_method: UpdateMethods = "MCMC"
update_kwargs: Optional[dict] = None
_default_update_kwargs = dict(draws=1000, progressbar=False, return_inferencedata=False)
Expand All @@ -299,17 +307,41 @@ class BayesianLogisticRegression(Model):
)
_default_variational_inference_kwargs = dict(method="advi")

@model_validator(mode="after")
def arrange_update_kwargs(self):
if self.update_kwargs is None:
self.update_kwargs = self._default_update_kwargs
if self.update_method == "VI":
self.update_kwargs = {**self._default_variational_inference_kwargs, **self.update_kwargs}
elif self.update_method == "MCMC":
self.update_kwargs = {**self._default_mcmc_kwargs, **self.update_kwargs}
else:
raise ValueError("Invalid update method.")
return self
if pydantic_version == PYDANTIC_VERSION_1:

@model_validator(mode="before")
@classmethod
def arrange_update_kwargs(cls, values):
update_kwargs = cls._get_value_with_default("update_kwargs", values)
update_method = cls._get_value_with_default("update_method", values)
if update_kwargs is None:
update_kwargs = cls._default_update_kwargs
if update_method == "VI":
update_kwargs = {**cls._default_variational_inference_kwargs, **update_kwargs}
elif update_method == "MCMC":
update_kwargs = {**cls._default_mcmc_kwargs, **update_kwargs}
else:
raise ValueError("Invalid update method.")
values["update_kwargs"] = update_kwargs
values["update_method"] = update_method
return values

elif pydantic_version == PYDANTIC_VERSION_2:

@model_validator(mode="after")
def arrange_update_kwargs(self):
if self.update_kwargs is None:
self.update_kwargs = self._default_update_kwargs
if self.update_method == "VI":
self.update_kwargs = {**self._default_variational_inference_kwargs, **self.update_kwargs}
elif self.update_method == "MCMC":
self.update_kwargs = {**self._default_mcmc_kwargs, **self.update_kwargs}
else:
raise ValueError("Invalid update method.")
return self

else:
raise ValueError(f"Unsupported pydantic version: {pydantic_version}")

@classmethod
def _stable_sigmoid(cls, x: Union[np.ndarray, TensorVariable]) -> Union[np.ndarray, TensorVariable]:
Expand Down
Loading

0 comments on commit 8594158

Please sign in to comment.