Skip to content

Commit

Permalink
separate configured limit from slot count (#26803)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Currently, the event log storage returns a `ConcurrencyKeyInfo` to
represent the information about a particular pool. This info includes
the number of slots available, but it doesn't have enough information to
distinguish between the following cases:

1. no pool limit is explicitly set and there is no default value
2. no pool limit is explicitly set but there is a default value
3. a pool limit is explicitly set to 0

This PR adds more information to the returned object to help distinguish
between cases 1 and 2. Specifically, there's a new nullable `limit`
value, which returns the configured limit regardless of whether the pool
has been initialized to create the corresponding slot rows. A limit
value of `None` indicates an unconfigured pool, with no default limit.

## How I Tested These Changes
BK
  • Loading branch information
prha authored Jan 23, 2025
1 parent 276b168 commit 92ffc2d
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 66 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions js_modules/dagster-ui/packages/ui-core/src/graphql/types.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class GrapheneConcurrencyKeyInfo(graphene.ObjectType):
pendingStepRunIds = non_null_list(graphene.String)
assignedStepCount = graphene.NonNull(graphene.Int)
assignedStepRunIds = non_null_list(graphene.String)
limit = graphene.Int()
usingDefaultLimit = graphene.Boolean()

class Meta:
name = "ConcurrencyKeyInfo"
Expand Down Expand Up @@ -193,6 +195,12 @@ def resolve_assignedStepCount(self, graphene_info: ResolveInfo):
def resolve_assignedStepRunIds(self, graphene_info: ResolveInfo):
return list(self._get_concurrency_key_info(graphene_info).assigned_run_ids)

def resolve_limit(self, graphene_info: ResolveInfo):
return self._get_concurrency_key_info(graphene_info).limit

def resolve_usingDefaultLimit(self, graphene_info: ResolveInfo):
return self._get_concurrency_key_info(graphene_info).using_default_limit


class GrapheneRunQueueConfig(graphene.ObjectType):
maxConcurrentRuns = graphene.NonNull(graphene.Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,35 @@ def _sqlite_asset_instance():

return MarkedManager(_sqlite_asset_instance, [Marks.asset_aware_instance])

@staticmethod
def default_concurrency_sqlite_instance():
@contextmanager
def _sqlite_with_default_concurrency_instance():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
temp_dir=temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
},
"run_coordinator": {
"module": "dagster._core.run_coordinator.queued_run_coordinator",
"class": "QueuedRunCoordinator",
},
"concurrency": {
"default_op_concurrency_limit": 1,
},
},
) as instance:
yield instance

return MarkedManager(
_sqlite_with_default_concurrency_instance,
[Marks.sqlite_instance, Marks.queued_run_coordinator],
)


class EnvironmentManagers:
@staticmethod
Expand Down Expand Up @@ -556,6 +585,16 @@ def sqlite_with_default_run_launcher_code_server_cli_env(
test_id="sqlite_with_default_run_launcher_code_server_cli_env",
)

@staticmethod
def sqlite_with_default_concurrency_managed_grpc_env(
target=None, location_name="test_location"
):
return GraphQLContextVariant(
InstanceManagers.default_concurrency_sqlite_instance(),
EnvironmentManagers.managed_grpc(target, location_name),
test_id="sqlite_with_default_concurrency_managed_grpc_env",
)

@staticmethod
def postgres_with_default_run_launcher_managed_grpc_env(
target=None, location_name="test_location"
Expand Down Expand Up @@ -662,6 +701,7 @@ def all_variants():
GraphQLContextVariant.non_launchable_postgres_instance_managed_grpc_env(),
GraphQLContextVariant.non_launchable_postgres_instance_lazy_repository(),
GraphQLContextVariant.consolidated_sqlite_instance_managed_grpc_env(),
GraphQLContextVariant.sqlite_with_default_concurrency_managed_grpc_env(),
]

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,34 @@
"""

GET_CONCURRENCY_LIMITS_QUERY = """
query InstanceConcurrencyLimitsQuery {
query InstanceConcurrencyLimitsQuery($concurrencyKey: String!) {
instance {
concurrencyLimit(concurrencyKey: $concurrencyKey) {
concurrencyKey
slotCount
activeSlotCount
activeRunIds
claimedSlots {
runId
stepKey
}
pendingSteps {
runId
stepKey
enqueuedTimestamp
assignedTimestamp
priority
}
limit
usingDefaultLimit
}
}
}
"""

ALL_CONCURRENCY_LIMITS_QUERY = """
query AllConcurrencyLimitsQuery {
instance {
concurrencyLimits {
concurrencyKey
Expand All @@ -37,6 +64,8 @@
assignedTimestamp
priority
}
limit
usingDefaultLimit
}
}
}
Expand Down Expand Up @@ -67,6 +96,40 @@
)


def fetch_concurrency_limit(graphql_context, key: str):
results = execute_dagster_graphql(
graphql_context,
GET_CONCURRENCY_LIMITS_QUERY,
{"concurrencyKey": key},
)
assert results.data
assert "instance" in results.data
assert "concurrencyLimit" in results.data["instance"]
return results.data["instance"]["concurrencyLimit"]


