Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce not given and apply it to dataset's target #1612

Merged
merged 21 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b9c59a1
Introduce not given and apply it to dataset's target
kevinmessiaen Nov 15, 2023
15c8ecf
Fixed typo
kevinmessiaen Nov 15, 2023
0e2e4f2
Added test to validate explicit target as None
kevinmessiaen Nov 15, 2023
57ecaf3
Fixed test where trying to assign value to a @property
kevinmessiaen Nov 16, 2023
3b877c5
Actually test for no warning being raised
kevinmessiaen Nov 16, 2023
21bf1b0
Code optimization
kevinmessiaen Nov 16, 2023
2a52a37
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 16, 2023
f7ec53b
Fixed compatibility issue with Python 3.9 due to usage of `|`
kevinmessiaen Nov 16, 2023
9eb16ab
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
rabah-khalek Nov 16, 2023
c464a55
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
rabah-khalek Nov 17, 2023
7d9df77
Removed check for model type to ignore target warnings
kevinmessiaen Nov 20, 2023
d4a7816
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 20, 2023
906f4ff
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 24, 2023
24d15a3
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 27, 2023
438ca6d
Moved target validation from model to dataset
kevinmessiaen Nov 27, 2023
ee7b5f7
Updated documentation and notebook
kevinmessiaen Nov 27, 2023
454e1cb
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 28, 2023
d440242
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 29, 2023
49d10bc
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 29, 2023
2dab169
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
rabah-khalek Nov 29, 2023
4f179f0
Merge branch 'main' into feature/gsk-2118-introduce-not_given-type-2
kevinmessiaen Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions giskard/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,37 @@ class Kwargs:
pass


_T = TypeVar("_T")


# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).

For example:

