Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add opentelemetry integration #467

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions arq/opentelemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Observability utilities related to the `arq` library."""

from arq.opentelemetry.consume import instrument_job
from arq.opentelemetry.produce import InstrumentedArqRedis
from arq.opentelemetry.propagator import ArqJobTextMapPropagator

__all__ = ["instrument_job", "InstrumentedArqRedis", "ArqJobTextMapPropagator"]
107 changes: 107 additions & 0 deletions arq/opentelemetry/consume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Module that instruments tracing on enqueue_job() and the various
job coroutines."""

from functools import wraps
from time import time_ns
from typing import Any, Awaitable, Callable, Dict, List, Optional, ParamSpec, TypeVar

from opentelemetry import context, metrics, trace
from opentelemetry.propagate import extract, inject

from arq.jobs import Job
from arq.opentelemetry.shared import shared_messaging_attributes
from arq.worker import Retry

meter = metrics.get_meter(__name__)
job_counter = meter.create_counter(
name="arq.jobs.counter",
unit="1",
description="Counts the number of jobs created with Arq.",
)


def span_name_consumer(func: Callable) -> str:
"""Get the span name to use when running a job. Name is based on conventions in
https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/messaging.md#conventions

There's a difference between 'receive' and 'process', which can be elided when they happen
close together, but is relevant for batch processing. But for now, we just use "process".
"""
return f"{func.__qualname__} process"


