Skip to content

Commit

Permalink
Allow concurrency limits to handle changing the default limit (#27112)
Browse files Browse the repository at this point in the history
## Summary & Motivation
This PR adds a default flag to the storage layer, to help track which
limits inherit from the default versus limits that are explicitly set.
This allows us to change the default and helps us to change those limit
values.

Currently, the `concurrency > default_op_concurrency_limit` setting is
only used to initialize the number of slots for "unconfigured" keys.
Now, we keep track of the slots that are initialized from a default
value and update them if the default value has changed.

## How I Tested These Changes
BK

## Changelog
- Adds the ability to distinguish between explicitly set pool limits and
default-set pool limits. Requires a schema migration using `dagster
instance migrate`.
  • Loading branch information
prha authored Jan 23, 2025
1 parent 1a8e8f2 commit 276b168
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import time
from collections import defaultdict
from types import TracebackType
from typing import Optional
from typing import TYPE_CHECKING, Optional

from typing_extensions import Self

from dagster._core.instance import DagsterInstance
from dagster._core.storage.dagster_run import DagsterRun
from dagster._core.storage.tags import PRIORITY_TAG

if TYPE_CHECKING:
from dagster._core.storage.event_log.base import PoolLimit

INITIAL_INTERVAL_VALUE = 1
STEP_UP_BASE = 1.1
MAX_CONCURRENCY_CLAIM_BLOCKED_INTERVAL = 15
Expand All @@ -33,10 +36,11 @@ class InstanceConcurrencyContext:
def __init__(self, instance: DagsterInstance, dagster_run: DagsterRun):
self._instance = instance
self._run_id = dagster_run.run_id
self._global_concurrency_keys = None
self._pools: Optional[dict[str, PoolLimit]] = None
self._pending_timeouts = defaultdict(float)
self._pending_claim_counts = defaultdict(int)
self._pending_claims = set()
self._default_limit = instance.global_op_concurrency_default_limit
self._claims = set()
try:
self._run_priority = int(dagster_run.tags.get(PRIORITY_TAG, "0"))
Expand Down Expand Up @@ -68,38 +72,39 @@ def __exit__(

self._context_guard = False

@property
def global_concurrency_keys(self) -> set[str]:
# lazily load the global concurrency keys, to avoid the DB fetch for plans that do not
# have global concurrency limited keys
if self._global_concurrency_keys is None:
if not self._instance.event_log_storage.supports_global_concurrency_limits:
self._global_concurrency_keys = set()
else:
self._global_concurrency_keys = (
self._instance.event_log_storage.get_concurrency_keys()
)
def get_pool_info(self, pool_name: str) -> Optional["PoolLimit"]:
if not self._instance.event_log_storage.supports_global_concurrency_limits:
return None

if self._pools is None:
self._sync_pools()

return self._global_concurrency_keys
assert self._pools is not None
return self._pools.get(pool_name)

def _sync_global_concurrency_keys(self) -> None:
self._global_concurrency_keys = self._instance.event_log_storage.get_concurrency_keys()
def _sync_pools(self) -> None:
pool_limits = self._instance.event_log_storage.get_pool_limits()
self._pools = {pool.name: pool for pool in pool_limits}

def claim(self, concurrency_key: str, step_key: str, step_priority: int = 0):
if not self._instance.event_log_storage.supports_global_concurrency_limits:
return True

if concurrency_key not in self.global_concurrency_keys:
# The initialization call will be a no-op if the limit is set by another process,
# mitigating any race condition concerns
if not self._instance.event_log_storage.initialize_concurrency_limit_to_default(
pool_info = self.get_pool_info(concurrency_key)
if (pool_info is None and self._default_limit is not None) or (
pool_info is not None
and pool_info.from_default
and pool_info.limit != self._default_limit
):
self._instance.event_log_storage.initialize_concurrency_limit_to_default(
concurrency_key
):
# still default open if the limit table has not been initialized
return True
else:
# sync the global concurrency keys to ensure we have the latest
self._sync_global_concurrency_keys()
)
self._sync_pools()
# refetch the pool info
pool_info = self.get_pool_info(concurrency_key)

if pool_info is None:
return True

if step_key in self._pending_claims:
if time.time() > self._pending_timeouts[step_key]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,19 @@ def __init__(
os.getenv("DAGSTER_OP_CONCURRENCY_KEYS_ALLOTTED_FOR_STARTED_RUN_SECONDS", "5")
)

queued_pool_names = self._get_queued_pool_names(runs)
# initialize all the pool limits to the default if necessary
self._initialize_pool_limits(instance, queued_pool_names)

# fetch all the concurrency info for all of the runs at once, so we can claim in the correct
# priority order
self._fetch_concurrency_info(instance, runs)
self._fetch_concurrency_info(instance, queued_pool_names)

# fetch all the outstanding pools for in-progress runs
self._process_in_progress_runs(in_progress_run_records)

def _fetch_concurrency_info(self, instance: DagsterInstance, queued_runs: Sequence[DagsterRun]):
# fetch all the concurrency slot information for all the queued runs
all_pools = set()

configured_pools = instance.event_log_storage.get_concurrency_keys()
def _get_queued_pool_names(self, queued_runs: Sequence[DagsterRun]) -> set[str]:
queued_pool_names = set()
for run in queued_runs:
if run.run_op_concurrency:
# if using run granularity, consider all the concurrency keys required by the run
Expand All @@ -96,17 +97,32 @@ def _fetch_concurrency_info(self, instance: DagsterInstance, queued_runs: Sequen
if self._pool_granularity == PoolGranularity.OP
else run.run_op_concurrency.all_pools or []
)
all_pools.update(run_pools)

for pool in all_pools:
if pool is None:
queued_pool_names.update(run_pools)
return queued_pool_names

def _initialize_pool_limits(self, instance: DagsterInstance, pool_names: set[str]):
default_limit = instance.global_op_concurrency_default_limit
pool_limits_by_name = {
pool.name: pool for pool in instance.event_log_storage.get_pool_limits()
}
for pool_name in pool_names:
if pool_name is None:
continue

if pool not in configured_pools:
instance.event_log_storage.initialize_concurrency_limit_to_default(pool)
if (pool_name not in pool_limits_by_name and default_limit) or (
pool_name in pool_limits_by_name
and pool_limits_by_name[pool_name].from_default
and pool_limits_by_name[pool_name].limit != default_limit
):
instance.event_log_storage.initialize_concurrency_limit_to_default(pool_name)

def _fetch_concurrency_info(self, instance: DagsterInstance, pool_names: set[str]):
for pool_name in pool_names:
if pool_name is None:
continue

self._concurrency_info_by_key[pool] = instance.event_log_storage.get_concurrency_info(
pool
self._concurrency_info_by_key[pool_name] = (
instance.event_log_storage.get_concurrency_info(pool_name)
)

def _should_allocate_slots_for_in_progress_run(self, record: RunRecord):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""add column concurrency default limit
Revision ID: 7e2f3204cf8e
Revises: 6b7fb194ff9c
Create Date: 2025-01-13 15:19:19.331752
"""

import sqlalchemy as sa
from alembic import op
from dagster._core.storage.migration.utils import has_column, has_table

# revision identifiers, used by Alembic.
revision = "7e2f3204cf8e"
down_revision = "6b7fb194ff9c"
branch_labels = None
depends_on = None


def upgrade():
if has_table("concurrency_limits"):
if not has_column("concurrency_limits", "using_default_limit"):
op.add_column(
"concurrency_limits",
sa.Column("using_default_limit", sa.Boolean(), nullable=False, default=False),
)


def downgrade():
if has_table("concurrency_limits"):
if has_column("concurrency_limits", "using_default_limit"):
op.drop_column("concurrency_limits", "using_default_limit")
13 changes: 13 additions & 0 deletions python_modules/dagster/dagster/_core/storage/event_log/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from dagster._core.storage.partition_status_cache import get_and_update_asset_status_cache_value
from dagster._core.storage.sql import AlembicVersion
from dagster._core.storage.tags import MULTIDIMENSIONAL_PARTITION_PREFIX
from dagster._record import record
from dagster._utils import PrintFn
from dagster._utils.concurrency import ConcurrencyClaimStatus, ConcurrencyKeyInfo
from dagster._utils.warnings import deprecation_warning
Expand Down Expand Up @@ -185,6 +186,13 @@ class PlannedMaterializationInfo(NamedTuple):
run_id: str


@record
class PoolLimit:
name: str
limit: int
from_default: bool


class EventLogStorage(ABC, MayHaveInstanceWeakref[T_DagsterInstance]):
"""Abstract base class for storing structured event logs from pipeline runs.
Expand Down Expand Up @@ -560,6 +568,11 @@ def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo:
"""Get concurrency info for key."""
raise NotImplementedError()

@abstractmethod
def get_pool_limits(self) -> Sequence[PoolLimit]:
"""Get the set of concurrency limited keys and limits."""
raise NotImplementedError()

@abstractmethod
def claim_concurrency_slot(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
),
db.Column("concurrency_key", MySQLCompatabilityTypes.UniqueText, nullable=False, unique=True),
db.Column("limit", db.Integer, nullable=False),
db.Column("using_default_limit", db.Boolean, nullable=False, default=False),
db.Column("update_timestamp", db.DateTime, server_default=get_sql_current_timestamp()),
db.Column("create_timestamp", db.DateTime, server_default=get_sql_current_timestamp()),
)
Expand Down
Loading

0 comments on commit 276b168

Please sign in to comment.