From 25f4e5f9c2888cb2b5376d2f244ffd3d64eba316 Mon Sep 17 00:00:00 2001
From: "David H. Irving" <david.irving@noirlab.edu>
Date: Fri, 3 Jan 2025 14:57:43 -0700
Subject: [PATCH] Enable caching context for multiple dataset queries

Enable the caching context a few places where we do repeated dataset queries over the same collections.  This allows us to avoid repeated collection summary lookups.
---
 python/lsst/daf/butler/_query_all_datasets.py | 45 ++++++++--------
 .../butler/direct_butler/_direct_butler.py    | 25 ++++-----
 python/lsst/daf/butler/script/removeRuns.py   | 51 ++++++++++---------
 3 files changed, 65 insertions(+), 56 deletions(-)

diff --git a/python/lsst/daf/butler/_query_all_datasets.py b/python/lsst/daf/butler/_query_all_datasets.py
index 1432798dda..e8bed6c9e1 100644
--- a/python/lsst/daf/butler/_query_all_datasets.py
+++ b/python/lsst/daf/butler/_query_all_datasets.py
@@ -113,30 +113,33 @@ def query_all_datasets(
         raise InvalidQueryError("Can not use wildcards in collections when find_first=True")
 
     dataset_type_query = list(ensure_iterable(args.name))
-    dataset_type_collections = _filter_collections_and_dataset_types(
-        butler, args.collections, dataset_type_query
-    )
 
-    limit = args.limit
-    for dt, filtered_collections in sorted(dataset_type_collections.items()):
-        _LOG.debug("Querying dataset type %s", dt)
-        results = (
-            query.datasets(dt, filtered_collections, find_first=args.find_first)
-            .where(args.data_id, args.where, args.kwargs, bind=args.bind)
-            .limit(limit)
+    with butler.registry.caching_context():
+        dataset_type_collections = _filter_collections_and_dataset_types(
+            butler, args.collections, dataset_type_query
         )
-        if args.with_dimension_records:
-            results = results.with_dimension_records()
-
-        for page in results._iter_pages():
-            if limit is not None:
-                # Track how much of the limit has been used up by each query.
-                limit -= len(page)
-
-            yield DatasetsPage(dataset_type=dt, data=page)
 
-        if limit is not None and limit <= 0:
-            break
+        limit = args.limit
+        for dt, filtered_collections in sorted(dataset_type_collections.items()):
+            _LOG.debug("Querying dataset type %s", dt)
+            results = (
+                query.datasets(dt, filtered_collections, find_first=args.find_first)
+                .where(args.data_id, args.where, args.kwargs, bind=args.bind)
+                .limit(limit)
+            )
+            if args.with_dimension_records:
+                results = results.with_dimension_records()
+
+            for page in results._iter_pages():
+                if limit is not None:
+                    # Track how much of the limit has been used up by each
+                    # query.
+                    limit -= len(page)
+
+                yield DatasetsPage(dataset_type=dt, data=page)
+
+            if limit is not None and limit <= 0:
+                break
 
 
 def _filter_collections_and_dataset_types(
diff --git a/python/lsst/daf/butler/direct_butler/_direct_butler.py b/python/lsst/daf/butler/direct_butler/_direct_butler.py
index 69972d8856..4095fd2af5 100644
--- a/python/lsst/daf/butler/direct_butler/_direct_butler.py
+++ b/python/lsst/daf/butler/direct_butler/_direct_butler.py
@@ -1426,18 +1426,19 @@ def removeRuns(self, names: Iterable[str], unstore: bool = True) -> None:
         names = list(names)
         refs: list[DatasetRef] = []
         all_dataset_types = [dt.name for dt in self._registry.queryDatasetTypes(...)]
-        for name in names:
-            collectionType = self._registry.getCollectionType(name)
-            if collectionType is not CollectionType.RUN:
-                raise TypeError(f"The collection type of '{name}' is {collectionType.name}, not RUN.")
-            with self.query() as query:
-                # Work out the dataset types that are relevant.
-                collections_info = self.collections.query_info(name, include_summary=True)
-                filtered_dataset_types = self.collections._filter_dataset_types(
-                    all_dataset_types, collections_info
-                )
-                for dt in filtered_dataset_types:
-                    refs.extend(query.datasets(dt, collections=name))
+        with self._caching_context():
+            for name in names:
+                collectionType = self._registry.getCollectionType(name)
+                if collectionType is not CollectionType.RUN:
+                    raise TypeError(f"The collection type of '{name}' is {collectionType.name}, not RUN.")
+                with self.query() as query:
+                    # Work out the dataset types that are relevant.
+                    collections_info = self.collections.query_info(name, include_summary=True)
+                    filtered_dataset_types = self.collections._filter_dataset_types(
+                        all_dataset_types, collections_info
+                    )
+                    for dt in filtered_dataset_types:
+                        refs.extend(query.datasets(dt, collections=name))
         with self._datastore.transaction(), self._registry.transaction():
             if unstore:
                 self._datastore.trash(refs)
diff --git a/python/lsst/daf/butler/script/removeRuns.py b/python/lsst/daf/butler/script/removeRuns.py
index cf16d1812e..9c29172823 100644
--- a/python/lsst/daf/butler/script/removeRuns.py
+++ b/python/lsst/daf/butler/script/removeRuns.py
@@ -86,29 +86,34 @@ def _getCollectionInfo(
         The dataset types and and how many will be removed.
     """
     butler = Butler.from_config(repo)
-    try:
-        collections = butler.collections.query_info(
-            collection, CollectionType.RUN, include_chains=False, include_parents=True, include_summary=True
-        )
-    except MissingCollectionError:
-        # Act as if no collections matched.
-        collections = []
-    dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(...)]
-    dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections))
-
-    runs = []
-    datasets: dict[str, int] = defaultdict(int)
-    for collection_info in collections:
-        assert collection_info.type == CollectionType.RUN and collection_info.parents is not None
-        runs.append(RemoveRun(collection_info.name, list(collection_info.parents)))
-        with butler.query() as query:
-            for dt in dataset_types:
-                results = query.datasets(dt, collections=collection_info.name)
-                count = results.count(exact=False)
-                if count:
-                    datasets[dt] += count
-
-    return runs, {k: datasets[k] for k in sorted(datasets.keys())}
+    with butler.registry.caching_context():
+        try:
+            collections = butler.collections.query_info(
+                collection,
+                CollectionType.RUN,
+                include_chains=False,
+                include_parents=True,
+                include_summary=True,
+            )
+        except MissingCollectionError:
+            # Act as if no collections matched.
+            collections = []
+        dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(...)]
+        dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections))
+
+        runs = []
+        datasets: dict[str, int] = defaultdict(int)
+        for collection_info in collections:
+            assert collection_info.type == CollectionType.RUN and collection_info.parents is not None
+            runs.append(RemoveRun(collection_info.name, list(collection_info.parents)))
+            with butler.query() as query:
+                for dt in dataset_types:
+                    results = query.datasets(dt, collections=collection_info.name)
+                    count = results.count(exact=False)
+                    if count:
+                        datasets[dt] += count
+
+        return runs, {k: datasets[k] for k in sorted(datasets.keys())}
 
 
 def removeRuns(