Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ends_with validator to support lists and strings with test #564

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions guardrails/validators/ends_with.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List, Union

from guardrails.logger import logger
from guardrails.validator_base import (
Expand All @@ -10,7 +10,7 @@
)


@register_validator(name="ends-with", data_type="list")
@register_validator(name="ends-with", data_type=["string", "list"])
class EndsWith(Validator):
"""Validates that a list ends with a given value.

Expand All @@ -19,24 +19,38 @@ class EndsWith(Validator):
| Property | Description |
| ----------------------------- | --------------------------------- |
| Name for `format` attribute | `ends-with` |
| Supported data types | `list` |
| Programmatic fix | Append the given value to the list. |
| Supported data types | `list`, `string |
| Programmatic fix | Append the given value to the list or string |

Args:
end: The required last element.
"""

def __init__(self, end: str, on_fail: str = "fix"):
def __init__(self, end: Union[List[Any], Any, str], on_fail: str = "fix"):
super().__init__(on_fail=on_fail, end=end)
self._end = end

def validate(self, value: Any, metadata: Dict) -> ValidationResult:
logger.debug(f"Validating {value} ends with {self._end}...")

if not value[-1] == self._end:
end = self._end
if isinstance(value, list) and not isinstance(self._end, list):
end = [self._end]

ending_idxs = len(end)
if not value[-ending_idxs:] == end:
if isinstance(value, list) and isinstance(end, list):
fix_value = value + end
elif isinstance(value, str) and isinstance(end, str):
fix_value = value + end
else:
raise TypeError(
"Cannot concatenate `end` and `value` of different types"
)

return FailResult(
error_message=f"{value} must end with {self._end}",
fix_value=value + [self._end],
error_message=f"{value} must end with {end}",
fix_value=fix_value,
)

return PassResult()
29 changes: 29 additions & 0 deletions tests/unit_tests/validators/test_ends_with.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from guardrails.validators import EndsWith, FailResult, PassResult


@pytest.mark.parametrize(
"input, end, outcome, fix_value",
[
("Test string", "g", "pass", None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't look like we run null or empty checks in the validator. Can we cover some of those cases in test to make sure they work correctly?

("Test string", "string", "pass", None),
(["Item 1", "Item 2"], ["Item 2"], "pass", None),
(["Item 1", "Item 2", "Item 3"], ["Item 2", "Item 3"], "pass", None),
("Test string", "Test", "fail", "Test stringTest"),
(["Item 1", "Item 2"], ["Item 1"], "fail", ["Item 1", "Item 2", "Item 1"]),
(["Item 1", "Item 2"], "Item 2", "pass", None),
(["Item 1", "Item 2"], "Item 3", "fail", ["Item 1", "Item 2", "Item 3"]),
],
)
def test_ends_with_validator(input, end, outcome, fix_value):
"""Test that the validator returns the expected outcome and fix_value."""
validator = EndsWith(end=end, on_fail="fix")
result: PassResult = validator.validate(input, {})

# Check that the result matches the expected outcome and fix_value
if outcome == "fail":
assert isinstance(result, FailResult)
assert result.fix_value == fix_value
else:
assert isinstance(result, PassResult)
Loading