diff --git a/guardrails/telemetry/common.py b/guardrails/telemetry/common.py index d2cc53ca8..d1335bb3c 100644 --- a/guardrails/telemetry/common.py +++ b/guardrails/telemetry/common.py @@ -124,3 +124,90 @@ def add_user_attributes(span: Span): except Exception as e: logger.warning("Error loading baggage user information", e) pass + + +def redact(value: str) -> str: + """Redacts all but the last four characters of the given string. + + Args: + value (str): The string to be redacted. + + Returns: + str: The redacted string with all but the last four characters + replaced by asterisks. + """ + redaction_length = len(value) - 4 + stars = "*" * redaction_length + return f"{stars}{value[-4:]}" + + +def ismatchingkey( + target_key: str, + keys_to_match: tuple[str, ...] = ("key", "token", "password"), +) -> bool: + """Check if the target key contains any of the specified keys to match. + + Args: + target_key (str): The key to be checked. + keys_to_match (tuple[str, ...], optional): A tuple of keys to match + against the target key. Defaults to ("key", "token"). + + Returns: + bool: True if any of the keys to match are found in the target key, + False otherwise. + """ + for k in keys_to_match: + if k in target_key: + return True + return False + + +def can_convert_to_dict(s): + """Check if a string can be converted to a dictionary. + + This function attempts to load the input string as JSON. If successful, + it returns True, indicating that the string can be converted to a dictionary. + Otherwise, it catches ValueError and TypeError exceptions and returns False. + + Args: + s (str): The input string to be checked. + + Returns: + bool: True if the string can be converted to a dictionary, False otherwise. + """ + try: + json.loads(s) + return True + except (ValueError, TypeError): + return False + + +def recursive_key_operation(data, operation, keys_to_match=["key", "token"]): + """Recursively checks if any key in the dictionary or JSON object is + present in keys_to_match and applies the operation on the corresponding + value. + + Args: + data (dict or list or str): The dictionary or JSON object to traverse. + keys_to_match (list): List of keys to match. + operation (function): The operation to perform on the matched values. + + Returns: + dict or list or str: the modified dictionary, list or string. + """ + if isinstance(data, str) and can_convert_to_dict(data): + data_dict = json.loads(data) + data = str(recursive_key_operation(data_dict, operation, keys_to_match)) + elif isinstance(data, dict): + for key, value in data.items(): + if ismatchingkey(key, keys_to_match) and isinstance(value, str): + # Apply the operation to the value of the matched key + data[key] = operation(value) + else: + # Recursively process nested dictionaries or lists + data[key] = recursive_key_operation(value, operation, keys_to_match) + elif isinstance(data, list): + for i in range(len(data)): + data[i] = recursive_key_operation(data[i], operation, keys_to_match) + + return data diff --git a/guardrails/telemetry/open_inference.py b/guardrails/telemetry/open_inference.py index 7c58d8a84..f934c12db 100644 --- a/guardrails/telemetry/open_inference.py +++ b/guardrails/telemetry/open_inference.py @@ -1,6 +1,12 @@ from typing import Any, Dict, List, Optional -from guardrails.telemetry.common import get_span, to_dict, serialize +from guardrails.telemetry.common import ( + get_span, + to_dict, + serialize, + recursive_key_operation, + redact, +) def trace_operation( @@ -93,6 +99,9 @@ def trace_llm_call( ser_invocation_parameters = serialize(invocation_parameters) if ser_invocation_parameters: + ser_invocation_parameters = recursive_key_operation( + ser_invocation_parameters, redact + ) current_span.set_attribute( "llm.invocation_parameters", ser_invocation_parameters ) diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index 27cf27a97..cbf5b6ac2 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -17,7 +17,13 @@ from guardrails.classes.output_type import OT from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.stores.context import get_guard_name -from guardrails.telemetry.common import get_tracer, add_user_attributes, serialize +from guardrails.telemetry.common import ( + get_tracer, + add_user_attributes, + serialize, + recursive_key_operation, + redact, +) from guardrails.utils.safe_get import safe_get from guardrails.version import GUARDRAILS_VERSION @@ -45,10 +51,14 @@ def add_step_attributes( ser_args = [serialize(arg) for arg in args] ser_kwargs = {k: serialize(v) for k, v in kwargs.items()} + inputs = { "args": [sarg for sarg in ser_args if sarg is not None], "kwargs": {k: v for k, v in ser_kwargs.items() if v is not None}, } + for k in inputs: + inputs[k] = recursive_key_operation(inputs[k], redact) + step_span.set_attribute("input.mime_type", "application/json") step_span.set_attribute("input.value", json.dumps(inputs)) @@ -239,6 +249,8 @@ def add_call_attributes( "args": [sarg for sarg in ser_args if sarg is not None], "kwargs": {k: v for k, v in ser_kwargs.items() if v is not None}, } + for k in inputs: + inputs[k] = recursive_key_operation(inputs[k], redact) call_span.set_attribute("input.mime_type", "application/json") call_span.set_attribute("input.value", json.dumps(inputs)) diff --git a/tests/unit_tests/redaction/test_matching_key.py b/tests/unit_tests/redaction/test_matching_key.py new file mode 100644 index 000000000..ff09584ae --- /dev/null +++ b/tests/unit_tests/redaction/test_matching_key.py @@ -0,0 +1,25 @@ +import unittest +from guardrails.telemetry.common import ismatchingkey + + +class TestIsMatchingKey(unittest.TestCase): + def test_key_matches_with_default_keys(self): + self.assertTrue(ismatchingkey("api_key")) + self.assertTrue(ismatchingkey("user_token")) + self.assertFalse(ismatchingkey("username")) + self.assertTrue(ismatchingkey("password")) + + def test_key_matches_with_custom_keys(self): + self.assertTrue(ismatchingkey("api_secret", keys_to_match=("secret",))) + self.assertTrue(ismatchingkey("client_id", keys_to_match=("id",))) + self.assertFalse(ismatchingkey("session", keys_to_match=("key", "token"))) + + def test_empty_key(self): + self.assertFalse(ismatchingkey("", keys_to_match=("key", "token"))) + + def test_empty_keys_to_match(self): + self.assertFalse(ismatchingkey("key", keys_to_match=())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/redaction/test_recursive_redaction.py b/tests/unit_tests/redaction/test_recursive_redaction.py new file mode 100644 index 000000000..2a252bdf0 --- /dev/null +++ b/tests/unit_tests/redaction/test_recursive_redaction.py @@ -0,0 +1,49 @@ +import unittest +from guardrails.telemetry.common import recursive_key_operation, redact +import ast + + +# Test suite for recursive_key_operation function +class TestRecursiveKeyOperation(unittest.TestCase): + def test_list(self): + data = '{"init_args": [], "init_kwargs": {"model": "gpt-4o-mini", \ + "api_base": "https://api.openai.com/v1", "api_key": "sk-1234"}}' + result = recursive_key_operation(data, redact) + assert ast.literal_eval(result)["init_kwargs"]["api_key"] == "***1234" + + def test_dict_kwargs(self): + data = { + "index": "0", + "api": '{"init_args": [], "init_kwargs": {"model": "gpt-4o-mini",\ + "api_base": "https://api.openai.com/v1", "api_key": "sk-1234"}}', + "messages": None, + "prompt_params": "{}", + "output_schema": '{"type": "string"}', + "output": None, + } + result = recursive_key_operation(data, redact) + assert ast.literal_eval(result["api"])["init_kwargs"]["api_key"] == "***1234" + + def test_nomatch(self): + data = {"somekey": "soemvalue"} + result = recursive_key_operation(data, redact) + self.assertEqual(result, data) + + def test_empty_dict(self): + data = {} + result = recursive_key_operation(data, redact) + self.assertEqual(result, data) + + def test_empty_list(self): + data = [] + result = recursive_key_operation(data, redact) + self.assertEqual(result, data) + + def test_non_string_value(self): + data = {"key": 123, "another_key": "value"} + result = recursive_key_operation(data, redact) + self.assertEqual(result, data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/redaction/test_redaction.py b/tests/unit_tests/redaction/test_redaction.py new file mode 100644 index 000000000..dda6b7ba2 --- /dev/null +++ b/tests/unit_tests/redaction/test_redaction.py @@ -0,0 +1,38 @@ +import unittest +from guardrails.telemetry.common import redact + + +class TestRedactFunction(unittest.TestCase): + def test_redact_long_string(self): + self.assertEqual(redact("supersecretpassword"), "***************word") + + def test_redact_short_string(self): + self.assertEqual(redact("test"), "test") + + def test_open_ai_example_key(self): + self.assertEqual( + redact("sk-hp37"), + "***hp37", + ) + + def test_redact_very_short_string(self): + self.assertEqual(redact("abc"), "abc") + + def test_redact_empty_string(self): + self.assertEqual(redact(""), "") + + def test_redact_exact_length(self): + self.assertEqual(redact("1234"), "1234") + + def test_redact_special_characters(self): + self.assertEqual(redact("ab!@#12"), "***@#12") + + def test_redact_single_character(self): + self.assertEqual(redact("a"), "a") + + def test_redact_spaces(self): + self.assertEqual(redact(" test"), "******test") + + +if __name__ == "__main__": + unittest.main()