def set_concurrency_limit(graphql_context, key: str, limit: int):
execute_dagster_graphql(
graphql_context,
SET_CONCURRENCY_LIMITS_MUTATION,
variables={
"concurrencyKey": key,
"limit": limit,
},
)


def fetch_all_concurrency_limits(graphql_context):
results = execute_dagster_graphql(
graphql_context,
ALL_CONCURRENCY_LIMITS_QUERY,
)
assert results.data
assert "instance" in results.data
assert "concurrencyLimits" in results.data["instance"]
return [limit for limit in results.data["instance"]["concurrencyLimits"]]


class TestInstanceSettings(BaseTestSuite):
def test_instance_settings(self, graphql_context):
results = execute_dagster_graphql(graphql_context, INSTANCE_QUERY)
Expand All @@ -81,78 +144,56 @@ def test_instance_settings(self, graphql_context):
def test_concurrency_limits(self, graphql_context):
instance = graphql_context.instance

def _fetch_limits(key: str):
results = execute_dagster_graphql(
graphql_context,
GET_CONCURRENCY_LIMITS_QUERY,
)
assert results.data
assert "instance" in results.data
assert "concurrencyLimits" in results.data["instance"]
limit_info = results.data["instance"]["concurrencyLimits"]
return next(iter([info for info in limit_info if info["concurrencyKey"] == key]), None)

def _set_limits(key: str, limit: int):
execute_dagster_graphql(
graphql_context,
SET_CONCURRENCY_LIMITS_MUTATION,
variables={
"concurrencyKey": key,
"limit": limit,
},
)

# default limits are empty
assert _fetch_limits("foo") is None
all_limits = fetch_all_concurrency_limits(graphql_context)
assert len(all_limits) == 0

# set a limit
_set_limits("foo", 10)
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 10 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 0 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"] == [] # pyright: ignore[reportOptionalSubscript]
set_concurrency_limit(graphql_context, "foo", 10)
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 10
assert foo["activeSlotCount"] == 0
assert foo["activeRunIds"] == []
assert foo["claimedSlots"] == []
assert foo["pendingSteps"] == []

# claim a slot
run_id = make_new_run_id()
instance.event_log_storage.claim_concurrency_slot("foo", run_id, "fake_step_key")
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 10 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [run_id] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] # pyright: ignore[reportOptionalSubscript]
assert len(foo["pendingSteps"]) == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["runId"] == run_id # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["priority"] == 0 # pyright: ignore[reportOptionalSubscript]

# set a new limit
_set_limits("foo", 5)
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 5 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [run_id] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] # pyright: ignore[reportOptionalSubscript]
assert len(foo["pendingSteps"]) == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["runId"] == run_id # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["priority"] == 0 # pyright: ignore[reportOptionalSubscript]

# free a slot
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 10
assert foo["activeSlotCount"] == 1
assert foo["activeRunIds"] == [run_id]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}]
assert len(foo["pendingSteps"]) == 1
assert foo["pendingSteps"][0]["runId"] == run_id
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key"
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None
assert foo["pendingSteps"][0]["priority"] == 0

set_concurrency_limit(graphql_context, "foo", 5)
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 5
assert foo["activeSlotCount"] == 1
assert foo["activeRunIds"] == [run_id]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}]
assert len(foo["pendingSteps"]) == 1
assert foo["pendingSteps"][0]["runId"] == run_id
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key"
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None
assert foo["pendingSteps"][0]["priority"] == 0

instance.event_log_storage.free_concurrency_slots_for_run(run_id)
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 5 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 0 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"] == [] # pyright: ignore[reportOptionalSubscript]
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 5
assert foo["activeSlotCount"] == 0
assert foo["activeRunIds"] == []
assert foo["claimedSlots"] == []
assert foo["pendingSteps"] == []

def test_concurrency_free(self, graphql_context):
storage = graphql_context.instance.event_log_storage
Expand Down Expand Up @@ -243,3 +284,33 @@ def test_concurrency_free_run(self, graphql_context):
assert foo_info.pending_run_ids == set()
assert foo_info.assigned_step_count == 1
assert foo_info.assigned_run_ids == {run_id_2}


ConcurrencyTestSuite: Any = make_graphql_context_test_suite(
context_variants=[
GraphQLContextVariant.sqlite_with_default_concurrency_managed_grpc_env(),
]
)


class TestConcurrencyInstanceSettings(ConcurrencyTestSuite):
def test_default_concurrency(self, graphql_context):
# no limits
all_limits = fetch_all_concurrency_limits(graphql_context)
assert len(all_limits) == 0

# default limits are empty
limit = fetch_concurrency_limit(graphql_context, "foo")
assert limit is not None
assert limit["slotCount"] == 0
assert limit["limit"] == 1
assert limit["usingDefaultLimit"]

# set a limit
set_concurrency_limit(graphql_context, "foo", 0)

limit = fetch_concurrency_limit(graphql_context, "foo")
assert limit is not None
assert limit["slotCount"] == 0
assert limit["limit"] == 0
assert not limit["usingDefaultLimit"]
Loading

1 comment on commit 92ffc2d

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deploy preview for dagit-core-storybook ready!

✅ Preview
https://dagit-core-storybook-oqv571c6q-elementl.vercel.app

Built with commit 92ffc2d.
This pull request is being automatically deployed with vercel-action

Please sign in to comment.