Skip to content

Commit

Permalink
Skip input_schema for hftransformersv2 back compat (#88)
Browse files Browse the repository at this point in the history
* Skip input_schema for hftransformersv2 back compat

Signed-off-by: Walter Martin <[email protected]>

* lint

Signed-off-by: Walter Martin <[email protected]>

* only check for hftransformers if input arg is dict

Signed-off-by: Walter Martin <[email protected]>

---------

Signed-off-by: Walter Martin <[email protected]>
  • Loading branch information
wamartin-aml authored Apr 3, 2024
1 parent 9f5ace8 commit 78c4e9c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
8 changes: 6 additions & 2 deletions inference_schema/schema_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def input_schema(param_name, param_type, convert_to_provided_type=True, optional

@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
def decorator_input(user_run, instance, args, kwargs):
if convert_to_provided_type:
is_hftransformersv2 = False
if len(args) > 0 and type(args[0]) is dict:
args_keys = args[0].keys()
is_hftransformersv2 = len(args_keys) == 2 and "parameters" in args_keys and "input_string" in args_keys
# skip all of this for hftransformersv2
if convert_to_provided_type and not is_hftransformersv2:
args = list(args)

if param_name not in kwargs.keys() and not optional:
decorators = _get_decorators(user_run)
arg_names = inspect.getfullargspec(decorators[-1]).args
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,19 @@ def standard_py_func(param):
return standard_py_func


@pytest.fixture(scope="session")
def decorated_standard_func_parameters(standard_sample_input, sample_param_dict):
@input_schema('input_data', StandardPythonParameterType(standard_sample_input))
@input_schema('params', StandardPythonParameterType(sample_param_dict), optional=False)
def standard_params_func(input_data, params=None):
if params is not None:
assert type(params) is dict
beams = params['num_beams'] if params is not None else 0
return input_data["input_string"], beams

return standard_params_func


@pytest.fixture(scope="session")
def standard_sample_input_multitype_list():
return ['foo', 1]
Expand Down
12 changes: 12 additions & 0 deletions tests/test_standard_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,15 @@ def test_float_int_handling(self, decorated_float_func):
int_input = 1
result = decorated_float_func(int_input)
assert int_input == result

def test_standard_params_handling_hftransformersv2(self, decorated_standard_func_parameters):
input_data = {
"input_string": ["the meaning of life is"],
"parameters": {
"num_beams": 2,
"max_length": 512
}
}
result = decorated_standard_func_parameters(input_data)
assert result[0][0] == "the meaning of life is"
assert result[1] == 0

0 comments on commit 78c4e9c

Please sign in to comment.