Skip to content

Commit

Permalink
feat: add get_root
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski committed Feb 13, 2025
1 parent a41287f commit a039414
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
25 changes: 24 additions & 1 deletion src/pystac/stac_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

if TYPE_CHECKING:
from .catalog import Catalog
from .container import Container
from .io import Read, Write


Expand Down Expand Up @@ -82,7 +83,7 @@ def from_dict(
d: dict[str, Any],
*,
href: str | None = None,
root: Catalog | None = None, # TODO deprecation warning
root: Catalog | None = None,
migrate: bool = False,
preserve_dict: bool = True, # TODO deprecation warning
reader: Read | None = None,
Expand Down Expand Up @@ -127,6 +128,15 @@ def from_dict(
raise StacError(f"unknown type field: {type_value}")

if isinstance(stac_object, cls):
if root:
warnings.warn(
"The `root` argument is deprecated in PySTAC v2 and "
"will be removed in a future version. Prefer to use "
"`stac_object.set_link(Link.root(catalog))` "
"after object creation.",
FutureWarning,
)
stac_object.set_link(Link.root(root))
return stac_object
else:
raise PystacError(f"Expected {cls} but got a {type(stac_object)}")
Expand Down Expand Up @@ -244,6 +254,19 @@ def save_object(
else:
raise PystacError("cannot save an object without an href")

def get_root(self) -> Container | None:
"""Returns the container at this object's root link, if there is one."""
from .container import Container

if link := self.get_link(ROOT_REL):
stac_object = link.get_stac_object()
if isinstance(stac_object, Container):
return stac_object
else:
return None
else:
return None

def get_link(self, rel: str) -> Link | None:
return next((link for link in self._links if link.rel == rel), None)

Expand Down
11 changes: 10 additions & 1 deletion tests/v1/test_item.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from copy import deepcopy
from typing import Any

from pystac import Item
import pytest

from pystac import Catalog, Item

from . import utils

Expand All @@ -26,3 +28,10 @@ def test_to_from_dict(sample_item_dict: dict[str, Any]) -> None:
# assert that the parameter is preserved regardless of preserve_dict
Item.from_dict(param_dict, preserve_dict=False)
assert param_dict == sample_item_dict


def test_from_dict_set_root(sample_item_dict: dict[str, Any]) -> None:
catalog = Catalog(id="test", description="test desc")
with pytest.warns(FutureWarning):
item = Item.from_dict(sample_item_dict, root=catalog)
assert item.get_root() is catalog

0 comments on commit a039414

Please sign in to comment.