Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix for MLFlow validator spans during async execution #1164

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions guardrails/integrations/databricks/ml_flow_instrumentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def instrument(self):
export.validate
)
setattr(export, "validate", wrapped_validator_validate)

wrapped_validator_async_validate = (
self._instrument_validator_async_validate(export.async_validate)
)
setattr(export, "async_validate", wrapped_validator_async_validate)

setattr(guardrails.hub, validator_name, export) # type: ignore

def _instrument_guard(
Expand Down Expand Up @@ -387,6 +393,14 @@ def trace_validator_wrapper(*args, **kwargs):
init_kwargs = validator_self._kwargs

validator_span_name = f"{validator_name}.validate"

# Skip this instrumentation in the case of async
# when the parent span cannot be fetched from the current context
# because Validator.validate is running in a ThreadPoolExecutor
parent_span = mlflow.get_current_active_span()
if not parent_span:
return validator_validate(*args, **kwargs)

with mlflow.start_span(
name=validator_span_name,
span_type="validator",
Expand Down Expand Up @@ -425,3 +439,64 @@ def trace_validator_wrapper(*args, **kwargs):
raise e

return trace_validator_wrapper

def _instrument_validator_async_validate(
self,
validator_async_validate: Callable[..., Coroutine[Any, Any, ValidationResult]],
):
@wraps(validator_async_validate)
async def trace_async_validator_wrapper(*args, **kwargs):
validator_name = "validator"
obj_id = id(validator_async_validate)
on_fail_descriptor = "unknown"
init_kwargs = {}
validation_session_id = "unknown"

validator_self = args[0]
if validator_self is not None and isinstance(validator_self, Validator):
validator_name = validator_self.rail_alias
obj_id = id(validator_self)
on_fail_descriptor = validator_self.on_fail_descriptor
init_kwargs = validator_self._kwargs

validator_span_name = f"{validator_name}.validate"

with mlflow.start_span(
name=validator_span_name,
span_type="validator",
attributes={
"guardrails.version": GUARDRAILS_VERSION,
"type": "guardrails/guard/step/validator",
"async": True,
},
) as validator_span:
try:
resp = await validator_async_validate(*args, **kwargs)
add_validator_attributes(
*args,
validator_span=validator_span, # type: ignore
validator_name=validator_name,
obj_id=obj_id,
on_fail_descriptor=on_fail_descriptor,
result=resp,
init_kwargs=init_kwargs,
validation_session_id=validation_session_id,
**kwargs,
)
return resp
except Exception as e:
validator_span.set_status(status=SpanStatusCode.ERROR)
add_validator_attributes(
*args,
validator_span=validator_span, # type: ignore
validator_name=validator_name,
obj_id=obj_id,
on_fail_descriptor=on_fail_descriptor,
result=None,
init_kwargs=init_kwargs,
validation_session_id=validation_session_id,
**kwargs,
)
raise e

return trace_async_validator_wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ async def test__instrument_async_runner_call(self, mocker):

def test__instrument_validator_validate(self, mocker):
mock_span = MockSpan()
mock_start_span = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.get_current_active_span",
return_value=mock_span,
)
mock_start_span = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.start_span",
return_value=mock_span,
Expand Down Expand Up @@ -630,3 +634,52 @@ def test__instrument_validator_validate(self, mocker):
init_kwargs={},
validation_session_id="unknown",
)

@pytest.mark.asyncio
async def test__instrument_validator_async_validate(self, mocker):
mock_span = MockSpan()
mock_start_span = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.start_span",
return_value=mock_span,
)

mock_add_validator_attributes = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.add_validator_attributes"
)

from guardrails.integrations.databricks import MlFlowInstrumentor
from tests.unit_tests.mocks.mock_hub import MockValidator

m = MlFlowInstrumentor("mock experiment")

wrapped_async_validate = m._instrument_validator_async_validate(
MockValidator.async_validate
)

mock_validator = MockValidator()

resp = await wrapped_async_validate(mock_validator, True, {})

mock_start_span.assert_called_once_with(
name="mock-validator.validate",
span_type="validator",
attributes={
"guardrails.version": GUARDRAILS_VERSION,
"type": "guardrails/guard/step/validator",
"async": True,
},
)

# Internally called, not the wrapped call above
mock_add_validator_attributes.assert_called_once_with(
mock_validator,
True,
{},
validator_span=mock_span, # type: ignore
validator_name="mock-validator",
obj_id=id(mock_validator),
on_fail_descriptor="exception",
result=resp,
init_kwargs={},
validation_session_id="unknown",
)
Loading