Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix control casing #18

Merged
merged 21 commits into from
Sep 5, 2024
Merged

Fix control casing #18

merged 21 commits into from
Sep 5, 2024

Conversation

aaravnavani
Copy link
Contributor

This PR fixes this issue.

For example, if the competitor that is passed in is "Apple" and we are validating the text Fun fact about apple is that apple is both a fruit and company. apples come in many different colors and flavors. What if apple is a tech company that makes phones, computers, and tablets? apple is headquartered in Cupertino, California, it properly identifies lowercase apple as a competitor.

Curl command to test with fastapi setup running on localhost:

curl -X POST "http://localhost:8000/validate" \
-H "Content-Type: application/json" \
-d '{
    "inputs": [
        {
            "name": "text",
            "shape": [1],
            "data": [
                "Fun fact about apple is that apple is both a fruit and company. apples come in many different colors and flavors. What if apple is a tech company that makes phones, computers, and tablets? apple is headquartered in Cupertino, California."
            ],
            "datatype": "BYTES"
        },
        {
            "name": "competitors",
            "shape": [1],
            "data": ["Apple"],
            "datatype": "BYTES"
        }
    ]
}'

gives output:

{"modelname":"en_core_web_trf","modelversion":"1","outputs":[{"name":"result0","shape":[1],"data":[["apple","apple","apple"]],"datatype":"BYTES"}]}

@CalebCourier
Copy link
Contributor

We should add the same fixes for local inference to the validator's main.py.

@aaravnavani
Copy link
Contributor Author

@CalebCourier from the tests that I ran, it seems to detect apple as a competitor with local inference. This script for example:

# Import Guard and Validator
from guardrails import Guard
from guardrails.hub import CompetitorCheck


# Setup Guard
guard = Guard().use(
    CompetitorCheck, ["Apple", "Samsung"], "exception", use_local=True,
)
response = guard.validate("apple just released a new iPhone.")  # Validator fails

print(response)

prints guardrails.errors.ValidationError: Validation failed for field with errors: Found the following competitors: [['Apple']]. Please avoid naming those competitors next time It seems to be handled in this line.

@CalebCourier
Copy link
Contributor

@aaravnavani that's only the initial filter. It also needs to be handled here and potentially in filtering methods.

As a side note, the fact that it's listing competitors as an Array<Array<String>> seems off.

@CalebCourier
Copy link
Contributor

@CalebCourier from the tests that I ran, it seems to detect apple as a competitor with local inference. This script for example:

# Import Guard and Validator
from guardrails import Guard
from guardrails.hub import CompetitorCheck


# Setup Guard
guard = Guard().use(
    CompetitorCheck, ["Apple", "Samsung"], "exception", use_local=True,
)
response = guard.validate("apple just released a new iPhone.")  # Validator fails

print(response)

prints guardrails.errors.ValidationError: Validation failed for field with errors: Found the following competitors: [['Apple']]. Please avoid naming those competitors next time It seems to be handled in this line.

I'm not seeing that behaviour locally. If I try to run that exact code snippet this is what I get:

ValidationOutcome(
    call_id='14844740304',
    raw_llm_output='apple just released a new iPhone.',
    validated_output='apple just released a new iPhone.',
    reask=None,
    validation_passed=True,
    error=None
)

@@ -159,7 +159,7 @@ def _inference_local(self, model_input: Any) -> List[List[str]]:
doc = self.nlp(t)
located_entities = []
for ent in doc.ents:
if ent.text in competitors:
if ent.text.lower() in [comp.lower() for comp in competitors]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: extract the competitors list comprehension to a variable

@CalebCourier CalebCourier merged commit 538343b into main Sep 5, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants