diff --git a/giskard/ml_worker/core/savable.py b/giskard/ml_worker/core/savable.py index 99bcb5c5f4..aa5a770981 100644 --- a/giskard/ml_worker/core/savable.py +++ b/giskard/ml_worker/core/savable.py @@ -14,6 +14,7 @@ from giskard.client.giskard_client import GiskardClient from giskard.core.core import SMT, SavableMeta +from giskard.ml_worker.exceptions.giskard_exception import python_env_exception_helper from giskard.ml_worker.testing.registry.registry import tests_registry from giskard.settings import settings @@ -173,7 +174,12 @@ def load(cls, local_dir: Path, uuid: str, meta: SMT): if local_dir.exists(): with open(local_dir / "data.pkl", "rb") as f: - _function = cloudpickle.load(f) + try: + # According to https://github.com/cloudpipe/cloudpickle#cloudpickle: + # Cloudpickle can only be used to send objects between the exact same version of Python. + _function = cloudpickle.load(f) + except Exception as e: + raise python_env_exception_helper(cls.__name__, e) else: try: func = getattr(sys.modules[meta.module], meta.name) diff --git a/giskard/ml_worker/exceptions/giskard_exception.py b/giskard/ml_worker/exceptions/giskard_exception.py index 252cae3691..27f98a5ab2 100644 --- a/giskard/ml_worker/exceptions/giskard_exception.py +++ b/giskard/ml_worker/exceptions/giskard_exception.py @@ -1,3 +1,39 @@ +from typing import Optional, Tuple +import platform + + class GiskardException(Exception): def __init__(self, *args: object) -> None: super().__init__(*args) + + +class GiskardPythonEnvException(GiskardException): + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +class GiskardPythonVerException(GiskardPythonEnvException): + def __init__(self, cls_name, e, required_py_ver, *args: object) -> None: + super().__init__( + f"Failed to load '{cls_name}' due to {e.__class__.__name__}.\n" + f"Make sure you are loading it in the environment with matched Python version (required {required_py_ver}, loading with {platform.python_version()}).", + *args, + ) + + +class GiskardPythonDepException(GiskardPythonEnvException): + def __init__(self, cls_name, e, *args: object) -> None: + super().__init__( + f"Failed to load '{cls_name}' due to {e.__class__.__name__}.\n" + "Make sure you are loading it in the environment with matched dependencies.", + *args, + ) + + +def python_env_exception_helper(cls_name, e: Exception, required_py_ver: Optional[Tuple[str, str, str]] = None): + if required_py_ver is not None and required_py_ver[:2] != platform.python_version_tuple()[:2]: + # Python major and minor versions are not matched + # Notice that there could be some false positive, check: https://github.com/Giskard-AI/giskard/pull/1620 + return GiskardPythonVerException(cls_name, e, required_py_ver=required_py_ver) + # We assume the other cases as the dependency issues + return GiskardPythonDepException(cls_name, e) diff --git a/giskard/ml_worker/testing/registry/giskard_test.py b/giskard/ml_worker/testing/registry/giskard_test.py index f482d4e239..064a5256d7 100644 --- a/giskard/ml_worker/testing/registry/giskard_test.py +++ b/giskard/ml_worker/testing/registry/giskard_test.py @@ -12,6 +12,7 @@ from giskard.core.core import SMT, TestFunctionMeta from giskard.core.validation import configured_validate_arguments from giskard.ml_worker.core.savable import Artifact +from giskard.ml_worker.exceptions.giskard_exception import python_env_exception_helper from giskard.ml_worker.testing.registry.registry import get_object_uuid, tests_registry from giskard.ml_worker.testing.test_result import TestResult from giskard.utils.analytics_collector import analytics @@ -75,7 +76,10 @@ def _load_meta_locally(cls, local_dir, uuid: str) -> Optional[TestFunctionMeta]: def load(cls, local_dir: Path, uuid: str, meta: TestFunctionMeta): if local_dir.exists(): with open(Path(local_dir) / "data.pkl", "rb") as f: - func = pickle.load(f) + try: + func = pickle.load(f) + except Exception as e: + raise python_env_exception_helper(cls.__name__, e) elif hasattr(sys.modules[meta.module], meta.name): func = getattr(sys.modules[meta.module], meta.name) else: diff --git a/giskard/models/base/model.py b/giskard/models/base/model.py index 7e28b01796..596ba70c44 100644 --- a/giskard/models/base/model.py +++ b/giskard/models/base/model.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Type, Union +from typing import Iterable, List, Optional, Tuple, Type, Union import builtins import importlib @@ -22,7 +22,7 @@ from ...core.core import ModelMeta, ModelType, SupportedModelTypes from ...core.validation import configured_validate_arguments from ...datasets.base import Dataset -from ...ml_worker.exceptions.giskard_exception import GiskardException +from ...ml_worker.exceptions.giskard_exception import GiskardException, python_env_exception_helper from ...ml_worker.utils.logging import Timer from ...models.cache import ModelCache from ...path_utils import get_size @@ -207,11 +207,16 @@ def is_text_generation(self) -> bool: return self.meta.model_type == SupportedModelTypes.TEXT_GENERATION @classmethod - def determine_model_class(cls, meta, local_dir): + def determine_model_class(cls, meta, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None): class_file = Path(local_dir) / MODEL_CLASS_PKL if class_file.exists(): with open(class_file, "rb") as f: - clazz = cloudpickle.load(f) + try: + # According to https://github.com/cloudpipe/cloudpickle#cloudpickle: + # Cloudpickle can only be used to send objects between the exact same version of Python. + clazz = cloudpickle.load(f) + except Exception as e: + raise python_env_exception_helper(cls.__name__, e, required_py_ver=model_py_ver) if not issubclass(clazz, BaseModel): raise ValueError(f"Unknown model class: {clazz}. Models should inherit from 'BaseModel' class") return clazz @@ -440,7 +445,7 @@ def download(cls, client: Optional[GiskardClient], project_key, model_id): if client is None: # internal worker case, no token based http client [deprecated, to be removed] assert local_dir.exists(), f"Cannot find existing model {project_key}.{model_id} in {local_dir}" - _, meta = cls.read_meta_from_local_dir(local_dir) + meta_response, meta = cls.read_meta_from_local_dir(local_dir) else: client.load_artifact(local_dir, posixpath.join(project_key, "models", model_id)) meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id) @@ -461,7 +466,11 @@ def download(cls, client: Optional[GiskardClient], project_key, model_id): loader_class=file_meta["loader_class"], ) - clazz = cls.determine_model_class(meta, local_dir) + model_py_ver = ( + tuple(meta_response.languageVersion.split(".")) if "PYTHON" == meta_response.language.upper() else None + ) + + clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver) constructor_params = meta.__dict__ constructor_params["id"] = str(model_id) @@ -469,11 +478,11 @@ def download(cls, client: Optional[GiskardClient], project_key, model_id): del constructor_params["loader_module"] del constructor_params["loader_class"] - model = clazz.load(local_dir, **constructor_params) + model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params) return model @classmethod - def read_meta_from_local_dir(cls, local_dir): + def read_meta_from_local_dir(cls, local_dir) -> Tuple[ModelMetaInfo, ModelMeta]: with (Path(local_dir) / META_FILENAME).open(encoding="utf-8") as f: file_meta = yaml.load(f, Loader=yaml.Loader) meta = ModelMeta( @@ -486,8 +495,26 @@ def read_meta_from_local_dir(cls, local_dir): loader_module=file_meta["loader_module"], loader_class=file_meta["loader_class"], ) - # dirty implementation to return id like this, to be decided if meta properties can just be BaseModel properties - return file_meta["id"], meta + + # Bring more information, such as language and language version + extra_meta = ModelMetaInfo( + id=file_meta["id"], + name=meta.name, + modelType=file_meta["model_type"], + featureNames=meta.feature_names if meta.feature_names is not None else [], + threshold=meta.classification_threshold, + description=meta.description, + classificationLabels=meta.classification_labels + if meta.classification_labels is None + else list(map(str, meta.classification_labels)), + languageVersion=file_meta["language_version"], + language=file_meta["language"], + size=file_meta["size"], + classificationLabelsDtype=None, + createdDate="", + projectId=-1, + ) + return extra_meta, meta @classmethod def cast_labels(cls, meta_response: ModelMetaInfo) -> List[Union[str, Type]]: @@ -499,7 +526,7 @@ def cast_labels(cls, meta_response: ModelMetaInfo) -> List[Union[str, Type]]: return labels_ @classmethod - def load(cls, local_dir, **kwargs): + def load(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None, **kwargs): class_file = Path(local_dir) / MODEL_CLASS_PKL model_id, meta = cls.read_meta_from_local_dir(local_dir) @@ -510,7 +537,12 @@ def load(cls, local_dir, **kwargs): if class_file.exists(): with open(class_file, "rb") as f: - clazz = cloudpickle.load(f) + try: + # According to https://github.com/cloudpipe/cloudpickle#cloudpickle: + # Cloudpickle can only be used to send objects between the exact same version of Python. + clazz = cloudpickle.load(f) + except Exception as e: + raise python_env_exception_helper(cls.__name__, e, required_py_ver=model_py_ver) clazz_kwargs = {} clazz_kwargs.update(constructor_params) clazz_kwargs.update(kwargs) diff --git a/giskard/models/base/serialization.py b/giskard/models/base/serialization.py index 353812c00e..d9926e2864 100644 --- a/giskard/models/base/serialization.py +++ b/giskard/models/base/serialization.py @@ -1,10 +1,12 @@ import pickle from pathlib import Path -from typing import Union +from typing import Optional, Tuple, Union import cloudpickle import mlflow +from giskard.ml_worker.exceptions.giskard_exception import python_env_exception_helper + from .wrapper import WrapperModel @@ -49,12 +51,17 @@ def save_model(self, local_path: Union[str, Path]) -> None: ) @classmethod - def load_model(cls, local_dir): + def load_model(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None): local_path = Path(local_dir) model_path = local_path / "model.pkl" if model_path.exists(): with open(model_path, "rb") as f: - model = cloudpickle.load(f) + try: + # According to https://github.com/cloudpipe/cloudpickle#cloudpickle: + # Cloudpickle can only be used to send objects between the exact same version of Python. + model = cloudpickle.load(f) + except Exception as e: + raise python_env_exception_helper(cls.__name__, e, required_py_ver=model_py_ver) return model else: raise ValueError( diff --git a/giskard/models/base/wrapper.py b/giskard/models/base/wrapper.py index d7873496be..3d295145cb 100644 --- a/giskard/models/base/wrapper.py +++ b/giskard/models/base/wrapper.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable, Optional, Union +from typing import Any, Callable, Iterable, Optional, Tuple, Union import logging import pickle @@ -16,6 +16,7 @@ from ...core.validation import configured_validate_arguments from ..utils import warn_once from .model import BaseModel +from giskard.ml_worker.exceptions.giskard_exception import python_env_exception_helper logger = logging.getLogger(__name__) @@ -242,21 +243,28 @@ def save_wrapper_meta(self, local_path): ) @classmethod - def load(cls, local_dir, **kwargs): + def load(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None, **kwargs): constructor_params = cls.load_constructor_params(local_dir, **kwargs) - return cls(model=cls.load_model(local_dir), **constructor_params) + if model_py_ver is None: + # Try to extract Python version from meta info under local dir + meta_response, _ = cls.read_meta_from_local_dir(local_dir) + model_py_ver = ( + tuple(meta_response.languageVersion.split(".")) if "PYTHON" == meta_response.language.upper() else None + ) + + return cls(model=cls.load_model(local_dir, model_py_ver), **constructor_params) @classmethod - def load_constructor_params(cls, local_dir, **kwargs): + def load_constructor_params(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None, **kwargs): params = cls.load_wrapper_meta(local_dir) params["data_preprocessing_function"] = cls.load_data_preprocessing_function(local_dir) params["model_postprocessing_function"] = cls.load_model_postprocessing_function(local_dir) params.update(kwargs) - model_id, meta = cls.read_meta_from_local_dir(local_dir) + extra_meta, meta = cls.read_meta_from_local_dir(local_dir) constructor_params = meta.__dict__ - constructor_params["id"] = model_id + constructor_params["id"] = extra_meta.id constructor_params = constructor_params.copy() constructor_params.update(params) @@ -264,13 +272,15 @@ def load_constructor_params(cls, local_dir, **kwargs): @classmethod @abstractmethod - def load_model(cls, path: Union[str, Path]): + def load_model(cls, path: Union[str, Path], model_py_ver: Optional[Tuple[str, str, str]] = None): """Loads the wrapped ``model`` object. Parameters ---------- path : Union[str, Path] Path from which the model should be loaded. + model_py_ver : Optional[Tuple[str, str, str]] + Python version used to save the model, to validate if model loading failed. """ ... @@ -280,7 +290,12 @@ def load_data_preprocessing_function(cls, local_path: Union[str, Path]): file_path = local_path / "giskard-data-preprocessing-function.pkl" if file_path.exists(): with open(file_path, "rb") as f: - return cloudpickle.load(f) + try: + # According to https://github.com/cloudpipe/cloudpickle#cloudpickle: + # Cloudpickle can only be used to send objects between the exact same version of Python. + return cloudpickle.load(f) + except Exception as e: + raise python_env_exception_helper("Data Preprocessing Function", e) return None @classmethod @@ -289,7 +304,12 @@ def load_model_postprocessing_function(cls, local_path: Union[str, Path]): file_path = local_path / "giskard-model-postprocessing-function.pkl" if file_path.exists(): with open(file_path, "rb") as f: - return cloudpickle.load(f) + try: + # According to https://github.com/cloudpipe/cloudpickle#cloudpickle: + # Cloudpickle can only be used to send objects between the exact same version of Python. + return cloudpickle.load(f) + except Exception as e: + raise python_env_exception_helper("Data Postprocessing Function", e) return None @classmethod diff --git a/giskard/models/catboost.py b/giskard/models/catboost.py index 187015bc9f..a07f7485bc 100644 --- a/giskard/models/catboost.py +++ b/giskard/models/catboost.py @@ -1,3 +1,4 @@ +from typing import Optional, Tuple import mlflow from .sklearn import SKLearnModel @@ -12,7 +13,7 @@ def save_model(self, local_path, mlflow_meta: mlflow.models.Model): mlflow.catboost.save_model(self.model, path=local_path, mlflow_model=mlflow_meta) @classmethod - def load_model(cls, local_dir): + def load_model(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None): return mlflow.catboost.load_model(local_dir) def to_mlflow(self, artifact_path: str = "catboost-model-from-giskard", **kwargs): diff --git a/giskard/models/huggingface.py b/giskard/models/huggingface.py index 2b114fe7a7..3f40d1e503 100644 --- a/giskard/models/huggingface.py +++ b/giskard/models/huggingface.py @@ -92,7 +92,7 @@ class explicitly using :class:`giskard.models.huggingface.HuggingFaceModel`. the `model_postprocessing_function` argument. This function should take the raw output of your model and return a numpy array of probabilities. """ -from typing import Any, Callable, Iterable, Optional, Union +from typing import Any, Callable, Iterable, Optional, Tuple, Union import logging from pathlib import Path @@ -196,7 +196,7 @@ def __init__( pass @classmethod - def load_model(cls, local_path): + def load_model(cls, local_path, model_py_ver: Optional[Tuple[str, str, str]] = None): huggingface_meta_file = Path(local_path) / "giskard-model-huggingface-meta.yaml" if huggingface_meta_file.exists(): with open(huggingface_meta_file) as f: diff --git a/giskard/models/langchain.py b/giskard/models/langchain.py index ce123ea3a5..7fea74d49f 100644 --- a/giskard/models/langchain.py +++ b/giskard/models/langchain.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Iterable, Optional, Union, Dict +from typing import Any, Callable, Iterable, Optional, Tuple, Union, Dict import pandas as pd @@ -57,16 +57,16 @@ def save_artifacts(self, artifact_dir) -> None: ... @classmethod - def load(cls, local_dir, **kwargs): + def load(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None, **kwargs): constructor_params = cls.load_constructor_params(local_dir, **kwargs) artifacts = cls.load_artifacts(Path(local_dir) / "artifacts") or dict() constructor_params.update(artifacts) - return cls(model=cls.load_model(local_dir, **artifacts), **constructor_params) + return cls(model=cls.load_model(local_dir, model_py_ver=model_py_ver, **artifacts), **constructor_params) @classmethod - def load_model(cls, local_dir, **kwargs): + def load_model(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None, **kwargs): from langchain.chains import load_chain path = Path(local_dir) diff --git a/giskard/models/pytorch.py b/giskard/models/pytorch.py index 4fd9c1628b..401939d3bd 100644 --- a/giskard/models/pytorch.py +++ b/giskard/models/pytorch.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union, get_args +from typing import Literal, Optional, Tuple, Union, get_args import collections import importlib @@ -139,7 +139,7 @@ def __init__( ) @classmethod - def load_model(cls, local_dir): + def load_model(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None): return mlflow.pytorch.load_model(local_dir) def save_model(self, local_path, mlflow_meta: mlflow.models.Model): @@ -216,9 +216,9 @@ def save(self, local_path: Union[str, Path]) -> None: self.save_pytorch_meta(local_path) @classmethod - def load(cls, local_dir, **kwargs): + def load(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None, **kwargs): kwargs.update(cls.load_pytorch_meta(local_dir)) - return super().load(local_dir, **kwargs) + return super().load(local_dir, model_py_ver=model_py_ver, **kwargs) @classmethod def load_pytorch_meta(cls, local_dir): diff --git a/giskard/models/sklearn.py b/giskard/models/sklearn.py index cf3d3bd7dd..4618919528 100644 --- a/giskard/models/sklearn.py +++ b/giskard/models/sklearn.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Tuple import mlflow import pandas as pd @@ -66,7 +66,7 @@ def save_model(self, local_path, mlflow_meta): ) @classmethod - def load_model(cls, local_dir): + def load_model(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None): return mlflow.sklearn.load_model(local_dir) def model_predict(self, df): diff --git a/giskard/models/tensorflow.py b/giskard/models/tensorflow.py index 7cb43841a3..acab098497 100644 --- a/giskard/models/tensorflow.py +++ b/giskard/models/tensorflow.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Tuple import logging @@ -41,7 +41,7 @@ def __init__( ) @classmethod - def load_model(cls, local_path): + def load_model(cls, local_path, model_py_ver: Optional[Tuple[str, str, str]] = None): return mlflow.tensorflow.load_model(local_path) def save_model(self, local_path, mlflow_meta: mlflow.models.Model): diff --git a/tests/models/fixtures/func/3.10/giskard-model-meta.yaml b/tests/models/fixtures/func/3.10/giskard-model-meta.yaml new file mode 100644 index 0000000000..a983342099 --- /dev/null +++ b/tests/models/fixtures/func/3.10/giskard-model-meta.yaml @@ -0,0 +1,14 @@ +classification_labels: +- 0 +- 1 +description: No description +feature_names: null +id: 54a90cbf-3b44-43d4-8070-8926c8339791 +language: PYTHON +language_version: 3.10.13 +loader_class: PredictionFunctionModel +loader_module: giskard.models.function +model_type: CLASSIFICATION +name: PredictionFunctionModel +size: 0 +threshold: 0.5 diff --git a/tests/models/fixtures/func/3.10/giskard-model-wrapper-meta.yaml b/tests/models/fixtures/func/3.10/giskard-model-wrapper-meta.yaml new file mode 100644 index 0000000000..f57f73a329 --- /dev/null +++ b/tests/models/fixtures/func/3.10/giskard-model-wrapper-meta.yaml @@ -0,0 +1 @@ +batch_size: null diff --git a/tests/models/fixtures/func/3.10/model.pkl b/tests/models/fixtures/func/3.10/model.pkl new file mode 100644 index 0000000000..5a8778813f Binary files /dev/null and b/tests/models/fixtures/func/3.10/model.pkl differ diff --git a/tests/models/fixtures/func/3.11/giskard-model-meta.yaml b/tests/models/fixtures/func/3.11/giskard-model-meta.yaml new file mode 100644 index 0000000000..1d6513bc3c --- /dev/null +++ b/tests/models/fixtures/func/3.11/giskard-model-meta.yaml @@ -0,0 +1,14 @@ +classification_labels: +- 0 +- 1 +description: No description +feature_names: null +id: 01297928-e182-4b50-b9aa-57058dcbadbd +language: PYTHON +language_version: 3.11.5 +loader_class: PredictionFunctionModel +loader_module: giskard.models.function +model_type: CLASSIFICATION +name: PredictionFunctionModel +size: 0 +threshold: 0.5 diff --git a/tests/models/fixtures/func/3.11/giskard-model-wrapper-meta.yaml b/tests/models/fixtures/func/3.11/giskard-model-wrapper-meta.yaml new file mode 100644 index 0000000000..f57f73a329 --- /dev/null +++ b/tests/models/fixtures/func/3.11/giskard-model-wrapper-meta.yaml @@ -0,0 +1 @@ +batch_size: null diff --git a/tests/models/fixtures/func/3.11/model.pkl b/tests/models/fixtures/func/3.11/model.pkl new file mode 100644 index 0000000000..7c9904a16e Binary files /dev/null and b/tests/models/fixtures/func/3.11/model.pkl differ diff --git a/tests/models/fixtures/func/3.9/giskard-model-meta.yaml b/tests/models/fixtures/func/3.9/giskard-model-meta.yaml new file mode 100644 index 0000000000..bfe93dcbec --- /dev/null +++ b/tests/models/fixtures/func/3.9/giskard-model-meta.yaml @@ -0,0 +1,13 @@ +classification_labels: +- 0 +- 1 +feature_names: null +id: 592e6f31-5d04-4915-9d26-798f21277a28 +language: PYTHON +language_version: 3.9.6 +loader_class: PredictionFunctionModel +loader_module: giskard.models.function +model_type: CLASSIFICATION +name: PredictionFunctionModel +size: 0 +threshold: 0.5 diff --git a/tests/models/fixtures/func/3.9/giskard-model-wrapper-meta.yaml b/tests/models/fixtures/func/3.9/giskard-model-wrapper-meta.yaml new file mode 100644 index 0000000000..f57f73a329 --- /dev/null +++ b/tests/models/fixtures/func/3.9/giskard-model-wrapper-meta.yaml @@ -0,0 +1 @@ +batch_size: null diff --git a/tests/models/fixtures/func/3.9/model.pkl b/tests/models/fixtures/func/3.9/model.pkl new file mode 100644 index 0000000000..6eb0430ffb Binary files /dev/null and b/tests/models/fixtures/func/3.9/model.pkl differ diff --git a/tests/models/test_function_model.py b/tests/models/test_function_model.py index f401e239cf..3d242fc944 100644 --- a/tests/models/test_function_model.py +++ b/tests/models/test_function_model.py @@ -1,9 +1,11 @@ +from pathlib import Path import numpy as np import pandas as pd import pytest +import platform -import tests.utils from giskard import Dataset, Model +from giskard.ml_worker.exceptions.giskard_exception import GiskardPythonVerException from giskard.models.function import PredictionFunctionModel @@ -22,6 +24,8 @@ def test_prediction_function_upload(): lambda df: np.ones(len(df)), model_type="classification", classification_labels=[0, 1] ) + import tests.utils + tests.utils.verify_model_upload(gsk_model, Dataset(df=pd.DataFrame({"x": [1, 2, 3], "y": [1, 0, 1]}), target="y")) @@ -64,3 +68,27 @@ def test_single_feature(): with pytest.raises(Exception) as e: validate_model(giskard_model, giskard_dataset) assert e.match(r"Your model returned an error when we passed a 'pandas.Dataframe' as input.*") + + +COMPAT_TABLE = { + "3.9": ["3.9", "3.10"], + "3.10": ["3.9", "3.10"], + "3.11": ["3.11"], +} + + +@pytest.mark.parametrize("py_ver", ["3.9", "3.10", "3.11"]) +def test_prediction_function_load(py_ver): + model_path = Path(__file__).parent / "fixtures" / "func" / py_ver + if ".".join(platform.python_version_tuple()[:2]) in COMPAT_TABLE[py_ver]: + model = Model.load(model_path) + assert model is not None + else: + with pytest.raises(GiskardPythonVerException): + Model.load(model_path) + + +if __name__ == "__main__": + py_ver = ".".join(platform.python_version_tuple()[:2]) + model_path = Path(__file__).parent / "fixtures" / "func" / py_ver + Model(lambda df: np.ones(len(df)), model_type="classification", classification_labels=[0, 1]).save(model_path) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 7221e2b1ec..49e7f83369 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,4 +1,5 @@ import tempfile +from typing import Optional, Tuple from pathlib import Path import numpy as np @@ -73,7 +74,7 @@ def save_model(self, path): Path(path).joinpath("custom_data").touch() @classmethod - def load_model(cls, path): + def load_model(cls, path, model_py_ver: Optional[Tuple[str, str, str]] = None): call_count["load"] = call_count["load"] + 1 def model(x): diff --git a/tests/models/test_wrapper_model.py b/tests/models/test_wrapper_model.py index 68a3a5e820..017873887e 100644 --- a/tests/models/test_wrapper_model.py +++ b/tests/models/test_wrapper_model.py @@ -1,4 +1,5 @@ import tempfile +from typing import Optional, Tuple import numpy as np import pandas as pd @@ -17,7 +18,7 @@ def model_predict(self, data): return [0] * len(data) @classmethod - def load_model(cls, path): + def load_model(cls, path, model_py_ver: Optional[Tuple[str, str, str]] = None): pass def save_model(self, path): @@ -62,7 +63,7 @@ def model_predict(self, data): return [0] * len(data) @classmethod - def load_model(cls, path): + def load_model(cls, path, model_py_ver: Optional[Tuple[str, str, str]] = None): pass def save_model(self, path): diff --git a/tests/test_custom_model.py b/tests/test_custom_model.py index be27509647..d307e90a16 100644 --- a/tests/test_custom_model.py +++ b/tests/test_custom_model.py @@ -1,6 +1,6 @@ import re from pathlib import Path -from typing import Union +from typing import Union, Optional, Tuple from giskard.core.core import SupportedModelTypes from giskard.models.base import BaseModel, WrapperModel @@ -14,7 +14,7 @@ def test_custom_model(linear_regression_diabetes: BaseModel): class MyModel(WrapperModel): @classmethod - def load_model(cls, local_dir): + def load_model(cls, local_dir, model_py_ver: Optional[Tuple[str, str, str]] = None): pass def save_model(self, local_path: Union[str, Path]) -> None: