diff --git a/arq/opentelemetry/__init__.py b/arq/opentelemetry/__init__.py new file mode 100644 index 00000000..d33f673c --- /dev/null +++ b/arq/opentelemetry/__init__.py @@ -0,0 +1,7 @@ +"""Observability utilities related to the `arq` library.""" + +from arq.opentelemetry.consume import instrument_job +from arq.opentelemetry.produce import InstrumentedArqRedis +from arq.opentelemetry.propagator import ArqJobTextMapPropagator + +__all__ = ["instrument_job", "InstrumentedArqRedis", "ArqJobTextMapPropagator"] diff --git a/arq/opentelemetry/consume.py b/arq/opentelemetry/consume.py new file mode 100644 index 00000000..9f4e4950 --- /dev/null +++ b/arq/opentelemetry/consume.py @@ -0,0 +1,107 @@ +"""Module that instruments tracing on enqueue_job() and the various +job coroutines.""" + +from functools import wraps +from time import time_ns +from typing import Any, Awaitable, Callable, Dict, List, Optional, ParamSpec, TypeVar + +from opentelemetry import context, metrics, trace +from opentelemetry.propagate import extract, inject + +from arq.jobs import Job +from arq.opentelemetry.shared import shared_messaging_attributes +from arq.worker import Retry + +meter = metrics.get_meter(__name__) +job_counter = meter.create_counter( + name="arq.jobs.counter", + unit="1", + description="Counts the number of jobs created with Arq.", +) + + +def span_name_consumer(func: Callable) -> str: + """Get the span name to use when running a job. Name is based on conventions in + https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/messaging.md#conventions + + There's a difference between 'receive' and 'process', which can be elided when they happen + close together, but is relevant for batch processing. But for now, we just use "process". + """ + return f"{func.__qualname__} process" + + +def get_consumer_attributes( + _func: Callable[..., Awaitable[Any]], _args: List[Any], kwargs: Dict[str, Any] +) -> Dict[str, Any]: + """Get attributes that apply when running a job.""" + attributes = {} + attributes["span.name"] = span_name_consumer(_func) + attributes["messaging.operation"] = "process" + if "_queue_name" in kwargs: + attributes["messaging.source.name"] = kwargs["_queue_name"] + + return attributes + + +TRACE_CONTEXT_PARAM_NAME = "_trace_context" +JobArgs = ParamSpec("JobArgs") +JobReturnType = TypeVar("JobReturnType") + + +def instrument_job( + wrapped: Callable[JobArgs, Awaitable[JobReturnType]] +) -> Callable[JobArgs, Awaitable[JobReturnType]]: + """Decorate a job definition such that it can be traced.""" + shared_attributes = shared_messaging_attributes(redis_host, redis_port) + + @wraps(wrapped) + async def wrapper(*args: JobArgs.args, **kwargs: JobArgs.kwargs) -> Any: + if TRACE_CONTEXT_PARAM_NAME in kwargs: + # The .extract() method gets us the trace data... + token = context.attach( + extract( + # IMPORTANT! Manually remove TRACE_CONTEXT_PARAM_NAME + # to prevent it being passed as an argument to the job. + carrier=kwargs.pop(TRACE_CONTEXT_PARAM_NAME, {}), + ) + ) + else: + # We're running a job coroutine but not as a job - perhaps it was awaited + # by something else. So we have no trace data in its kwargs. + # + # In that case, we'll stick with the current context. + token = None + + tracer = trace.get_tracer(__name__) + span = tracer.start_span( + span_name_consumer(wrapped), + kind=trace.SpanKind.CONSUMER, + start_time=time_ns(), + ) + if span.is_recording(): + attributes = get_consumer_attributes(wrapped, list(args), kwargs) + for key, value in attributes.items(): + span.set_attribute(key, value) + + with trace.use_span( + span, end_on_exit=True, set_status_on_exception=False + ) as span: + try: + result = await wrapped(*args, **kwargs) + except Retry as exc: + span.set_status(trace.Status(trace.StatusCode.OK)) + raise exc + except Exception as exc: + span.set_status( + trace.Status( + trace.StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + raise exc + + if token is not None: + context.detach(token) + return result + + return wrapper diff --git a/arq/opentelemetry/produce.py b/arq/opentelemetry/produce.py new file mode 100644 index 00000000..f5de9761 --- /dev/null +++ b/arq/opentelemetry/produce.py @@ -0,0 +1,170 @@ +"""Code to add trace information producing jobs.""" + +from datetime import datetime, timedelta +from time import time_ns +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Mapping, + Optional, + ParamSpec, + Union, + cast, +) +from uuid import uuid4 + +from opentelemetry import context, metrics, trace +from opentelemetry.propagate import inject +from redis.asyncio.connection import ConnectKwargs + +from arq import ArqRedis +from arq.jobs import Job +from arq.opentelemetry.shared import shared_messaging_attributes + +P = ParamSpec("P") +EnqueueJobType = Callable[P, Awaitable[Optional[Job]]] + +meter = metrics.get_meter(__name__) + + +def span_name_producer(*enqueue_args: Any, **enqueue_kwargs: Mapping[str, Any]) -> str: + """Get the span name to use when enqueuing a job. Name is based on conventions in + https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/messaging.md#conventions + and is derived from the name of the job that was passed as an argument to enqueue_job(). + """ + job_name = enqueue_kwargs.get("function", enqueue_args[0]) + return f"{job_name[0]} send" + + +def get_producer_attributes(*args: Any, **kwargs: Mapping[str, Any]) -> Dict[str, Any]: + """Get attributes that apply when enqueuing a job.""" + attributes = {} + attributes["span.name"] = span_name_producer(args, kwargs) + attributes["messaging.operation"] = "publish" + if "_queue_name" in kwargs: + attributes["messaging.destination.name"] = kwargs["_queue_name"] + + return attributes + + +def get_message_attributes(*args: Any, **kwargs: Mapping[str, Any]) -> Dict[str, Any]: + """Get attributes specific to a message when enqueuing a job.""" + attributes = {} + if "_job_id" in kwargs: + attributes["messaging.message.id"] = kwargs["_job_id"] + + return attributes + + +TRACE_CONTEXT_PARAM_NAME = "_trace_context" + +job_counter = meter.create_counter( + name="arq.jobs.counter", + unit="1", + description="Counts the number of jobs created with Arq.", +) + + +def get_wrap_enqueue_job(redis_host: str, redis_port: int) -> EnqueueJobType: + shared_attributes = shared_messaging_attributes(redis_host, redis_port) + + async def wrap_enqueue_job( + enqueue_job_func: EnqueueJobType, + *args: P.args, + **kwargs: P.kwargs, + ) -> Optional[Job]: + """Add an extra parameter into the job we're enqueueing, which holds + trace context information.""" + token = context.attach(context.get_current()) + + attributes = get_producer_attributes(enqueue_job_func, list(args), kwargs) + attributes.update(shared_attributes) + tracer = trace.get_tracer(__name__) + span = tracer.start_span( + span_name_producer(args, kwargs), + kind=trace.SpanKind.PRODUCER, + start_time=time_ns(), + ) + if span.is_recording(): + message_attributes = get_message_attributes( + enqueue_job_func, list(args), kwargs + ) + message_attributes.update(attributes) + span.set_attributes(message_attributes) + + job_counter.add(1, attributes=attributes) + + # Inject our context into the job definition. + kwargs.setdefault(TRACE_CONTEXT_PARAM_NAME, {}) + inject(carrier=kwargs[TRACE_CONTEXT_PARAM_NAME]) + + with trace.use_span(span, end_on_exit=True): + result = await enqueue_job_func(*args, **kwargs) + # If we were given an arq.jobs.Job instance as a result, put its job ID into our + # span attributes. + if result is not None and isinstance(result, Job): + span.set_attribute("messaging.message.id", result.job_id) + + context.detach(token) + return result + + return wrap_enqueue_job + + +class InstrumentedArqRedis(ArqRedis): + """InstrumentedArqRedis is an ArqRedis instance that adds tracing + information to the jobs it enqueues. + """ + + wrapper: Callable[..., Awaitable[Optional[Job]]] + + def __init__( + self, + arq_redis: ArqRedis, + ) -> None: + connection_kwargs = cast( + ConnectKwargs, arq_redis.connection_pool.connection_kwargs + ) + self.wrapper = get_wrap_enqueue_job( + connection_kwargs.get("host", "localhost"), + connection_kwargs.get("port", 6379), + ) + super().__init__( + connection_pool=arq_redis.connection_pool, + job_serializer=arq_redis.job_serializer, + job_deserializer=arq_redis.job_deserializer, + default_queue_name=arq_redis.default_queue_name, + expires_extra_ms=arq_redis.expires_extra_ms, + ) + + async def enqueue_job( + self, + function: str, + *args: Any, + _job_id: Optional[str] = None, + _queue_name: Optional[str] = None, + _defer_until: Optional[datetime] = None, + _defer_by: Union[None, int, float, timedelta] = None, + _expires: Union[None, int, float, timedelta] = None, + _job_try: Optional[int] = None, + **kwargs: Any, + ) -> Optional[Job]: + # Allow _queue_name to be included in trace. + if _queue_name is None: + _queue_name = self.default_queue_name + _job_id = _job_id or uuid4().hex + + return await self.wrapper( + super().enqueue_job, + function, + *args, + _job_id=_job_id, + _queue_name=_queue_name, + _defer_until=_defer_until, + _defer_by=_defer_by, + _expires=_expires, + _job_try=_job_try, + **kwargs, + ) diff --git a/arq/opentelemetry/propagator.py b/arq/opentelemetry/propagator.py new file mode 100644 index 00000000..88efc0cc --- /dev/null +++ b/arq/opentelemetry/propagator.py @@ -0,0 +1,119 @@ +"""A TextMapPropagator implementation for tracing between Arq jobs.""" + +from typing import Dict, List, Optional, Set + +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.propagators.textmap import ( + CarrierT, + Getter, + Setter, + TextMapPropagator, + default_getter, + default_setter, +) +from opentelemetry.trace import INVALID_SPAN_ID, INVALID_TRACE_ID, TraceFlags + +JobCarrierType = Dict[str, str] +ARQ_TRACE_ID_HEADER = "arq-tracer-trace-id" +ARQ_SPAN_ID_HEADER = "arq-tracer-span-id" +ARQ_TRACE_FLAGS_HEADER = "arq-tracer-trace-flags" + + +class ArqJobTextMapPropagator(TextMapPropagator): + """Used to convert between the SpanContext objects that OpenTelemetry uses, and the format in + which it is passed in job parameters (in an extra _trace_context parameter, as a dictionary). + + The actual writing to / reading from job parameters is done by other code. + + The span context parameters are: + * trace_id (an integer) + * span_id (an integer) + * trace_flags (an integer that represents a bitmask) - has a default of 0 + + It should meet the OpenTelemetry Propagators API specification. + https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/context/api-propagators.md + """ + + def extract( + self, + carrier: CarrierT, + context: Optional[Context] = None, + getter: Getter[CarrierT] = default_getter, + ) -> Context: + """Extracts SpanContext from the carrier. + + For this TextMapPropagator, the carrier is the parameters of an arq job.""" + if context is None: + context = Context() + + trace_id: int = _extract_identifier( + getter.get(carrier, ARQ_TRACE_ID_HEADER), INVALID_TRACE_ID + ) + span_id: int = _extract_identifier( + getter.get(carrier, ARQ_SPAN_ID_HEADER), INVALID_SPAN_ID + ) + trace_flags_str: Optional[str] = _extract_first_element( + getter.get(carrier, ARQ_TRACE_FLAGS_HEADER) + ) + + if trace_id == INVALID_TRACE_ID or span_id == INVALID_SPAN_ID: + return context + + if trace_flags_str is None: + trace_flags = TraceFlags.DEFAULT + else: + trace_flags = int(trace_flags_str) + + span_context = trace.SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=True, + trace_flags=trace.TraceFlags(trace_flags), + ) + return trace.set_span_in_context(trace.NonRecordingSpan(span_context), context) + + def inject( + self, + carrier: CarrierT, + context: Optional[Context] = None, + setter: Setter[CarrierT] = default_setter, + ) -> None: + """Injects SpanContext into the carrier. + + For this TextMapPropagator, the carrier is the parameters of an arq job.""" + span = trace.get_current_span(context) + span_context = span.get_span_context() + if span_context == trace.INVALID_SPAN_CONTEXT: + return + + setter.set(carrier, ARQ_TRACE_ID_HEADER, hex(span_context.trace_id)) + setter.set(carrier, ARQ_SPAN_ID_HEADER, hex(span_context.span_id)) + setter.set(carrier, ARQ_TRACE_FLAGS_HEADER, str(span_context.trace_flags)) + + @property + def fields(self) -> Set[str]: + return { + ARQ_TRACE_ID_HEADER, + ARQ_SPAN_ID_HEADER, + ARQ_TRACE_FLAGS_HEADER, + } + + +def _extract_first_element( + items: Optional[List[str]], default: Optional[str] = None +) -> Optional[str]: + if items is None: + return default + return next(iter(items), None) + + +def _extract_identifier(items: Optional[List[str]], default: int) -> int: + header = _extract_first_element(items) + if header is None: + return default + + try: + return int(header, 16) + except ValueError: + return default diff --git a/arq/opentelemetry/shared.py b/arq/opentelemetry/shared.py new file mode 100644 index 00000000..09291bfc --- /dev/null +++ b/arq/opentelemetry/shared.py @@ -0,0 +1,18 @@ +"""Methods shared between producers and consumers.""" + +from typing import Dict + + +def shared_messaging_attributes( + redis_host: str, redis_port: int +) -> Dict[str, str | int]: + """Get semantic attributes that apply to all spans.""" + # TODO: I think resource attributes should be included by the tracer already + # but I need to check + return { + "server.address": redis_host, + "server.port": redis_port, + "messaging.system": "redis", + "messaging.destination.kind": "queue", + "messaging.protocol": "RESP", + } diff --git a/pyproject.toml b/pyproject.toml index 79fc137d..cccb6619 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ 'redis[hiredis]>=4.2.0,<5', 'click>=8.0', ] -optional-dependencies = {watch = ['watchfiles>=0.16'] } +optional-dependencies = {watch = ['watchfiles>=0.16'], opentelemetry = ['opentelemetry-sdk>=1.20'] } dynamic = ['version'] [project.scripts]