def get_consumer_attributes(
_func: Callable[..., Awaitable[Any]], _args: List[Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Get attributes that apply when running a job."""
attributes = {}
attributes["span.name"] = span_name_consumer(_func)
attributes["messaging.operation"] = "process"
if "_queue_name" in kwargs:
attributes["messaging.source.name"] = kwargs["_queue_name"]

return attributes


TRACE_CONTEXT_PARAM_NAME = "_trace_context"
JobArgs = ParamSpec("JobArgs")
JobReturnType = TypeVar("JobReturnType")


def instrument_job(
wrapped: Callable[JobArgs, Awaitable[JobReturnType]]
) -> Callable[JobArgs, Awaitable[JobReturnType]]:
"""Decorate a job definition such that it can be traced."""
shared_attributes = shared_messaging_attributes(redis_host, redis_port)

@wraps(wrapped)
async def wrapper(*args: JobArgs.args, **kwargs: JobArgs.kwargs) -> Any:
if TRACE_CONTEXT_PARAM_NAME in kwargs:
# The .extract() method gets us the trace data...
token = context.attach(
extract(
# IMPORTANT! Manually remove TRACE_CONTEXT_PARAM_NAME
# to prevent it being passed as an argument to the job.
carrier=kwargs.pop(TRACE_CONTEXT_PARAM_NAME, {}),
)
)
else:
# We're running a job coroutine but not as a job - perhaps it was awaited
# by something else. So we have no trace data in its kwargs.
#
# In that case, we'll stick with the current context.
token = None

tracer = trace.get_tracer(__name__)
span = tracer.start_span(
span_name_consumer(wrapped),
kind=trace.SpanKind.CONSUMER,
start_time=time_ns(),
)
if span.is_recording():
attributes = get_consumer_attributes(wrapped, list(args), kwargs)
for key, value in attributes.items():
span.set_attribute(key, value)

with trace.use_span(
span, end_on_exit=True, set_status_on_exception=False
) as span:
try:
result = await wrapped(*args, **kwargs)
except Retry as exc:
span.set_status(trace.Status(trace.StatusCode.OK))
raise exc
except Exception as exc:
span.set_status(
trace.Status(
trace.StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
)
)
raise exc

if token is not None:
context.detach(token)
return result

return wrapper
170 changes: 170 additions & 0 deletions arq/opentelemetry/produce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Code to add trace information producing jobs."""

from datetime import datetime, timedelta
from time import time_ns
from typing import (
Any,
Awaitable,
Callable,
Dict,
Mapping,
Optional,
ParamSpec,
Union,
cast,
)
from uuid import uuid4

from opentelemetry import context, metrics, trace
from opentelemetry.propagate import inject
from redis.asyncio.connection import ConnectKwargs

from arq import ArqRedis
from arq.jobs import Job
from arq.opentelemetry.shared import shared_messaging_attributes

P = ParamSpec("P")
EnqueueJobType = Callable[P, Awaitable[Optional[Job]]]

meter = metrics.get_meter(__name__)


def span_name_producer(*enqueue_args: Any, **enqueue_kwargs: Mapping[str, Any]) -> str:
"""Get the span name to use when enqueuing a job. Name is based on conventions in
https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/messaging.md#conventions
and is derived from the name of the job that was passed as an argument to enqueue_job().
"""
job_name = enqueue_kwargs.get("function", enqueue_args[0])
return f"{job_name[0]} send"


def get_producer_attributes(*args: Any, **kwargs: Mapping[str, Any]) -> Dict[str, Any]:
"""Get attributes that apply when enqueuing a job."""
attributes = {}
attributes["span.name"] = span_name_producer(args, kwargs)
attributes["messaging.operation"] = "publish"
if "_queue_name" in kwargs:
attributes["messaging.destination.name"] = kwargs["_queue_name"]

return attributes


def get_message_attributes(*args: Any, **kwargs: Mapping[str, Any]) -> Dict[str, Any]:
"""Get attributes specific to a message when enqueuing a job."""
attributes = {}
if "_job_id" in kwargs:
attributes["messaging.message.id"] = kwargs["_job_id"]

return attributes


TRACE_CONTEXT_PARAM_NAME = "_trace_context"

job_counter = meter.create_counter(
name="arq.jobs.counter",
unit="1",
description="Counts the number of jobs created with Arq.",
)


def get_wrap_enqueue_job(redis_host: str, redis_port: int) -> EnqueueJobType:
shared_attributes = shared_messaging_attributes(redis_host, redis_port)

async def wrap_enqueue_job(
enqueue_job_func: EnqueueJobType,
*args: P.args,
**kwargs: P.kwargs,
) -> Optional[Job]:
"""Add an extra parameter into the job we're enqueueing, which holds
trace context information."""
token = context.attach(context.get_current())

attributes = get_producer_attributes(enqueue_job_func, list(args), kwargs)
attributes.update(shared_attributes)
tracer = trace.get_tracer(__name__)
span = tracer.start_span(
span_name_producer(args, kwargs),
kind=trace.SpanKind.PRODUCER,
start_time=time_ns(),
)
if span.is_recording():
message_attributes = get_message_attributes(
enqueue_job_func, list(args), kwargs
)
message_attributes.update(attributes)
span.set_attributes(message_attributes)

job_counter.add(1, attributes=attributes)

# Inject our context into the job definition.
kwargs.setdefault(TRACE_CONTEXT_PARAM_NAME, {})
inject(carrier=kwargs[TRACE_CONTEXT_PARAM_NAME])

with trace.use_span(span, end_on_exit=True):
result = await enqueue_job_func(*args, **kwargs)
# If we were given an arq.jobs.Job instance as a result, put its job ID into our
# span attributes.
if result is not None and isinstance(result, Job):
span.set_attribute("messaging.message.id", result.job_id)

context.detach(token)
return result

return wrap_enqueue_job


class InstrumentedArqRedis(ArqRedis):
"""InstrumentedArqRedis is an ArqRedis instance that adds tracing
information to the jobs it enqueues.
"""

wrapper: Callable[..., Awaitable[Optional[Job]]]

def __init__(
self,
arq_redis: ArqRedis,
) -> None:
connection_kwargs = cast(
ConnectKwargs, arq_redis.connection_pool.connection_kwargs
)
self.wrapper = get_wrap_enqueue_job(
connection_kwargs.get("host", "localhost"),
connection_kwargs.get("port", 6379),
)
super().__init__(
connection_pool=arq_redis.connection_pool,
job_serializer=arq_redis.job_serializer,
job_deserializer=arq_redis.job_deserializer,
default_queue_name=arq_redis.default_queue_name,
expires_extra_ms=arq_redis.expires_extra_ms,
)

async def enqueue_job(
self,
function: str,
*args: Any,
_job_id: Optional[str] = None,
_queue_name: Optional[str] = None,
_defer_until: Optional[datetime] = None,
_defer_by: Union[None, int, float, timedelta] = None,
_expires: Union[None, int, float, timedelta] = None,
_job_try: Optional[int] = None,
**kwargs: Any,
) -> Optional[Job]:
# Allow _queue_name to be included in trace.
if _queue_name is None:
_queue_name = self.default_queue_name
_job_id = _job_id or uuid4().hex

return await self.wrapper(
super().enqueue_job,
function,
*args,
_job_id=_job_id,
_queue_name=_queue_name,
_defer_until=_defer_until,
_defer_by=_defer_by,
_expires=_expires,
_job_try=_job_try,
**kwargs,
)
Loading
Loading