Skip to content

Commit

Permalink
Added Vertex AI spans for request parameters (open-telemetry#3192)
Browse files Browse the repository at this point in the history
* Added Vertex AI spans for request parameters

* small fixes, get CI passing

* Use standard OTel tracing error handling

* move nested util

* Actually use GAPIC client since thats what we use under the hood

Also this is what LangChain uses

* Comment out seed for now

* Remove unnecessary dict.get() calls

* Typing improvements to check that we support both v1 and v1beta1

* Add more teest cases for error conditions and fix span name bug

* fix typing

* Add todos for error.type
  • Loading branch information
aabmass authored Jan 22, 2025
1 parent 3f50c08 commit ec3c51d
Show file tree
Hide file tree
Showing 12 changed files with 848 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Added Vertex AI spans for request parameters
([#3192](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3192))
- Initial VertexAI instrumentation
([#3123](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3123))
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,17 @@

from typing import Any, Collection

from wrapt import (
wrap_function_wrapper, # type: ignore[reportUnknownVariableType]
)

from opentelemetry._events import get_event_logger
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.vertexai.package import _instruments
from opentelemetry.instrumentation.vertexai.patch import (
generate_content_create,
)
from opentelemetry.instrumentation.vertexai.utils import is_content_enabled
from opentelemetry.semconv.schemas import Schemas
from opentelemetry.trace import get_tracer

Expand All @@ -55,20 +63,34 @@ def instrumentation_dependencies(self) -> Collection[str]:
def _instrument(self, **kwargs: Any):
"""Enable VertexAI instrumentation."""
tracer_provider = kwargs.get("tracer_provider")
_tracer = get_tracer(
tracer = get_tracer(
__name__,
"",
tracer_provider,
schema_url=Schemas.V1_28_0.value,
)
event_logger_provider = kwargs.get("event_logger_provider")
_event_logger = get_event_logger(
event_logger = get_event_logger(
__name__,
"",
schema_url=Schemas.V1_28_0.value,
event_logger_provider=event_logger_provider,
)
# TODO: implemented in later PR

wrap_function_wrapper(
module="google.cloud.aiplatform_v1beta1.services.prediction_service.client",
name="PredictionServiceClient.generate_content",
wrapper=generate_content_create(
tracer, event_logger, is_content_enabled()
),
)
wrap_function_wrapper(
module="google.cloud.aiplatform_v1.services.prediction_service.client",
name="PredictionServiceClient.generate_content",
wrapper=generate_content_create(
tracer, event_logger, is_content_enabled()
),
)

def _uninstrument(self, **kwargs: Any) -> None:
"""TODO: implemented in later PR"""
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import (
TYPE_CHECKING,
Any,
Callable,
MutableSequence,
)

from opentelemetry._events import EventLogger
from opentelemetry.instrumentation.vertexai.utils import (
GenerateContentParams,
get_genai_request_attributes,
get_span_name,
)
from opentelemetry.trace import SpanKind, Tracer

if TYPE_CHECKING:
from google.cloud.aiplatform_v1.services.prediction_service import client
from google.cloud.aiplatform_v1.types import (
content,
prediction_service,
)
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
client as client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.types import (
content as content_v1beta1,
)
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as prediction_service_v1beta1,
)


# Use parameter signature from
# https://github.com/googleapis/python-aiplatform/blob/v1.76.0/google/cloud/aiplatform_v1/services/prediction_service/client.py#L2088
# to handle named vs positional args robustly
def _extract_params(
request: prediction_service.GenerateContentRequest
| prediction_service_v1beta1.GenerateContentRequest
| dict[Any, Any]
| None = None,
*,
model: str | None = None,
contents: MutableSequence[content.Content]
| MutableSequence[content_v1beta1.Content]
| None = None,
**_kwargs: Any,
) -> GenerateContentParams:
# Request vs the named parameters are mututally exclusive or the RPC will fail
if not request:
return GenerateContentParams(
model=model or "",
contents=contents,
)

if isinstance(request, dict):
return GenerateContentParams(**request)

return GenerateContentParams(
model=request.model,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
tool_config=request.tool_config,
labels=request.labels,
safety_settings=request.safety_settings,
generation_config=request.generation_config,
)


def generate_content_create(
tracer: Tracer, event_logger: EventLogger, capture_content: bool
):
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""

def traced_method(
wrapped: Callable[
...,
prediction_service.GenerateContentResponse
| prediction_service_v1beta1.GenerateContentResponse,
],
instance: client.PredictionServiceClient
| client_v1beta1.PredictionServiceClient,
args: Any,
kwargs: Any,
):
params = _extract_params(*args, **kwargs)
span_attributes = get_genai_request_attributes(params)

span_name = get_span_name(span_attributes)
with tracer.start_as_current_span(
name=span_name,
kind=SpanKind.CLIENT,
attributes=span_attributes,
) as _span:
# TODO: emit request events
# if span.is_recording():
# for message in kwargs.get("messages", []):
# event_logger.emit(
# message_to_event(message, capture_content)
# )

# TODO: set error.type attribute
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
result = wrapped(*args, **kwargs)
# TODO: handle streaming
# if is_streaming(kwargs):
# return StreamWrapper(
# result, span, event_logger, capture_content
# )

# TODO: add response attributes and events
# if span.is_recording():
# _set_response_attributes(
# span, result, event_logger, capture_content
# )
return result

return traced_method
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import re
from dataclasses import dataclass
from os import environ
from typing import (
TYPE_CHECKING,
Mapping,
Sequence,
)

from opentelemetry.semconv._incubating.attributes import (
gen_ai_attributes as GenAIAttributes,
)
from opentelemetry.util.types import AttributeValue

if TYPE_CHECKING:
from google.cloud.aiplatform_v1.types import content, tool
from google.cloud.aiplatform_v1beta1.types import (
content as content_v1beta1,
)
from google.cloud.aiplatform_v1beta1.types import (
tool as tool_v1beta1,
)


@dataclass(frozen=True)
class GenerateContentParams:
model: str
contents: (
Sequence[content.Content] | Sequence[content_v1beta1.Content] | None
) = None
system_instruction: content.Content | content_v1beta1.Content | None = None
tools: Sequence[tool.Tool] | Sequence[tool_v1beta1.Tool] | None = None
tool_config: tool.ToolConfig | tool_v1beta1.ToolConfig | None = None
labels: Mapping[str, str] | None = None
safety_settings: (
Sequence[content.SafetySetting]
| Sequence[content_v1beta1.SafetySetting]
| None
) = None
generation_config: (
content.GenerationConfig | content_v1beta1.GenerationConfig | None
) = None


def get_genai_request_attributes(
params: GenerateContentParams,
operation_name: GenAIAttributes.GenAiOperationNameValues = GenAIAttributes.GenAiOperationNameValues.CHAT,
):
model = _get_model_name(params.model)
generation_config = params.generation_config
attributes: dict[str, AttributeValue] = {
GenAIAttributes.GEN_AI_OPERATION_NAME: operation_name.value,
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.VERTEX_AI.value,
GenAIAttributes.GEN_AI_REQUEST_MODEL: model,
}

if not generation_config:
return attributes

# Check for optional fields
# https://proto-plus-python.readthedocs.io/en/stable/fields.html#optional-fields
if "temperature" in generation_config:
attributes[GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE] = (
generation_config.temperature
)
if "top_p" in generation_config:
attributes[GenAIAttributes.GEN_AI_REQUEST_TOP_P] = (
generation_config.top_p
)
if "max_output_tokens" in generation_config:
attributes[GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS] = (
generation_config.max_output_tokens
)
if "presence_penalty" in generation_config:
attributes[GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY] = (
generation_config.presence_penalty
)
if "frequency_penalty" in generation_config:
attributes[GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY] = (
generation_config.frequency_penalty
)
# Uncomment once GEN_AI_REQUEST_SEED is released in 1.30
# https://github.com/open-telemetry/semantic-conventions/pull/1710
# if "seed" in generation_config:
# attributes[GenAIAttributes.GEN_AI_REQUEST_SEED] = (
# generation_config.seed
# )
if "stop_sequences" in generation_config:
attributes[GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES] = (
generation_config.stop_sequences
)

return attributes


_MODEL_STRIP_RE = re.compile(
r"^projects/(.*)/locations/(.*)/publishers/google/models/"
)


def _get_model_name(model: str) -> str:
return _MODEL_STRIP_RE.sub("", model)


OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = (
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
)


def is_content_enabled() -> bool:
capture_content = environ.get(
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "false"
)

return capture_content.lower() == "true"


def get_span_name(span_attributes: Mapping[str, AttributeValue]) -> str:
name = span_attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]
model = span_attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]
if not model:
return f"{name}"
return f"{name} {model}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
interactions:
- request:
body: |-
{
"contents": [
{
"role": "user",
"parts": [
{
"text": "Say this is a test"
}
]
}
]
}
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '141'
Content-Type:
- application/json
User-Agent:
- python-requests/2.32.3
method: POST
uri: https://us-central1-aiplatform.googleapis.com/v1/projects/fake-project/locations/us-central1/publishers/google/models/gemini-1.5-flash-002:generateContent?%24alt=json%3Benum-encoding%3Dint
response:
body:
string: |-
{
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": "Okay, I understand. I'm ready for your test. Please proceed.\n"
}
]
},
"finishReason": 1,
"avgLogprobs": -0.005692833348324424
}
],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 19,
"totalTokenCount": 24
},
"modelVersion": "gemini-1.5-flash-002"
}
headers:
Content-Type:
- application/json; charset=UTF-8
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
content-length:
- '453'
status:
code: 200
message: OK
version: 1
Loading

0 comments on commit ec3c51d

Please sign in to comment.