```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...

get(timout=1) # 1s timeout
get(timout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""

def __bool__(self) -> Literal[False]:
return False

def __repr__(self) -> str:
return "NOT_GIVEN"


NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()


def _get_plugin_method_full_name(func):
from giskard.ml_worker.testing.registry.registry import plugins_root

Expand Down
2 changes: 1 addition & 1 deletion giskard/core/dataset_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def validate_optional_target(ds: Dataset):
if ds.target is None:
if not ds.is_target_given:
warning(
"You did not provide the optional argument 'target'. "
"'target' is the column name in df corresponding to the actual target variable (ground truth)."
Expand Down
4 changes: 2 additions & 2 deletions giskard/core/model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from giskard.datasets.base import Dataset
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction
from giskard.models.base import BaseModel, WrapperModel
from .dataset_validation import validate_optional_target
from ..utils import fullname
from ..utils.analytics_collector import analytics, get_dataset_properties, get_model_properties
from .dataset_validation import validate_optional_target


@configured_validate_arguments
def validate_model(model: BaseModel, validate_ds: Optional[Dataset] = None, print_validation_message: bool = True):
try:
if model.meta.model_type != SupportedModelTypes.TEXT_GENERATION and validate_ds is not None:
if validate_ds is not None:
validate_optional_target(validate_ds)
kevinmessiaen marked this conversation as resolved.
Show resolved Hide resolved
_do_validate_model(model, validate_ds)
except (ValueError, TypeError) as err:
Expand Down
16 changes: 11 additions & 5 deletions giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from giskard.client.giskard_client import GiskardClient
from giskard.client.io_utils import compress, save_df
from giskard.client.python_utils import warning
from giskard.core.core import DatasetMeta, SupportedColumnTypes
from giskard.core.core import DatasetMeta, SupportedColumnTypes, NOT_GIVEN, NotGivenOr
from giskard.core.errors import GiskardImportError
from giskard.core.validation import configured_validate_arguments
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction, SlicingFunctionType
Expand Down Expand Up @@ -144,7 +144,7 @@ class Dataset(ColumnMetadataMixin):
"""

name: Optional[str]
target: Optional[str]
_target: NotGivenOr[Optional[str]]
column_types: Dict[str, str]
df: pd.DataFrame
id: uuid.UUID
Expand All @@ -156,7 +156,7 @@ def __init__(
self,
df: pd.DataFrame,
name: Optional[str] = None,
target: Optional[Hashable] = None,
target: NotGivenOr[Optional[Hashable]] = NOT_GIVEN,
cat_columns: Optional[List[str]] = None,
column_types: Optional[Dict[Hashable, str]] = None,
id: Optional[uuid.UUID] = None,
Expand Down Expand Up @@ -186,7 +186,7 @@ def __init__(

self.name = name
self.df = pd.DataFrame(df)
self.target = target
self._target = target

if validation:
from giskard.core.dataset_validation import validate_dtypes, validate_target_exists
Expand Down Expand Up @@ -229,7 +229,13 @@ def __init__(

logger.info("Your 'pandas.DataFrame' is successfully wrapped by Giskard's 'Dataset' wrapper class.")

self.data_processor = DataProcessor()
@property
def is_target_given(self) -> bool:
return self._target is not NOT_GIVEN

@property
def target(self) -> Optional[str]:
return self._target or None

def add_slicing_function(self, slicing_function: SlicingFunction):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_errors(dataset_name, model_name, request):
# dataset type error
dataset_copy = dataset.copy()
dataset_copy.df = [[0.6, 0.4]]
dataset_copy.target = [1]
dataset_copy._target = [1]

with pytest.raises(Exception) as e:
_evaluate(dataset_copy, model, evaluator_config)
Expand Down
45 changes: 30 additions & 15 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import pandas as pd
import numpy as np
from pydantic import ValidationError
import pytest
import uuid

import numpy as np
import pandas as pd
import pytest
import requests_mock
from pydantic import ValidationError

from giskard.datasets.base import Dataset
from giskard.core.dataset_validation import validate_optional_target
from giskard.client.dtos import DatasetMetaInfo

from giskard.core.dataset_validation import validate_optional_target
from giskard.datasets.base import Dataset
from tests import utils
from tests.communications.test_dto_serialization import is_required, get_fields, get_name


# FIXME: conflict on `name` between Giskard Hub (@NotBlank) and Python client (optional in DatasetMeta and DatasetMetaInfo)
MANDATORY_FIELDS = [
"id",
Expand Down Expand Up @@ -57,15 +55,28 @@ def test_factory():
assert isinstance(my_dataset, Dataset)


def test_valid_df_column_types():
# Option 0: none of column_types, cat_columns, infer_column_types = True are provided
def test_validate_optional_target():
with pytest.warns(
UserWarning,
match=r"You did not provide the optional argument 'target'\. 'target' is the column name "
r"in df corresponding to the actual target variable \(ground truth\)\.",
):
my_dataset = Dataset(valid_df)
validate_optional_target(my_dataset)

with pytest.warns(None) as record:
my_dataset = Dataset(valid_df, target=None)
validate_optional_target(my_dataset)

my_dataset = Dataset(valid_df, target="text_column")
validate_optional_target(my_dataset)

assert len(record) == 0


def test_valid_df_column_types():
# Option 0: none of column_types, cat_columns, infer_column_types = True are provided
my_dataset = Dataset(valid_df)
assert my_dataset.column_types == {
"categorical_column": "category",
"text_column": "text",
Expand Down Expand Up @@ -158,7 +169,6 @@ def test_numeric_column_names():


def test_infer_column_types():

# if df_size >= 100 ==> category_threshold = floor(log10(df_size))
assert Dataset(pd.DataFrame({"f": [1, 2] * 50})).column_types["f"] == "category"
assert Dataset(pd.DataFrame({"f": ["a", "b"] * 50})).column_types["f"] == "category"
Expand Down Expand Up @@ -228,8 +238,9 @@ def test_dataset_meta_info():
mandatory_field_names = []
optional_field_names = []
for name, field in get_fields(klass).items():
mandatory_field_names.append(get_name(name, field)) if is_required(field) else \
optional_field_names.append(get_name(name, field))
mandatory_field_names.append(get_name(name, field)) if is_required(field) else optional_field_names.append(
get_name(name, field)
)
assert set(mandatory_field_names) == set(MANDATORY_FIELDS)
assert set(optional_field_names) == set(OPTIONAL_FIELDS)

Expand All @@ -242,7 +253,9 @@ def test_fetch_dataset_meta(request):
with utils.MockedClient(mock_all=False) as (client, mr):
meta_info = utils.mock_dataset_meta_info(dataset, project_key)
meta_info.pop(op)
mr.register_uri(method=requests_mock.GET, url=utils.get_url_for_dataset(dataset, project_key), json=meta_info)
mr.register_uri(
method=requests_mock.GET, url=utils.get_url_for_dataset(dataset, project_key), json=meta_info
)

# Should not raise
client.load_dataset_meta(project_key, uuid=str(dataset.id))
Expand All @@ -251,7 +264,9 @@ def test_fetch_dataset_meta(request):
with utils.MockedClient(mock_all=False) as (client, mr):
meta_info = utils.mock_dataset_meta_info(dataset, project_key)
meta_info.pop(op)
mr.register_uri(method=requests_mock.GET, url=utils.get_url_for_dataset(dataset, project_key), json=meta_info)
mr.register_uri(
method=requests_mock.GET, url=utils.get_url_for_dataset(dataset, project_key), json=meta_info
)

# Should raise due to missing of values
with pytest.raises(ValidationError):
Expand Down