diff --git a/pystac/__init__.py b/pystac/__init__.py index 7dc8b13fb..c4f4c54e4 100644 --- a/pystac/__init__.py +++ b/pystac/__init__.py @@ -13,6 +13,14 @@ class STACError(Exception): pass +class STACTypeError(Exception): + """A STACTypeError is raised when encountering a representation of + a STAC entity that is not correct for the context; for example, if + a Catalog JSON was read in as an Item. + """ + pass + + from typing import Any, Dict, Optional from pystac.version import (__version__, get_stac_version, set_stac_version) # type:ignore from pystac.stac_io import STAC_IO @@ -21,8 +29,12 @@ class STACError(Exception): from pystac.media_type import MediaType # type:ignore from pystac.link import (Link, HIERARCHICAL_LINKS) # type:ignore from pystac.catalog import (Catalog, CatalogType) # type:ignore -from pystac.collection import (Collection, Extent, SpatialExtent, TemporalExtent, # type:ignore - Provider) # type:ignore +from pystac.collection import ( + Collection, # type:ignore + Extent, # type:ignore + SpatialExtent, # type:ignore + TemporalExtent, # type:ignore + Provider) # type:ignore from pystac.item import (Item, Asset, CommonMetadata) # type:ignore from pystac.serialization import stac_object_from_dict diff --git a/pystac/cache.py b/pystac/cache.py index fabe4b18d..c8ea88277 100644 --- a/pystac/cache.py +++ b/pystac/cache.py @@ -1,13 +1,15 @@ from collections import ChainMap from copy import copy -from pystac.collection import Collection -from typing import Any, Dict, List, Optional, Tuple, Union, cast -from pystac.stac_object import STACObject +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast -import pystac +import pystac as ps +if TYPE_CHECKING: + from pystac.stac_object import STACObject + from pystac.collection import Collection -def get_cache_key(stac_object: STACObject) -> Tuple[str, bool]: + +def get_cache_key(stac_object: "STACObject") -> Tuple[str, bool]: """Produce a cache key for the given STAC object. If a self href is set, use that as the cache key. @@ -56,16 +58,16 @@ class ResolvedObjectCache: ids_to_collections (Dict[str, Collection]): Map of collection IDs to collections. """ def __init__(self, - id_keys_to_objects: Optional[Dict[str, STACObject]] = None, - hrefs_to_objects: Optional[Dict[str, STACObject]] = None, - ids_to_collections: Dict[str, Collection] = None): + id_keys_to_objects: Optional[Dict[str, "STACObject"]] = None, + hrefs_to_objects: Optional[Dict[str, "STACObject"]] = None, + ids_to_collections: Dict[str, "Collection"] = None): self.id_keys_to_objects = id_keys_to_objects or {} self.hrefs_to_objects = hrefs_to_objects or {} self.ids_to_collections = ids_to_collections or {} self._collection_cache = None - def get_or_cache(self, obj: STACObject) -> STACObject: + def get_or_cache(self, obj: "STACObject") -> "STACObject": """Gets the STACObject that is the cached version of the given STACObject; or, if none exists, sets the cached object to the given object. @@ -91,7 +93,7 @@ def get_or_cache(self, obj: STACObject) -> STACObject: self.cache(obj) return obj - def get(self, obj: STACObject) -> Optional[STACObject]: + def get(self, obj: "STACObject") -> Optional["STACObject"]: """Get the cached object that has the same cache key as the given object. Args: @@ -107,7 +109,7 @@ def get(self, obj: STACObject) -> Optional[STACObject]: else: return self.id_keys_to_objects.get(key) - def get_by_href(self, href: str) -> Optional[STACObject]: + def get_by_href(self, href: str) -> Optional["STACObject"]: """Gets the cached object at href. Args: @@ -118,7 +120,7 @@ def get_by_href(self, href: str) -> Optional[STACObject]: """ return self.hrefs_to_objects.get(href) - def get_collection_by_id(self, id: str) -> Optional[Collection]: + def get_collection_by_id(self, id: str) -> Optional["Collection"]: """Retrieved a cached Collection by its ID. Args: @@ -130,7 +132,7 @@ def get_collection_by_id(self, id: str) -> Optional[Collection]: """ return self.ids_to_collections.get(id) - def cache(self, obj: STACObject) -> None: + def cache(self, obj: "STACObject") -> None: """Set the given object into the cache. Args: @@ -142,10 +144,10 @@ def cache(self, obj: STACObject) -> None: else: self.id_keys_to_objects[key] = obj - if isinstance(obj, Collection): + if isinstance(obj, ps.Collection): self.ids_to_collections[obj.id] = obj - def remove(self, obj: STACObject) -> None: + def remove(self, obj: "STACObject") -> None: """Removes any cached object that matches the given object's cache key. Args: @@ -158,10 +160,10 @@ def remove(self, obj: STACObject) -> None: else: self.id_keys_to_objects.pop(key, None) - if obj.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION: + if obj.STAC_OBJECT_TYPE == ps.STACObjectType.COLLECTION: self.id_keys_to_objects.pop(obj.id, None) - def __contains__(self, obj: STACObject) -> bool: + def __contains__(self, obj: "STACObject") -> bool: key, is_href = get_cache_key(obj) return key in self.hrefs_to_objects if is_href else key in self.id_keys_to_objects @@ -213,23 +215,25 @@ class CollectionCache: and will set Collection JSON that it reads in order to merge in common properties. """ def __init__(self, - cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None, - cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None): + cached_ids: Dict[str, Union["Collection", Dict[str, Any]]] = None, + cached_hrefs: Dict[str, Union["Collection", Dict[str, Any]]] = None): self.cached_ids = cached_ids or {} self.cached_hrefs = cached_hrefs or {} - def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]: + def get_by_id(self, collection_id: str) -> Optional[Union["Collection", Dict[str, Any]]]: return self.cached_ids.get(collection_id) - def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]: + def get_by_href(self, href: str) -> Optional[Union["Collection", Dict[str, Any]]]: return self.cached_hrefs.get(href) def contains_id(self, collection_id: str) -> bool: return collection_id in self.cached_ids - def cache(self, collection: Union[Collection, Dict[str, Any]], href: Optional[str] = None) -> None: + def cache(self, + collection: Union["Collection", Dict[str, Any]], + href: Optional[str] = None) -> None: """Caches a collection JSON.""" - if isinstance(collection, Collection): + if isinstance(collection, ps.Collection): self.cached_ids[collection.id] = collection else: self.cached_ids[collection['id']] = collection @@ -241,24 +245,24 @@ def cache(self, collection: Union[Collection, Dict[str, Any]], href: Optional[st class ResolvedObjectCollectionCache(CollectionCache): def __init__(self, resolved_object_cache: ResolvedObjectCache, - cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None, - cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None): + cached_ids: Dict[str, Union["Collection", Dict[str, Any]]] = None, + cached_hrefs: Dict[str, Union["Collection", Dict[str, Any]]] = None): super().__init__(cached_ids, cached_hrefs) self.resolved_object_cache = resolved_object_cache - def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]: + def get_by_id(self, collection_id: str) -> Optional[Union["Collection", Dict[str, Any]]]: result = self.resolved_object_cache.get_collection_by_id(collection_id) if result is None: return super().get_by_id(collection_id) else: return result - def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]: + def get_by_href(self, href: str) -> Optional[Union["Collection", Dict[str, Any]]]: result = self.resolved_object_cache.get_by_href(href) if result is None: return super().get_by_href(href) else: - return cast(Collection, result) + return cast(ps.Collection, result) def contains_id(self, collection_id: str) -> bool: return (self.resolved_object_cache.contains_collection_id(collection_id) diff --git a/pystac/catalog.py b/pystac/catalog.py index 89bbe4f4b..ac91592d9 100644 --- a/pystac/catalog.py +++ b/pystac/catalog.py @@ -1,16 +1,16 @@ import os from copy import deepcopy from enum import Enum -from pystac.item import Asset, Item -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast -import pystac -from pystac import STACError +import pystac as ps from pystac.stac_object import STACObject from pystac.layout import (BestPracticesLayoutStrategy, HrefLayoutStrategy, LayoutTemplate) from pystac.link import Link from pystac.cache import ResolvedObjectCache from pystac.utils import (is_absolute_href, make_absolute_href) +if TYPE_CHECKING: + from pystac.item import Asset as AssetType, Item as ItemType class CatalogType(str, Enum): @@ -108,7 +108,7 @@ class Catalog(STACObject): catalog_type (str): The catalog type. Defaults to ABSOLUTE_PUBLISHED """ - STAC_OBJECT_TYPE = pystac.STACObjectType.CATALOG + STAC_OBJECT_TYPE = ps.STACObjectType.CATALOG DEFAULT_FILE_NAME = "catalog.json" """Default file name that will be given to this STAC object in a canonical format.""" @@ -169,8 +169,8 @@ def add_child(self, """ # Prevent typo confusion - if isinstance(child, pystac.Item): - raise STACError('Cannot add item as child. Use add_item instead.') + if isinstance(child, ps.Item): + raise ps.STACError('Cannot add item as child. Use add_item instead.') if strategy is None: strategy = BestPracticesLayoutStrategy() @@ -198,7 +198,7 @@ def add_children(self, children: Iterable["Catalog"]) -> None: self.add_child(child) def add_item(self, - item: Item, + item: "ItemType", title: Optional[str] = None, strategy: Optional[HrefLayoutStrategy] = None) -> None: """Adds a link to an :class:`~pystac.Item`. @@ -211,8 +211,8 @@ def add_item(self, """ # Prevent typo confusion - if isinstance(item, pystac.Catalog): - raise STACError('Cannot add catalog as item. Use add_child instead.') + if isinstance(item, ps.Catalog): + raise ps.STACError('Cannot add catalog as item. Use add_child instead.') if strategy is None: strategy = BestPracticesLayoutStrategy() @@ -228,7 +228,7 @@ def add_item(self, self.add_link(Link.item(item, title=title)) - def add_items(self, items: Iterable[Item]) -> None: + def add_items(self, items: Iterable["ItemType"]) -> None: """Adds links to multiple :class:`~pystac.Item` s. This method will set each item's parent to this object, and their root to this Catalog's root. @@ -266,7 +266,7 @@ def get_children(self) -> Iterable["Catalog"]: Iterable[Catalog]: Generator of children who's parent is this catalog. """ - return map(lambda x: cast(Catalog, x), self.get_stac_objects('child')) + return map(lambda x: cast(ps.Catalog, x), self.get_stac_objects('child')) def get_child_links(self): """Return all child links of this catalog. @@ -293,7 +293,7 @@ def remove_child(self, child_id: str) -> None: Args: child_id (str): The ID of the child to remove. """ - new_links: List[Link] = [] + new_links: List[ps.Link] = [] root = self.get_root() for link in self.links: if link.rel != 'child': @@ -308,7 +308,7 @@ def remove_child(self, child_id: str) -> None: child.set_root(None) self.links = new_links - def get_item(self, id: str, recursive: bool = False) -> Optional[Item]: + def get_item(self, id: str, recursive: bool = False) -> Optional["ItemType"]: """Returns an item with a given ID. Args: @@ -328,13 +328,13 @@ def get_item(self, id: str, recursive: bool = False) -> Optional[Item]: return item return None - def get_items(self) -> Iterable[Item]: + def get_items(self) -> Iterable["ItemType"]: """Return all items of this catalog. Return: Iterable[Item]: Generator of items who's parent is this catalog. """ - return map(lambda x: cast(Item, x), self.get_stac_objects('item')) + return map(lambda x: cast(ps.Item, x), self.get_stac_objects('item')) def clear_items(self) -> "Catalog": """Removes all items from this catalog. @@ -344,7 +344,7 @@ def clear_items(self) -> "Catalog": """ for link in self.get_item_links(): if link.is_resolved(): - item = cast(Item, link.target) + item = cast(ps.Item, link.target) item.set_parent(None) item.set_root(None) @@ -357,14 +357,14 @@ def remove_item(self, item_id: str) -> None: Args: item_id (str): The ID of the item to remove. """ - new_links: List[Link] = [] + new_links: List[ps.Link] = [] root = self.get_root() for link in self.links: if link.rel != 'item': new_links.append(link) else: link.resolve_stac_object(root=root) - item = cast(Item, link.target) + item = cast(ps.Item, link.target) if item.id != item_id: new_links.append(link) else: @@ -372,7 +372,7 @@ def remove_item(self, item_id: str) -> None: item.set_root(None) self.links = new_links - def get_all_items(self) -> Iterable[Item]: + def get_all_items(self) -> Iterable["ItemType"]: """Get all items from this catalog and all subcatalogs. Will traverse any subcatalogs recursively. @@ -400,7 +400,7 @@ def to_dict(self, include_self_link: bool = True) -> Dict[str, Any]: d: Dict[str, Any] = { 'id': self.id, - 'stac_version': pystac.get_stac_version(), + 'stac_version': ps.get_stac_version(), 'description': self.description, 'links': [link.to_dict() for link in links] } @@ -498,7 +498,7 @@ def normalize_hrefs(self, root_href: str, strategy: Optional[HrefLayoutStrategy] if not is_absolute_href(root_href): root_href = make_absolute_href(root_href, os.getcwd(), start_is_dir=True) - def process_item(item: Item, _root_href: str) -> Callable[[], None]: + def process_item(item: "ItemType", _root_href: str) -> Callable[[], None]: item.resolve_links() new_self_href = strategy.get_href(item, _root_href) @@ -579,7 +579,7 @@ def generate_subcatalogs(self, item_links = [lk for lk in self.links if lk.rel == 'item'] for link in item_links: link.resolve_stac_object(root=self.get_root()) - item = cast(Item, link.target) + item = cast(ps.Item, link.target) item_parts = layout_template.get_template_values(item) id_iter = reversed(parent_ids) if all(['{}'.format(id) == next(id_iter, None) @@ -596,7 +596,7 @@ def generate_subcatalogs(self, if subcat is None: subcat_desc = 'Catalog of items from {} with {} of {}'.format( curr_parent.id, k, v) - subcat = pystac.Catalog(id=subcat_id, description=subcat_desc) + subcat = ps.Catalog(id=subcat_id, description=subcat_desc) curr_parent.add_child(subcat) result.append(subcat) curr_parent = subcat @@ -647,7 +647,8 @@ def save(self, catalog_type: CatalogType = None) -> None: for item_link in self.get_item_links(): if item_link.is_resolved(): - cast(Item, item_link.target).save_object(include_self_link=items_include_self_link) + cast(ps.Item, + item_link.target).save_object(include_self_link=items_include_self_link) include_self_link = False # include a self link if this is the root catalog or if ABSOLUTE_PUBLISHED catalog @@ -660,7 +661,7 @@ def save(self, catalog_type: CatalogType = None) -> None: self.catalog_type = catalog_type - def walk(self) -> Iterable[Tuple["Catalog", Iterable["Catalog"], Iterable[Item]]]: + def walk(self) -> Iterable[Tuple["Catalog", Iterable["Catalog"], Iterable["ItemType"]]]: """Walks through children and items of catalogs. For each catalog in the STAC's tree rooted at this catalog (including this catalog @@ -699,9 +700,9 @@ def validate_all(self) -> None: item.validate() def _object_links(self) -> List[str]: - return ['child', 'item'] + (pystac.STAC_EXTENSIONS.get_extended_object_links(self)) + return ['child', 'item'] + (ps.STAC_EXTENSIONS.get_extended_object_links(self)) - def map_items(self, item_mapper: Callable[[Item], Union[Item, List[Item]]]): + def map_items(self, item_mapper: Callable[["ItemType"], Union["ItemType", List["ItemType"]]]): """Creates a copy of a catalog, with each item passed through the item_mapper function. @@ -724,10 +725,10 @@ def process_catalog(catalog: Catalog): item_links: List[Link] = [] for item_link in catalog.get_item_links(): item_link.resolve_stac_object(root=self.get_root()) - mapped = item_mapper(cast(Item, item_link.target)) + mapped = item_mapper(cast(ps.Item, item_link.target)) if mapped is None: raise Exception('item_mapper cannot return None.') - if isinstance(mapped, Item): + if isinstance(mapped, ps.Item): item_link.target = mapped item_links.append(item_link) else: @@ -741,8 +742,9 @@ def process_catalog(catalog: Catalog): process_catalog(new_cat) return new_cat - def map_assets(self, asset_mapper: Callable[[str, Asset], Union[Asset, Tuple[str, Asset], - Dict[str, Asset]]]): + def map_assets(self, asset_mapper: Callable[[str, "AssetType"], + Union["AssetType", Tuple[str, "AssetType"], + Dict[str, "AssetType"]]]): """Creates a copy of a catalog, with each Asset for each Item passed through the asset_mapper function. @@ -756,12 +758,12 @@ def map_assets(self, asset_mapper: Callable[[str, Asset], Union[Asset, Tuple[str Catalog: A full copy of this catalog, with assets manipulated according to the asset_mapper function. """ - def apply_asset_mapper(tup: Tuple[str, Asset]): + def apply_asset_mapper(tup: Tuple[str, "AssetType"]): k, v = tup result = asset_mapper(k, v) if result is None: raise Exception('asset_mapper cannot return None.') - if isinstance(result, pystac.Asset): + if isinstance(result, ps.Asset): return [(k, result)] elif isinstance(result, tuple): return [result] @@ -771,7 +773,7 @@ def apply_asset_mapper(tup: Tuple[str, Asset]): raise Exception('asset_mapper must return a non-empty list') return assets - def item_mapper(item: Item): + def item_mapper(item: ps.Item): new_assets = [ x for result in map(apply_asset_mapper, item.assets.items()) for x in result ] @@ -804,7 +806,14 @@ def describe(self, include_hrefs: bool = False, _indent: int = 0): def from_dict(cls, d: Dict[str, Any], href: Optional[str] = None, - root: Optional["Catalog"] = None) -> "Catalog": + root: Optional["Catalog"] = None, + migrate: bool = False) -> "Catalog": + if migrate: + result = ps.read_dict(d, href=href, root=root) + if not isinstance(result, Catalog): + raise ps.STACError(f"{result} is not a Catalog") + return result + catalog_type = CatalogType.determine_type(d) d = deepcopy(d) @@ -834,3 +843,15 @@ def from_dict(cls, cat.add_link(Link.from_dict(link)) return cat + + def full_copy(self, + root: Optional["Catalog"] = None, + parent: Optional["Catalog"] = None) -> "Catalog": + return cast(Catalog, super().full_copy(root, parent)) + + @classmethod + def from_file(cls, href: str) -> "Catalog": + result = super().from_file(href) + if not isinstance(result, Catalog): + raise ps.STACTypeError(f"{result} is not a {Catalog}.") + return result diff --git a/pystac/collection.py b/pystac/collection.py index e9cb73b44..2f8fbf7a0 100644 --- a/pystac/collection.py +++ b/pystac/collection.py @@ -1,14 +1,19 @@ +from copy import (copy, deepcopy) from datetime import datetime as Datetime -from pystac.item import Item -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast + import dateutil.parser from dateutil import tz -from copy import (copy, deepcopy) + +import pystac as ps from pystac import (STACObjectType, CatalogType) from pystac.catalog import Catalog from pystac.link import Link from pystac.utils import datetime_to_str +if TYPE_CHECKING: + from pystac.item import Item as ItemType + class SpatialExtent: """Describes the spatial extent of a Collection. @@ -251,7 +256,7 @@ def from_dict(d: Dict[str, Any]) -> "Extent": TemporalExtent.from_dict(temporal_extent_dict)) @staticmethod - def from_items(items: Iterable[Item]) -> "Extent": + def from_items(items: Iterable["ItemType"]) -> "Extent": """Create an Extent based on the datetimes and bboxes of a list of items. Args: @@ -261,11 +266,11 @@ def from_items(items: Iterable[Item]) -> "Extent": Extent: An Extent that spatially and temporally covers all of the given items. """ - def extract_extent_props(item: Item) -> List[Any]: + def extract_extent_props(item: ps.Item) -> List[Any]: return item.bbox + [ item.datetime, item.common_metadata.start_datetime, item.common_metadata.end_datetime - ] # type:ignore + ] # type:ignore xmins, ymins, xmaxs, ymaxs, datetimes, starts, ends = zip(*map(extract_extent_props, items)) @@ -418,23 +423,21 @@ def __init__(self, license: str = 'proprietary', keywords: Optional[List[str]] = None, providers: Optional[List[Provider]] = None, - properties: Optional[Dict[str, Any]] = None, summaries: Optional[Dict[str, Any]] = None): super(Collection, self).__init__(id, description, title, stac_extensions, extra_fields, href, catalog_type or CatalogType.ABSOLUTE_PUBLISHED) self.extent = extent self.license = license - self.stac_extensions = stac_extensions + self.stac_extensions: List[str] = stac_extensions or [] self.keywords = keywords self.providers = providers - self.properties = properties self.summaries = summaries def __repr__(self) -> str: return ''.format(self.id) - def add_item(self, item: Item, title: Optional[str] = None) -> None: + def add_item(self, item: "ItemType", title: Optional[str] = None) -> None: super(Collection, self).add_item(item, title) item.set_collection(self) @@ -448,8 +451,6 @@ def to_dict(self, include_self_link: bool = True) -> Dict[str, Any]: d['keywords'] = self.keywords if self.providers is not None: d['providers'] = list(map(lambda x: x.to_dict(), self.providers)) - if self.properties is not None: - d['properties'] = self.properties if self.summaries is not None: d['summaries'] = self.summaries @@ -466,7 +467,6 @@ def clone(self): license=self.license, keywords=self.keywords, providers=self.providers, - properties=self.properties, summaries=self.summaries) clone._resolved_objects.cache(clone) @@ -488,7 +488,14 @@ def clone(self): def from_dict(cls, d: Dict[str, Any], href: Optional[str] = None, - root: Optional[Catalog] = None) -> "Collection": + root: Optional[Catalog] = None, + migrate: bool = False) -> "Collection": + if migrate: + result = ps.read_dict(d, href=href, root=root) + if not isinstance(result, Collection): + raise ps.STACError(f"{result} is not a Catalog") + return result + catalog_type = CatalogType.determine_type(d) d = deepcopy(d) @@ -502,7 +509,6 @@ def from_dict(cls, providers = d.get('providers') if providers is not None: providers = list(map(lambda x: Provider.from_dict(x), providers)) - properties = d.get('properties') summaries = d.get('summaries') links = d.pop('links') @@ -517,7 +523,6 @@ def from_dict(cls, license=license, keywords=keywords, providers=providers, - properties=properties, summaries=summaries, href=href, catalog_type=catalog_type) @@ -537,3 +542,15 @@ def update_extent_from_items(self): Update datetime and bbox based on all items to a single bbox and time window. """ self.extent = Extent.from_items(self.get_all_items()) + + def full_copy(self, + root: Optional["Catalog"] = None, + parent: Optional["Catalog"] = None) -> "Collection": + return cast(Collection, super().full_copy(root, parent)) + + @classmethod + def from_file(cls, href: str) -> "Collection": + result = super().from_file(href) + if not isinstance(result, Collection): + raise ps.STACTypeError(f"{result} is not a {Collection}.") + return result diff --git a/pystac/extensions/base.py b/pystac/extensions/base.py index a2fd8513d..e647464f6 100644 --- a/pystac/extensions/base.py +++ b/pystac/extensions/base.py @@ -1,16 +1,18 @@ from abc import (ABC, abstractmethod) -from typing import Any, Iterable, List, Optional, Type -from pystac.stac_object import STACObject +from typing import Any, Iterable, List, Optional, TYPE_CHECKING, Type from pystac.catalog import Catalog from pystac.collection import Collection from pystac.item import Asset, Item from pystac.extensions import ExtensionError +if TYPE_CHECKING: + from pystac.stac_object import STACObject + class STACObjectExtension(ABC): @classmethod - def _from_object(cls, stac_object: STACObject) -> "STACObjectExtension": + def _from_object(cls, stac_object: "STACObject") -> "STACObjectExtension": ... @classmethod @@ -19,7 +21,7 @@ def _object_links(cls) -> List[str]: raise NotImplementedError("_object_links") @classmethod - def enable_extension(cls, stac_object: STACObject) -> None: + def enable_extension(cls, stac_object: "STACObject") -> None: """Enables the extension for the given stac_object. Child classes can choose to override this method in order to modify the stac_object when an extension is enabled. @@ -39,7 +41,7 @@ class ExtendedObject: stac_object_class: The STAC object class that is being extended. extension_class: The class of the extension, e.g. LabelItemExt """ - def __init__(self, stac_object_class: Type[STACObject], + def __init__(self, stac_object_class: Type["STACObject"], extension_class: Type[STACObjectExtension]): if stac_object_class is Catalog: if not issubclass(extension_class, CatalogExtension): @@ -74,7 +76,7 @@ def __init__(self, extension_id: str, extended_objects: List[ExtendedObject]): class CatalogExtension(STACObjectExtension): @classmethod - def _from_object(cls, stac_object: STACObject) -> "CatalogExtension": + def _from_object(cls, stac_object: "STACObject") -> "CatalogExtension": if not isinstance(stac_object, Catalog): raise ValueError(f"This extension applies to Catalogs, not {cls}") return cls.from_catalog(stac_object) @@ -87,7 +89,7 @@ def from_catalog(cls, catalog: Catalog) -> "CatalogExtension": class CollectionExtension(STACObjectExtension): @classmethod - def _from_object(cls, stac_object: STACObject) -> "CollectionExtension": + def _from_object(cls, stac_object: "STACObject") -> "CollectionExtension": if not isinstance(stac_object, Collection): raise ValueError(f"This extension applies to Collections, not {cls}") return cls.from_collection(stac_object) @@ -102,7 +104,7 @@ class ItemExtension(STACObjectExtension): item: Item @classmethod - def _from_object(cls, stac_object: STACObject) -> "ItemExtension": + def _from_object(cls, stac_object: "STACObject") -> "ItemExtension": if not isinstance(stac_object, Item): raise ValueError(f"This extension applies to Items, not {cls}") return cls.from_item(stac_object) @@ -174,7 +176,7 @@ def remove_extension(self, extension_id: str) -> None: def get_extension_class( self, extension_id: str, - stac_object_class: Type[STACObject]) -> Optional[Type[STACObjectExtension]]: + stac_object_class: Type["STACObject"]) -> Optional[Type[STACObjectExtension]]: """Gets the extension class for a given stac object class if one exists, otherwise returns None """ @@ -209,7 +211,7 @@ def get_extension_class( return ext_class - def extend_object(self, extension_id: str, stac_object: STACObject) -> STACObjectExtension: + def extend_object(self, extension_id: str, stac_object: "STACObject") -> STACObjectExtension: """Returns the extension object for the given STACObject and the given extension_id """ @@ -221,7 +223,7 @@ def extend_object(self, extension_id: str, stac_object: STACObject) -> STACObjec return ext_class._from_object(stac_object) - def get_extended_object_links(self, stac_object: STACObject) -> List[str]: + def get_extended_object_links(self, stac_object: "STACObject") -> List[str]: if stac_object.stac_extensions is None: return [] return [ @@ -231,7 +233,7 @@ def get_extended_object_links(self, stac_object: STACObject) -> List[str]: for link_rel in e_obj.extension_class._object_links() ] - def can_extend(self, extension_id: str, stac_object_class: Type[STACObject]) -> bool: + def can_extend(self, extension_id: str, stac_object_class: Type["STACObject"]) -> bool: """Returns True if the extension can extend the given object type. Args: @@ -254,7 +256,7 @@ def can_extend(self, extension_id: str, stac_object_class: Type[STACObject]) -> if issubclass(stac_object_class, e.stac_object_class) ]) - def enable_extension(self, extension_id: str, stac_object: STACObject) -> None: + def enable_extension(self, extension_id: str, stac_object: "STACObject") -> None: """Enables the extension for the given object. This will at least ensure the extension ID is in the object's "stac_extensions" diff --git a/pystac/extensions/file.py b/pystac/extensions/file.py index 35cf94b00..c86861035 100644 --- a/pystac/extensions/file.py +++ b/pystac/extensions/file.py @@ -126,7 +126,7 @@ def size(self) -> Optional[int]: def size(self, v: Optional[int]) -> None: self.set_size(v) - def get_size(self, asset: Optional[Asset]=None) -> Optional[int]: + def get_size(self, asset: Optional[Asset] = None) -> Optional[int]: """Gets an Item or an Asset file size. If an Asset is supplied and the Item property exists on the Asset, @@ -140,7 +140,7 @@ def get_size(self, asset: Optional[Asset]=None) -> Optional[int]: else: return asset.properties.get('file:size') - def set_size(self, size: Optional[int], asset: Optional[Asset]=None) -> None: + def set_size(self, size: Optional[int], asset: Optional[Asset] = None) -> None: """Set an Item or an Asset size. If an Asset is supplied, sets the property on the Asset. @@ -154,10 +154,10 @@ def nodata(self) -> Optional[List[Any]]: return self.get_nodata() @nodata.setter - def nodata(self, v: Optional[List[Any]])-> None: + def nodata(self, v: Optional[List[Any]]) -> None: self.set_nodata(v) - def get_nodata(self, asset: Optional[Asset]=None) -> Optional[List[Any]]: + def get_nodata(self, asset: Optional[Asset] = None) -> Optional[List[Any]]: """Gets an Item or an Asset nodata values. If an Asset is supplied and the Item property exists on the Asset, @@ -171,7 +171,7 @@ def get_nodata(self, asset: Optional[Asset]=None) -> Optional[List[Any]]: else: return asset.properties.get('file:nodata') - def set_nodata(self, nodata: Optional[List[Any]], asset: Optional[Asset]=None) -> None: + def set_nodata(self, nodata: Optional[List[Any]], asset: Optional[Asset] = None) -> None: """Set an Item or an Asset nodata values. If an Asset is supplied, sets the property on the Asset. @@ -192,7 +192,7 @@ def checksum(self) -> Optional[str]: def checksum(self, v: Optional[str]) -> None: self.set_checksum(v) - def get_checksum(self, asset: Optional[Asset]=None) -> Optional[str]: + def get_checksum(self, asset: Optional[Asset] = None) -> Optional[str]: """Gets an Item or an Asset checksum. If an Asset is supplied and the Item property exists on the Asset, @@ -203,7 +203,7 @@ def get_checksum(self, asset: Optional[Asset]=None) -> Optional[str]: else: return asset.properties.get('file:checksum') - def set_checksum(self, checksum: Optional[str], asset: Optional[Asset]=None) -> None: + def set_checksum(self, checksum: Optional[str], asset: Optional[Asset] = None) -> None: """Set an Item or an Asset checksum. If an Asset is supplied, sets the property on the Asset. diff --git a/pystac/extensions/label.py b/pystac/extensions/label.py index 6abfa3e2a..92d33b123 100644 --- a/pystac/extensions/label.py +++ b/pystac/extensions/label.py @@ -257,7 +257,7 @@ def __init__(self, properties: Dict[str, Any]): self.properties = properties def apply(self, - property_key: str, + property_key: Optional[str], counts: Optional[List[LabelCount]] = None, statistics: Optional[List[LabelStatistics]] = None): """Sets the properties for this LabelOverview. @@ -266,7 +266,9 @@ def apply(self, at least one is required. Args: - property_key (str): The property key within the asset corresponding to class labels. + property_key (str): The property key within the asset corresponding to class labels + that these counts or statistics are referencing. If the label data is raster data, + this should be None. counts: Optional list of LabelCounts containing counts for categorical data. statistics: Optional list of statistics containing statistics for @@ -278,7 +280,7 @@ def apply(self, @classmethod def create(cls, - property_key: str, + property_key: Optional[str], counts: Optional[List[LabelCount]] = None, statistics: Optional[List[LabelStatistics]] = None) -> "LabelOverview": """Creates a new LabelOverview. @@ -298,19 +300,16 @@ def create(cls, return x @property - def property_key(self) -> str: + def property_key(self) -> Optional[str]: """Get or sets the property key within the asset corresponding to class labels. Returns: str """ - result = self.properties.get('property_key') - if result is None: - raise STACError(f"Label overview has no property_key: {self.properties}") - return result + return self.properties.get('property_key') @property_key.setter - def property_key(self, v: str) -> None: + def property_key(self, v: Optional[str]) -> None: self.properties['property_key'] = v @property @@ -651,7 +650,11 @@ def get_sources(self) -> Iterable[Item]: """ return map(lambda x: cast(Item, x), self.item.get_stac_objects('source')) - def add_labels(self, href: str, title: Optional[str]=None, media_type: Optional[str]=None, properties: Optional[Dict[str, Any]]=None): + def add_labels(self, + href: str, + title: Optional[str] = None, + media_type: Optional[str] = None, + properties: Optional[Dict[str, Any]] = None): """Adds a label asset to this LabelItem. Args: @@ -667,7 +670,10 @@ def add_labels(self, href: str, title: Optional[str]=None, media_type: Optional[ self.item.add_asset( "labels", Asset(href=href, title=title, media_type=media_type, properties=properties)) - def add_geojson_labels(self, href: str, title: Optional[str]=None, properties:Optional[Dict[str, Any]]=None): + def add_geojson_labels(self, + href: str, + title: Optional[str] = None, + properties: Optional[Dict[str, Any]] = None): """Adds a GeoJSON label asset to this LabelItem. Args: diff --git a/pystac/extensions/projection.py b/pystac/extensions/projection.py index 9984796b3..39e28af83 100644 --- a/pystac/extensions/projection.py +++ b/pystac/extensions/projection.py @@ -314,7 +314,9 @@ def get_centroid(self, asset: Optional[Asset] = None) -> Optional[Dict[str, floa else: return asset.properties.get('proj:centroid') - def set_centroid(self, centroid: Optional[Dict[str, float]], asset: Optional[Asset] = None) -> None: + def set_centroid(self, + centroid: Optional[Dict[str, float]], + asset: Optional[Asset] = None) -> None: """Set an Item or an Asset centroid. If an Asset is supplied, sets the property on the Asset. @@ -397,7 +399,9 @@ def get_transform(self, asset: Optional[Asset] = None) -> Optional[List[float]]: else: return asset.properties.get('proj:transform') - def set_transform(self, transform: Optional[List[float]], asset: Optional[Asset] = None) -> None: + def set_transform(self, + transform: Optional[List[float]], + asset: Optional[Asset] = None) -> None: """Set an Item or an Asset transform. If an Asset is supplied, sets the property on the Asset. diff --git a/pystac/extensions/sar.py b/pystac/extensions/sar.py index 56c40a071..97f81b297 100644 --- a/pystac/extensions/sar.py +++ b/pystac/extensions/sar.py @@ -154,9 +154,8 @@ def instrument_mode(self) -> str: """ result = self.item.properties.get(INSTRUMENT_MODE) if result is None: - raise STACError( - f"Item with sar extension does not have property {INSTRUMENT_MODE}, id {self.item.id}" - ) + raise STACError(f"Item with sar extension does not have property {INSTRUMENT_MODE}, " + f"id {self.item.id}") return result @instrument_mode.setter @@ -172,9 +171,8 @@ def frequency_band(self) -> FrequencyBand: """ result = self.item.properties.get(FREQUENCY_BAND) if result is None: - raise STACError( - f"Item with sar extension does not have property {FREQUENCY_BAND}, id {self.item.id}" - ) + raise STACError(f"Item with sar extension does not have property {FREQUENCY_BAND}, " + f"id {self.item.id}") return FrequencyBand(result) @frequency_band.setter diff --git a/pystac/extensions/single_file_stac.py b/pystac/extensions/single_file_stac.py index f940e67b4..ecc211a2a 100644 --- a/pystac/extensions/single_file_stac.py +++ b/pystac/extensions/single_file_stac.py @@ -2,7 +2,6 @@ from typing import List, Optional, cast from pystac.catalog import Catalog -import pystac from pystac import (STACError, Extensions) from pystac.collection import Collection from pystac.extensions.base import (CatalogExtension, ExtensionDefinition, ExtendedObject) diff --git a/pystac/item.py b/pystac/item.py index affa49ce2..d0c881c4b 100644 --- a/pystac/item.py +++ b/pystac/item.py @@ -5,7 +5,7 @@ import dateutil.parser -import pystac +import pystac as ps from pystac import (STACError, STACObjectType) from pystac.link import Link from pystac.stac_object import STACObject @@ -747,8 +747,8 @@ class Item(STACObject): def __init__(self, id: str, - geometry: Dict[str, Any], - bbox: List[float], + geometry: Optional[Dict[str, Any]], + bbox: Optional[List[float]], datetime: Optional[Datetime], properties: Dict[str, Any], stac_extensions: Optional[List[str]] = None, @@ -913,7 +913,7 @@ def make_asset_hrefs_absolute(self) -> "Item": return self - def set_collection(self, collection: Collection) -> "Item": + def set_collection(self, collection: Optional[Collection]) -> "Item": """Set the collection of this item. This method will replace any existing Collection link and attribute for @@ -961,7 +961,7 @@ def to_dict(self, include_self_link: bool = True) -> Dict[str, Any]: d: Dict[str, Any] = { 'type': 'Feature', - 'stac_version': pystac.get_stac_version(), + 'stac_version': ps.get_stac_version(), 'id': self.id, 'properties': self.properties, 'geometry': self.geometry, @@ -1000,18 +1000,25 @@ def clone(self) -> "Item": return clone def _object_links(self) -> List[str]: - return ['collection'] + (pystac.STAC_EXTENSIONS.get_extended_object_links(self)) + return ['collection'] + (ps.STAC_EXTENSIONS.get_extended_object_links(self)) @classmethod def from_dict(cls, d: Dict[str, Any], href: Optional[str] = None, - root: Optional[Catalog] = None) -> "Item": + root: Optional[Catalog] = None, + migrate: bool = False) -> "Item": + if migrate: + result = ps.read_dict(d, href=href, root=root) + if not isinstance(result, Item): + raise ps.STACError(f"{result} is not a Catalog") + return result + d = deepcopy(d) id = d.pop('id') geometry = d.pop('geometry') properties = d.pop('properties') - bbox = d.pop('bbox') # TODO: Ensure this shouldn't pop with a default + bbox = d.pop('bbox', None) stac_extensions = d.get('stac_extensions') collection_id = d.pop('collection', None) @@ -1056,3 +1063,15 @@ def common_metadata(self) -> CommonMetadata: CommonMetada: contains all common metadata fields in the items properties """ return CommonMetadata(self.properties) + + def full_copy(self, + root: Optional["Catalog"] = None, + parent: Optional["Catalog"] = None) -> "Item": + return cast(Item, super().full_copy(root, parent)) + + @classmethod + def from_file(cls, href: str) -> "Item": + result = super().from_file(href) + if not isinstance(result, Item): + raise ps.STACTypeError(f"{result} is not a {Item}.") + return result diff --git a/pystac/layout.py b/pystac/layout.py index d97aff543..01cc9ef99 100644 --- a/pystac/layout.py +++ b/pystac/layout.py @@ -1,14 +1,16 @@ from abc import (abstractmethod, ABC) from collections import OrderedDict import os -from pystac.collection import Collection from string import Formatter -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union -import pystac -from pystac.catalog import Catalog -from pystac.item import Item -from pystac.stac_object import STACObject +import pystac as ps + +if TYPE_CHECKING: + from pystac.stac_object import STACObject + from pystac.catalog import Catalog + from pystac.collection import Collection + from pystac.item import Item class TemplateError(Exception): @@ -85,9 +87,9 @@ def __init__(self, template: str, defaults: Dict[str, str] = None) -> None: template_vars.append(v) self.template_vars = template_vars - def _get_template_value(self, stac_object: STACObject, template_var: str) -> Any: + def _get_template_value(self, stac_object: "STACObject", template_var: str) -> Any: if template_var in self.ITEM_TEMPLATE_VARS: - if isinstance(stac_object, Item): + if isinstance(stac_object, ps.Item): # Datetime dt = stac_object.datetime if dt is None: @@ -119,25 +121,29 @@ def _get_template_value(self, stac_object: STACObject, template_var: str) -> Any # Allow dot-notation properties for arbitrary object values. props = template_var.split('.') - prop_location = None + prop_source: Optional[Union[STACObject, Dict[str, Any]]] = None error = TemplateError('Cannot find property {} on {} for template {}'.format( template_var, stac_object, self.template)) try: + if hasattr(stac_object, props[0]): - prop_location = stac_object - elif hasattr(stac_object, "properties"): + prop_source = stac_object + + if prop_source is None and hasattr(stac_object, "properties"): obj_props: Optional[Dict[str, Any]] = stac_object.properties # type:ignore if obj_props is not None and props[0] in obj_props: - prop_location = obj_props - elif hasattr(stac_object, "extra_fields"): + prop_source = obj_props + + if prop_source is None and hasattr(stac_object, "extra_fields"): extra_fields: Optional[Dict[str, Any]] = stac_object.extra_fields # type:ignore if extra_fields is not None and props[0] in extra_fields: - prop_location = extra_fields - else: + prop_source = extra_fields + + if prop_source is None: raise error - v: Any = prop_location + v: Any = prop_source for prop in template_var.split('.'): if type(v) is dict: if prop not in v: @@ -154,7 +160,7 @@ def _get_template_value(self, stac_object: STACObject, template_var: str) -> Any return v - def get_template_values(self, stac_object: STACObject) -> Dict[str, Any]: + def get_template_values(self, stac_object: "STACObject") -> Dict[str, Any]: """Gets a dictionary of template variables to values derived from the given stac_object. If the template vars cannot be found in the stac object, and defaults was supplied to this template, a default @@ -177,7 +183,7 @@ def get_template_values(self, stac_object: STACObject) -> Dict[str, Any]: return OrderedDict([(k, self._get_template_value(stac_object, k)) for k in self.template_vars]) - def substitute(self, stac_object: STACObject) -> str: + def substitute(self, stac_object: "STACObject") -> str: """Substitutes the values derived from :meth:`~pystac.layout.LayoutTemplate.get_template_values` into the template string for this template. @@ -206,26 +212,26 @@ def substitute(self, stac_object: STACObject) -> str: class HrefLayoutStrategy(ABC): """Base class for HREF Layout strategies.""" - def get_href(self, stac_object: STACObject, parent_dir: str, is_root: bool = False) -> str: - if isinstance(stac_object, Catalog): - return self.get_catalog_href(stac_object, parent_dir, is_root) - elif isinstance(stac_object, Collection): - return self.get_collection_href(stac_object, parent_dir, is_root) - elif isinstance(stac_object, Item): + def get_href(self, stac_object: "STACObject", parent_dir: str, is_root: bool = False) -> str: + if isinstance(stac_object, ps.Item): return self.get_item_href(stac_object, parent_dir) + elif isinstance(stac_object, ps.Collection): + return self.get_collection_href(stac_object, parent_dir, is_root) + elif isinstance(stac_object, ps.Catalog): + return self.get_catalog_href(stac_object, parent_dir, is_root) else: - raise pystac.STACError('Unknown STAC object type {}'.format(stac_object)) + raise ps.STACError('Unknown STAC object type {}'.format(stac_object)) @abstractmethod - def get_catalog_href(self, cat: Catalog, parent_dir: str, is_root: bool) -> str: + def get_catalog_href(self, cat: "Catalog", parent_dir: str, is_root: bool) -> str: pass @abstractmethod - def get_collection_href(self, col: Collection, parent_dir: str, is_root: bool) -> str: + def get_collection_href(self, col: "Collection", parent_dir: str, is_root: bool) -> str: pass @abstractmethod - def get_item_href(self, item: Item, parent_dir: str) -> str: + def get_item_href(self, item: "Item", parent_dir: str) -> str: pass @@ -250,9 +256,9 @@ class CustomLayoutStrategy(HrefLayoutStrategy): :class:`~pystac.layout.BestPracticesLayoutStrategy` """ def __init__(self, - catalog_func: Optional[Callable[[Catalog, str, bool], str]] = None, - collection_func: Optional[Callable[[Collection, str, bool], str]] = None, - item_func: Optional[Callable[[Item, str], str]] = None, + catalog_func: Optional[Callable[["Catalog", str, bool], str]] = None, + collection_func: Optional[Callable[["Collection", str, bool], str]] = None, + item_func: Optional[Callable[["Item", str], str]] = None, fallback_strategy: Optional[HrefLayoutStrategy] = None): self.item_func = item_func self.collection_func = collection_func @@ -261,21 +267,21 @@ def __init__(self, fallback_strategy = BestPracticesLayoutStrategy() self.fallback_strategy: HrefLayoutStrategy = fallback_strategy - def get_catalog_href(self, cat: Catalog, parent_dir: str, is_root: bool) -> str: + def get_catalog_href(self, cat: "Catalog", parent_dir: str, is_root: bool) -> str: if self.catalog_func is not None: result = self.catalog_func(cat, parent_dir, is_root) if result is not None: return result return self.fallback_strategy.get_catalog_href(cat, parent_dir, is_root) - def get_collection_href(self, col: Collection, parent_dir: str, is_root: bool) -> str: + def get_collection_href(self, col: "Collection", parent_dir: str, is_root: bool) -> str: if self.collection_func is not None: result = self.collection_func(col, parent_dir, is_root) if result is not None: return result return self.fallback_strategy.get_collection_href(col, parent_dir, is_root) - def get_item_href(self, item: Item, parent_dir: str) -> str: + def get_item_href(self, item: "Item", parent_dir: str) -> str: if self.item_func is not None: result = self.item_func(item, parent_dir) if result is not None: @@ -322,7 +328,7 @@ def __init__(self, fallback_strategy = BestPracticesLayoutStrategy() self.fallback_strategy: HrefLayoutStrategy = fallback_strategy - def get_catalog_href(self, cat: Catalog, parent_dir: str, is_root: bool) -> str: + def get_catalog_href(self, cat: "Catalog", parent_dir: str, is_root: bool) -> str: if is_root or self.catalog_template is None: return self.fallback_strategy.get_catalog_href(cat, parent_dir, is_root) else: @@ -332,7 +338,7 @@ def get_catalog_href(self, cat: Catalog, parent_dir: str, is_root: bool) -> str: return os.path.join(parent_dir, template_path) - def get_collection_href(self, col: Collection, parent_dir: str, is_root: bool) -> str: + def get_collection_href(self, col: "Collection", parent_dir: str, is_root: bool) -> str: if is_root or self.collection_template is None: return self.fallback_strategy.get_collection_href(col, parent_dir, is_root) else: @@ -342,7 +348,7 @@ def get_collection_href(self, col: Collection, parent_dir: str, is_root: bool) - return os.path.join(parent_dir, template_path) - def get_item_href(self, item: Item, parent_dir: str) -> str: + def get_item_href(self, item: "Item", parent_dir: str) -> str: if self.item_template is None: return self.fallback_strategy.get_item_href(item, parent_dir) else: @@ -367,7 +373,7 @@ class BestPracticesLayoutStrategy(HrefLayoutStrategy): All paths are appended to the parent directory. """ - def get_catalog_href(self, cat: Catalog, parent_dir: str, is_root: bool) -> str: + def get_catalog_href(self, cat: "Catalog", parent_dir: str, is_root: bool) -> str: if is_root: cat_root = parent_dir else: @@ -375,7 +381,7 @@ def get_catalog_href(self, cat: Catalog, parent_dir: str, is_root: bool) -> str: return os.path.join(cat_root, cat.DEFAULT_FILE_NAME) - def get_collection_href(self, col: Collection, parent_dir: str, is_root: bool) -> str: + def get_collection_href(self, col: "Collection", parent_dir: str, is_root: bool) -> str: if is_root: col_root = parent_dir else: @@ -383,7 +389,7 @@ def get_collection_href(self, col: Collection, parent_dir: str, is_root: bool) - return os.path.join(col_root, col.DEFAULT_FILE_NAME) - def get_item_href(self, item: Item, parent_dir: str) -> str: + def get_item_href(self, item: "Item", parent_dir: str) -> str: item_root = os.path.join(parent_dir, '{}'.format(item.id)) return os.path.join(item_root, '{}.json'.format(item.id)) diff --git a/pystac/link.py b/pystac/link.py index b7a45515e..0490ecaa7 100644 --- a/pystac/link.py +++ b/pystac/link.py @@ -1,15 +1,16 @@ from copy import copy -from pystac.item import Item -from pystac.catalog import Catalog -from pystac.collection import Collection -from typing import Any, Dict, Optional, Union, cast - -import pystac -from pystac.stac_object import STACObject -from pystac import STACError +from typing import Any, Dict, Optional, TYPE_CHECKING, Union, cast + +import pystac as ps from pystac.stac_io import STAC_IO from pystac.utils import (make_absolute_href, make_relative_href, is_absolute_href) +if TYPE_CHECKING: + from pystac.stac_object import STACObject + from pystac.item import Item + from pystac.catalog import Catalog + from pystac.collection import Collection + HIERARCHICAL_LINKS = ['root', 'child', 'parent', 'collection', 'item', 'items'] @@ -57,18 +58,18 @@ class Link: """ def __init__(self, rel: str, - target: Union[str, STACObject], + target: Union[str, "STACObject"], media_type: Optional[str] = None, title: Optional[str] = None, properties: Optional[Dict[str, Any]] = None) -> None: self.rel = rel - self.target: Union[str, STACObject] = target # An object or an href + self.target: Union[str, "STACObject"] = target # An object or an href self.media_type = media_type self.title = title self.properties = properties self.owner = None - def set_owner(self, owner: STACObject) -> "Link": + def set_owner(self, owner: "STACObject") -> "Link": """Sets the owner of this link. Args: @@ -77,6 +78,18 @@ def set_owner(self, owner: STACObject) -> "Link": self.owner = owner return self + @property + def href(self) -> str: + """Returns the HREF for this link. + + If the href is None, this will throw an exception. + Use get_href if there may not be an href. + """ + result = self.get_href() + if result is None: + raise ValueError(f'{self} does not have an HREF set.') + return result + def get_href(self) -> Optional[str]: """Gets the HREF for this link. @@ -88,14 +101,14 @@ def get_href(self) -> Optional[str]: """ # get the self href if self.is_resolved(): - href = cast(STACObject, self.target).get_self_href() + href = cast(ps.STACObject, self.target).get_self_href() else: href = cast(Optional[str], self.target) if href and is_absolute_href(href) and self.owner and self.owner.get_root(): root = self.owner.get_root() rel_links = HIERARCHICAL_LINKS + \ - pystac.STAC_EXTENSIONS.get_extended_object_links(self.owner) + ps.STAC_EXTENSIONS.get_extended_object_links(self.owner) # if a hierarchical link with an owner and root, and relative catalog if root.is_relative() and self.rel in rel_links: owner_href = self.owner.get_self_href() @@ -104,6 +117,18 @@ def get_href(self) -> Optional[str]: return href + @property + def absolute_href(self) -> str: + """Returns the absolute HREF for this link. + + If the href is None, this will throw an exception. + Use get_absolute_href if there may not be an href set. + """ + result = self.get_absolute_href() + if result is None: + raise ValueError(f'{self} does not have an HREF set.') + return result + def get_absolute_href(self) -> Optional[str]: """Gets the absolute href for this link, if possible. @@ -113,7 +138,7 @@ def get_absolute_href(self) -> Optional[str]: and has an unresolved target, this will return a relative HREF. """ if self.is_resolved(): - href = cast(STACObject, self.target).get_self_href() + href = cast(ps.STACObject, self.target).get_self_href() else: href = cast(Optional[str], self.target) @@ -125,7 +150,7 @@ def get_absolute_href(self) -> Optional[str]: def __repr__(self): return ''.format(self.rel, self.target) - def resolve_stac_object(self, root: Optional[Catalog]=None) -> "Link": + def resolve_stac_object(self, root: Optional["Catalog"] = None) -> "Link": """Resolves a STAC object from the HREF of this link, if the link is not already resolved. @@ -140,13 +165,13 @@ def resolve_stac_object(self, root: Optional[Catalog]=None) -> "Link": # If it's a relative link, base it off the parent. if not is_absolute_href(target_href): if self.owner is None: - raise STACError('Relative path {} encountered ' - 'without owner or start_href.'.format(target_href)) + raise ps.STACError('Relative path {} encountered ' + 'without owner or start_href.'.format(target_href)) start_href = self.owner.get_self_href() if start_href is None: - raise STACError('Relative path {} encountered ' - 'without owner "self" link set.'.format(target_href)) + raise ps.STACError('Relative path {} encountered ' + 'without owner "self" link set.'.format(target_href)) target_href = make_absolute_href(target_href, start_href) obj = None @@ -165,7 +190,7 @@ def resolve_stac_object(self, root: Optional[Catalog]=None) -> "Link": self.target = obj - if self.owner and self.rel in ['child', 'item'] and isinstance(self.owner, Catalog): + if self.owner and self.rel in ['child', 'item'] and isinstance(self.owner, ps.Catalog): self.target.set_parent(self.owner) return self @@ -236,17 +261,17 @@ def from_dict(d: Dict[str, Any]) -> "Link": return Link(rel=rel, target=href, media_type=media_type, title=title, properties=properties) @staticmethod - def root(c: Catalog) -> "Link": + def root(c: "Catalog") -> "Link": """Creates a link to a root Catalog or Collection.""" return Link('root', c, media_type='application/json') @staticmethod - def parent(c: Catalog) -> "Link": + def parent(c: "Catalog") -> "Link": """Creates a link to a parent Catalog or Collection.""" return Link('parent', c, media_type='application/json') @staticmethod - def collection(c: Collection) -> "Link": + def collection(c: "Collection") -> "Link": """Creates a link to an item's Collection.""" return Link('collection', c, media_type='application/json') @@ -256,11 +281,11 @@ def self_href(href: str) -> "Link": return Link('self', href, media_type='application/json') @staticmethod - def child(c: Catalog, title: Optional[str]=None) -> "Link": + def child(c: "Catalog", title: Optional[str] = None) -> "Link": """Creates a link to a child Catalog or Collection.""" return Link('child', c, title=title, media_type='application/json') @staticmethod - def item(item: Item, title: Optional[str]=None) -> "Link": + def item(item: "Item", title: Optional[str] = None) -> "Link": """Creates a link to an Item.""" return Link('item', item, title=title, media_type='application/json') diff --git a/pystac/serialization/__init__.py b/pystac/serialization/__init__.py index 9ada41486..96b33ceff 100644 --- a/pystac/serialization/__init__.py +++ b/pystac/serialization/__init__.py @@ -1,15 +1,22 @@ # flake8: noqa -from pystac.stac_object import STACObject -from typing import Any, Dict, Optional -from pystac import (Catalog, Collection, Item, STACObjectType) +from typing import Any, Dict, Optional, TYPE_CHECKING -from pystac.serialization.identify import (STACJSONDescription, STACVersionRange, STACVersionID, # type:ignore - identify_stac_object, identify_stac_object_type) +import pystac as ps +from pystac.serialization.identify import ( + STACVersionRange, # type:ignore + identify_stac_object, + identify_stac_object_type) from pystac.serialization.common_properties import merge_common_properties from pystac.serialization.migrate import migrate_to_latest +if TYPE_CHECKING: + from pystac.stac_object import STACObject + from pystac.catalog import Catalog -def stac_object_from_dict(d: Dict[str, Any], href: Optional[str]=None, root: Optional[Catalog]=None) -> STACObject: + +def stac_object_from_dict(d: Dict[str, Any], + href: Optional[str] = None, + root: Optional["Catalog"] = None) -> "STACObject": """Determines how to deserialize a dictionary into a STAC object. Args: @@ -23,7 +30,7 @@ def stac_object_from_dict(d: Dict[str, Any], href: Optional[str]=None, root: Opt Note: This is used internally in STAC_IO to deserialize STAC Objects. It is in the top level __init__ in order to avoid circular dependencies. """ - if identify_stac_object_type(d) == STACObjectType.ITEM: + if identify_stac_object_type(d) == ps.STACObjectType.ITEM: collection_cache = None if root is not None: collection_cache = root._resolved_objects.as_collection_cache() @@ -35,13 +42,13 @@ def stac_object_from_dict(d: Dict[str, Any], href: Optional[str]=None, root: Opt d, info = migrate_to_latest(d, info) - if info.object_type == STACObjectType.CATALOG: - return Catalog.from_dict(d, href=href, root=root) + if info.object_type == ps.STACObjectType.CATALOG: + return ps.Catalog.from_dict(d, href=href, root=root, migrate=False) - if info.object_type == STACObjectType.COLLECTION: - return Collection.from_dict(d, href=href, root=root) + if info.object_type == ps.STACObjectType.COLLECTION: + return ps.Collection.from_dict(d, href=href, root=root, migrate=False) - if info.object_type == STACObjectType.ITEM: - return Item.from_dict(d, href=href, root=root) + if info.object_type == ps.STACObjectType.ITEM: + return ps.Item.from_dict(d, href=href, root=root, migrate=False) raise ValueError(f"Unknown STAC object type {info.object_type}") diff --git a/pystac/serialization/common_properties.py b/pystac/serialization/common_properties.py index 4c3677084..05a249762 100644 --- a/pystac/serialization/common_properties.py +++ b/pystac/serialization/common_properties.py @@ -1,10 +1,11 @@ -from pystac.cache import CollectionCache -from typing import Any, Dict, Optional, Union, cast -from pystac import Collection +from typing import Any, Dict, Iterable, Optional, Union, cast from pystac.utils import make_absolute_href from pystac.stac_io import STAC_IO from pystac.serialization.identify import STACVersionID +import pystac as ps +from pystac.cache import CollectionCache + def merge_common_properties(item_dict: Dict[str, Any], collection_cache: Optional[CollectionCache] = None, @@ -24,7 +25,7 @@ def merge_common_properties(item_dict: Dict[str, Any], """ properties_merged = False - collection: Optional[Union[Collection, Dict[str, Any]]] = None + collection: Optional[Union[ps.Collection, Dict[str, Any]]] = None collection_id: Optional[str] = None collection_href: Optional[str] = None @@ -40,7 +41,7 @@ def merge_common_properties(item_dict: Dict[str, Any], # we don't have to merge. if stac_version is not None and stac_version == '0.9.0': stac_extensions = item_dict.get('stac_extensions') - if type(stac_extensions) is list: + if isinstance(stac_extensions, list): if 'commons' not in stac_extensions: return False else: @@ -54,11 +55,11 @@ def merge_common_properties(item_dict: Dict[str, Any], # Next, try the collection link. if collection is None: - links = item_dict['links'] - # Account for 0.5 links, which were dicts - if isinstance(links, Dict[str, Dict[str, Any]]): - links = list(links.values()) + if isinstance(item_dict['links'], dict): + links = list(cast(Iterable[Dict[str, Any]], item_dict['links'].values())) + else: + links = cast(Iterable[Dict[str, Any]], item_dict['links']) collection_link = next((link for link in links if link['rel'] == 'collection'), None) if collection_link is not None: @@ -72,13 +73,12 @@ def merge_common_properties(item_dict: Dict[str, Any], if collection is None: collection = STAC_IO.read_json(collection_href) - # TODO: Remove properties from Collection, it would be in extra_fields if collection is not None: collection_id = None collection_props: Optional[Dict[str, Any]] = None - if isinstance(collection, Collection): + if isinstance(collection, ps.Collection): collection_id = collection.id - collection_props = collection.properties + collection_props = collection.extra_fields.get("properties") elif isinstance(collection, dict): collection_id = collection['id'] if 'properties' in collection: diff --git a/pystac/serialization/identify.py b/pystac/serialization/identify.py index 641366725..8c4216830 100644 --- a/pystac/serialization/identify.py +++ b/pystac/serialization/identify.py @@ -1,7 +1,7 @@ from functools import total_ordering from typing import Any, Dict, List, Optional, Tuple, Union, cast -from pystac import STACObjectType +import pystac from pystac.version import STACVersion from pystac.extensions import Extensions @@ -132,7 +132,8 @@ def __repr__(self) -> str: ','.join(self.custom_extensions)) -def _identify_stac_extensions(object_type: str, d: Dict[str, Any], version_range: STACVersionRange) -> List[str]: +def _identify_stac_extensions(object_type: str, d: Dict[str, Any], + version_range: STACVersionRange) -> List[str]: """Identifies extensions for STAC Objects that don't list their extensions in a 'stac_extensions' property. @@ -143,7 +144,7 @@ def _identify_stac_extensions(object_type: str, d: Dict[str, Any], version_range # assets (collection assets) - if object_type == STACObjectType.ITEMCOLLECTION: + if object_type == pystac.STACObjectType.ITEMCOLLECTION: if 'assets' in d: stac_extensions.add('assets') version_range.set_min(STACVersionID('0.8.0')) @@ -152,7 +153,12 @@ def _identify_stac_extensions(object_type: str, d: Dict[str, Any], version_range if 'links' in d: found_checksum = False for link in d['links']: - link_props = cast(Dict[str, Any], link).keys() + # Account for old links as dicts + if isinstance(link, str): + link_props = cast(Dict[str, Any], d['links'][link]).keys() + else: + link_props = cast(Dict[str, Any], link).keys() + if any(prop.startswith('checksum:') for prop in link_props): found_checksum = True stac_extensions.add(Extensions.CHECKSUM) @@ -167,19 +173,19 @@ def _identify_stac_extensions(object_type: str, d: Dict[str, Any], version_range version_range.set_min(STACVersionID('0.6.2')) # datacube - if object_type == STACObjectType.ITEM: + if object_type == pystac.STACObjectType.ITEM: if any(k.startswith('cube:') for k in cast(Dict[str, Any], d['properties'])): stac_extensions.add(Extensions.DATACUBE) version_range.set_min(STACVersionID('0.6.1')) # datetime-range (old extension) - if object_type == STACObjectType.ITEM: + if object_type == pystac.STACObjectType.ITEM: if 'dtr:start_datetime' in d['properties']: stac_extensions.add('datetime-range') version_range.set_min(STACVersionID('0.6.0')) # eo - if object_type == STACObjectType.ITEM: + if object_type == pystac.STACObjectType.ITEM: if any(k.startswith('eo:') for k in cast(Dict[str, Any], d['properties'])): stac_extensions.add(Extensions.EO) if 'eo:epsg' in d['properties']: @@ -194,13 +200,13 @@ def _identify_stac_extensions(object_type: str, d: Dict[str, Any], version_range version_range.set_max(STACVersionID('0.5.2')) # pointcloud - if object_type == STACObjectType.ITEM: + if object_type == pystac.STACObjectType.ITEM: if any(k.startswith('pc:') for k in cast(Dict[str, Any], d['properties'])): stac_extensions.add(Extensions.POINTCLOUD) version_range.set_min(STACVersionID('0.6.2')) # sar - if object_type == STACObjectType.ITEM: + if object_type == pystac.STACObjectType.ITEM: if any(k.startswith('sar:') for k in cast(Dict[str, Any], d['properties'])): stac_extensions.add(Extensions.SAR) version_range.set_min(STACVersionID('0.6.2')) @@ -227,7 +233,7 @@ def _identify_stac_extensions(object_type: str, d: Dict[str, Any], version_range version_range.set_max(STACVersionID('0.6.2')) # scientific - if object_type == STACObjectType.ITEM or object_type == STACObjectType.COLLECTION: + if object_type == pystac.STACObjectType.ITEM or object_type == pystac.STACObjectType.COLLECTION: if 'properties' in d: prop_keys = cast(Dict[str, Any], d['properties']).keys() if any(k.startswith('sci:') for k in prop_keys): @@ -235,7 +241,7 @@ def _identify_stac_extensions(object_type: str, d: Dict[str, Any], version_range version_range.set_min(STACVersionID('0.6.0')) # Single File STAC - if object_type == STACObjectType.ITEMCOLLECTION: + if object_type == pystac.STACObjectType.ITEMCOLLECTION: if 'collections' in d: stac_extensions.add(Extensions.SINGLE_FILE_STAC) version_range.set_min(STACVersionID('0.8.0')) @@ -275,14 +281,14 @@ def identify_stac_object_type(json_dict: Dict[str, Any]): if 'type' in json_dict and 'assets' not in json_dict: if 'stac_version' in json_dict and cast(str, json_dict['stac_version']).startswith('0'): if json_dict['type'] == 'FeatureCollection': - object_type = STACObjectType.ITEMCOLLECTION + object_type = pystac.STACObjectType.ITEMCOLLECTION if 'extent' in json_dict: - object_type = STACObjectType.COLLECTION + object_type = pystac.STACObjectType.COLLECTION elif 'assets' in json_dict: - object_type = STACObjectType.ITEM + object_type = pystac.STACObjectType.ITEM else: - object_type = STACObjectType.CATALOG + object_type = pystac.STACObjectType.CATALOG return object_type @@ -305,9 +311,10 @@ def identify_stac_object(json_dict: Dict[str, Any]) -> STACJSONDescription: stac_extensions = json_dict.get('stac_extensions', None) if stac_version is None: - if object_type == STACObjectType.CATALOG or object_type == STACObjectType.COLLECTION: + if (object_type == pystac.STACObjectType.CATALOG + or object_type == pystac.STACObjectType.COLLECTION): version_range.set_max(STACVersionID('0.5.2')) - elif object_type == STACObjectType.ITEM: + elif object_type == pystac.STACObjectType.ITEM: version_range.set_max(STACVersionID('0.7.0')) else: # ItemCollection version_range.set_min(STACVersionID('0.8.0')) @@ -323,7 +330,7 @@ def identify_stac_object(json_dict: Dict[str, Any]) -> STACJSONDescription: # but ItemCollection (except after 0.9.0, when ItemCollection also got # the stac_extensions property). if version_range.is_earlier_than('0.8.0') or \ - (object_type == STACObjectType.ITEMCOLLECTION and not version_range.is_later_than( + (object_type == pystac.STACObjectType.ITEMCOLLECTION and not version_range.is_later_than( '0.8.1')): stac_extensions = _identify_stac_extensions(object_type, json_dict, version_range) else: diff --git a/pystac/serialization/migrate.py b/pystac/serialization/migrate.py index e08bea976..3442be738 100644 --- a/pystac/serialization/migrate.py +++ b/pystac/serialization/migrate.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Set, Tuple -from pystac import STACObjectType +import pystac as ps from pystac.version import STACVersion from pystac.extensions import Extensions from pystac.serialization.identify import (STACJSONDescription, STACVersionID, STACVersionRange) @@ -46,7 +46,7 @@ def _migrate_itemcollection(d: Dict[str, Any], version: STACVersionID, def _migrate_item_assets(d: Dict[str, Any], version: STACVersionID, info: STACJSONDescription) -> Optional[Set[str]]: if version < '1.0.0-beta.2': - if info.object_type == STACObjectType.COLLECTION: + if info.object_type == ps.STACObjectType.COLLECTION: if 'assets' in d: d['item_assets'] = d['assets'] del d['assets'] @@ -138,7 +138,7 @@ def _migrate_eo(d: Dict[str, Any], version: STACVersionID, d['properties']['eo:{}'.format(field)] del d['properties']['eo:{}'.format(field)] - if version < '1.0.0-beta.1' and info.object_type == STACObjectType.ITEM: + if version < '1.0.0-beta.1' and info.object_type == ps.STACObjectType.ITEM: # gsd moved from eo to common metadata if 'eo:gsd' in d['properties']: d['properties']['gsd'] = d['properties']['eo:gsd'] @@ -160,7 +160,7 @@ def _migrate_eo(d: Dict[str, Any], version: STACVersionID, def _migrate_label(d: Dict[str, Any], version: STACVersionID, info: STACJSONDescription) -> None: - if info.object_type == STACObjectType.ITEM and version < '1.0.0': + if info.object_type == ps.STACObjectType.ITEM and version < '1.0.0': props = d['properties'] # Migrate 0.8.0-rc1 non-pluralized forms # As it's a common mistake, convert for any pre-1.0.0 version. @@ -213,37 +213,43 @@ def _migrate_single_file_stac(d: Dict[str, Any], version: STACVersionID, pass -_object_migrations: Dict[str, - Callable[[Dict[str, Any], STACVersionID, STACJSONDescription], None]] = { - STACObjectType.CATALOG: _migrate_catalog, - STACObjectType.COLLECTION: _migrate_collection, - STACObjectType.ITEM: _migrate_item, - STACObjectType.ITEMCOLLECTION: _migrate_itemcollection - } - -_extension_migrations: Dict[str, - Callable[[Dict[str, Any], STACVersionID, STACJSONDescription], - Optional[Set[str]]]] = { - Extensions.CHECKSUM: _migrate_checksum, - Extensions.DATACUBE: _migrate_datacube, - Extensions.EO: _migrate_eo, - Extensions.ITEM_ASSETS: _migrate_item_assets, - Extensions.LABEL: _migrate_label, - Extensions.POINTCLOUD: _migrate_pointcloud, - Extensions.SAR: _migrate_sar, - Extensions.SCIENTIFIC: _migrate_scientific, - Extensions.SINGLE_FILE_STAC: _migrate_single_file_stac - } - -_removed_extension_migrations: Dict[str, Callable[ - [Dict[str, Any], STACVersionID, STACJSONDescription], Optional[Set[str]]]] = { +def _get_object_migrations( +) -> Dict[str, Callable[[Dict[str, Any], STACVersionID, STACJSONDescription], None]]: + return { + ps.STACObjectType.CATALOG: _migrate_catalog, + ps.STACObjectType.COLLECTION: _migrate_collection, + ps.STACObjectType.ITEM: _migrate_item, + ps.STACObjectType.ITEMCOLLECTION: _migrate_itemcollection + } + + +def _get_extension_migrations( +) -> Dict[str, Callable[[Dict[str, Any], STACVersionID, STACJSONDescription], Optional[Set[str]]]]: + return { + Extensions.CHECKSUM: _migrate_checksum, + Extensions.DATACUBE: _migrate_datacube, + Extensions.EO: _migrate_eo, + Extensions.ITEM_ASSETS: _migrate_item_assets, + Extensions.LABEL: _migrate_label, + Extensions.POINTCLOUD: _migrate_pointcloud, + Extensions.SAR: _migrate_sar, + Extensions.SCIENTIFIC: _migrate_scientific, + Extensions.SINGLE_FILE_STAC: _migrate_single_file_stac + } + + +def _get_removed_extension_migrations( +) -> Dict[str, Callable[[Dict[str, Any], STACVersionID, STACJSONDescription], Optional[Set[str]]]]: + return { # Removed in 0.9.0 'dtr': _migrate_datetime_range, 'datetime-range': _migrate_datetime_range, 'commons': lambda a, b, c: None # No changes needed, just remove the extension_id } -_extension_renames: Dict[str, str] = {'asset': 'item-assets'} + +def _get_extension_renames() -> Dict[str, str]: + return {'asset': 'item-assets'} def migrate_to_latest(json_dict: Dict[str, Any], @@ -263,23 +269,28 @@ def migrate_to_latest(json_dict: Dict[str, Any], result = deepcopy(json_dict) version = info.version_range.latest_valid_version() + object_migrations = _get_object_migrations() + extension_migrations = _get_extension_migrations() + extension_renames = _get_extension_renames() + removed_extension_migrations = _get_removed_extension_migrations() + if version != STACVersion.DEFAULT_STAC_VERSION: - _object_migrations[info.object_type](result, version, info) + object_migrations[info.object_type](result, version, info) extensions_to_add = set([]) for ext in info.common_extensions: - if ext in _extension_renames: + if ext in extension_renames: result['stac_extensions'].remove(ext) - ext = _extension_renames[ext] + ext = extension_renames[ext] extensions_to_add.add(ext) - if ext in _extension_migrations: - added_extensions = _extension_migrations[ext](result, version, info) + if ext in extension_migrations: + added_extensions = extension_migrations[ext](result, version, info) if added_extensions: extensions_to_add |= added_extensions - if ext in _removed_extension_migrations: - _removed_extension_migrations[ext](result, version, info) + if ext in removed_extension_migrations: + removed_extension_migrations[ext](result, version, info) result['stac_extensions'].remove(ext) for ext in extensions_to_add: @@ -287,8 +298,8 @@ def migrate_to_latest(json_dict: Dict[str, Any], migrated_extensions = set(info.common_extensions) migrated_extensions = migrated_extensions | set(extensions_to_add) - migrated_extensions = migrated_extensions - set(_removed_extension_migrations.keys()) - migrated_extensions = migrated_extensions - set(_extension_renames.keys()) + migrated_extensions = migrated_extensions - set(removed_extension_migrations.keys()) + migrated_extensions = migrated_extensions - set(extension_renames.keys()) common_extensions = list(migrated_extensions) else: common_extensions = info.common_extensions diff --git a/pystac/stac_io.py b/pystac/stac_io.py index 49956df6a..c213f0a0d 100644 --- a/pystac/stac_io.py +++ b/pystac/stac_io.py @@ -1,13 +1,15 @@ import os import json -from pystac.stac_object import STACObject -from pystac.catalog import Catalog -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING from urllib.parse import urlparse from urllib.request import urlopen from urllib.error import HTTPError +if TYPE_CHECKING: + from pystac.stac_object import STACObject + from pystac.catalog import Catalog + class STAC_IO: """Methods used to read and save STAC json. @@ -54,7 +56,8 @@ def default_write_text_method(uri: str, txt: str) -> None: """ # Replaced in __init__ to account for extension objects. - stac_object_from_dict: Optional[Callable[[Dict[str, Any], Optional[str], Optional[Catalog]], STACObject]] = None + stac_object_from_dict: Optional[Callable[[Dict[str, Any], Optional[str], Optional["Catalog"]], + "STACObject"]] = None # This is set in __init__.py _STAC_OBJECT_CLASSES = None @@ -113,7 +116,7 @@ def read_json(cls, uri: str) -> Dict[str, Any]: return json.loads(STAC_IO.read_text(uri)) @classmethod - def read_stac_object(cls, uri: str, root: Optional[Catalog]=None) -> STACObject: + def read_stac_object(cls, uri: str, root: Optional["Catalog"] = None) -> "STACObject": """Read a STACObject from a JSON file at the given URI. Args: diff --git a/pystac/stac_object.py b/pystac/stac_object.py index cc578e24d..ed0f8d9fb 100644 --- a/pystac/stac_object.py +++ b/pystac/stac_object.py @@ -1,16 +1,17 @@ from abc import (ABC, abstractmethod) from enum import Enum -from pystac.catalog import Catalog -from typing import Any, Dict, Generator, List, Optional, cast +from typing import Any, Dict, Generator, List, Optional, cast, TYPE_CHECKING -import pystac -import pystac.validation +import pystac as ps from pystac import STACError from pystac.link import Link from pystac.stac_io import STAC_IO from pystac.utils import (is_absolute_href, make_absolute_href) from pystac.extensions import ExtensionError -from pystac.extensions.base import STACObjectExtension + +if TYPE_CHECKING: + from pystac.catalog import Catalog as CatalogType + from pystac.extensions.base import STACObjectExtension as STACObjectExtensionType class STACObjectType(str, Enum): @@ -110,6 +111,20 @@ def get_root_link(self): """ return self.get_single_link('root') + @property + def self_href(self) -> str: + """Gets the absolute HREF that is represented by the ``rel == 'self'`` + :class:`~pystac.Link`. + + Raises: + ValueError: If the self_href is not set, this method will throw a ValueError. + Use get_self_href if there may not be an href set. + """ + result = self.get_self_href() + if result is None: + raise ValueError(f"{self} does not have a self_href set.") + return result + def get_self_href(self) -> Optional[str]: """Gets the absolute HREF that is represented by the ``rel == 'self'`` :class:`~pystac.Link`. @@ -143,14 +158,14 @@ def set_self_href(self, href: str) -> "LinkMixin": """ root_link = self.get_root_link() if root_link is not None and root_link.is_resolved(): - cast(Catalog, root_link.target)._resolved_objects.remove(cast(STACObject, self)) + cast(ps.Catalog, root_link.target)._resolved_objects.remove(cast(STACObject, self)) self.remove_links('self') if href is not None: self.add_link(Link.self_href(href)) if root_link is not None and root_link.is_resolved(): - cast(Catalog, root_link.target)._resolved_objects.cache(cast(STACObject, self)) + cast(ps.Catalog, root_link.target)._resolved_objects.cache(cast(STACObject, self)) return self @@ -173,15 +188,20 @@ def __init__(self, stac_extensions: List[str]): self.links = [] self.stac_extensions = stac_extensions - def validate(self): + def validate(self) -> List[Any]: """Validate this STACObject. + Returns a list of validation results, which depends on the validation + implementation. For JSON Schema validation, this will be a list + of schema URIs that were used during validation. + Raises: STACValidationError """ - return pystac.validation.validate(self) + import pystac.validation + return pystac.validation.validate(self) # type:ignore - def get_root(self) -> Optional[Catalog]: + def get_root(self) -> Optional["CatalogType"]: """Get the :class:`~pystac.Catalog` or :class:`~pystac.Collection` to the root for this object. The root is represented by a :class:`~pystac.Link` with ``rel == 'root'``. @@ -195,12 +215,12 @@ def get_root(self) -> Optional[Catalog]: if not root_link.is_resolved(): root_link.resolve_stac_object() # Use set_root, so Catalogs can merge ResolvedObjectCache instances. - self.set_root(cast(Catalog, root_link.target)) - return cast(Catalog, root_link.target) + self.set_root(cast(ps.Catalog, root_link.target)) + return cast("CatalogType", root_link.target) else: return None - def set_root(self, root: Optional[Catalog]) -> "STACObject": + def set_root(self, root: Optional["CatalogType"]) -> "STACObject": """Sets the root :class:`~pystac.Catalog` or :class:`~pystac.Collection` for this object. @@ -215,7 +235,7 @@ def set_root(self, root: Optional[Catalog]) -> "STACObject": if root_link_index is not None: root_link = self.links[root_link_index] if root_link.is_resolved(): - cast(Catalog, root_link.target)._resolved_objects.remove(self) + cast(ps.Catalog, root_link.target)._resolved_objects.remove(self) if root is None: self.remove_links('root') @@ -230,7 +250,7 @@ def set_root(self, root: Optional[Catalog]) -> "STACObject": return self - def get_parent(self) -> Optional["STACObject"]: + def get_parent(self) -> Optional["CatalogType"]: """Get the :class:`~pystac.Catalog` or :class:`~pystac.Collection` to the parent for this object. The root is represented by a :class:`~pystac.Link` with ``rel == 'parent'``. @@ -242,11 +262,11 @@ def get_parent(self) -> Optional["STACObject"]: """ parent_link = self.get_single_link('parent') if parent_link: - return cast(Catalog, parent_link.resolve_stac_object().target) + return cast(ps.Catalog, parent_link.resolve_stac_object().target) else: return None - def set_parent(self, parent: Optional[Catalog]) -> "STACObject": + def set_parent(self, parent: Optional["CatalogType"]) -> "STACObject": """Sets the parent :class:`~pystac.Catalog` or :class:`~pystac.Collection` for this object. @@ -306,8 +326,8 @@ def save_object(self, include_self_link: bool = True, dest_href: Optional[str] = STAC_IO.save_json(dest_href, self.to_dict(include_self_link=include_self_link)) def full_copy(self, - root: Optional[Catalog] = None, - parent: Optional[Catalog] = None) -> "STACObject": + root: Optional["CatalogType"] = None, + parent: Optional["CatalogType"] = None) -> "STACObject": """Create a full copy of this STAC object and any stac objects linked to by this object. @@ -323,10 +343,10 @@ def full_copy(self, """ clone = self.clone() - if root is None and isinstance(clone, Catalog): + if root is None and isinstance(clone, ps.Catalog): root = clone - clone.set_root(cast(Catalog, root)) + clone.set_root(cast(ps.Catalog, root)) if parent: clone.set_parent(parent) @@ -340,14 +360,14 @@ def full_copy(self, assert target is not None else: target_parent = None - if link.rel in ['child', 'item'] and isinstance(clone, Catalog): + if link.rel in ['child', 'item'] and isinstance(clone, ps.Catalog): target_parent = clone copied_target = target.full_copy(root=root, parent=target_parent) root._resolved_objects.cache(copied_target) target = copied_target if link.rel in ['child', 'item']: target.set_root(root) - if isinstance(clone, Catalog): + if isinstance(clone, ps.Catalog): target.set_parent(clone) link.target = target @@ -433,10 +453,7 @@ def from_file(cls, href: str) -> "STACObject": href = make_absolute_href(href) d = STAC_IO.read_json(href) - if cls == STACObject: - o = STAC_IO.stac_object_from_dict(d, href, None) - else: - o = cls.from_dict(d, href, None) + o = STAC_IO.stac_object_from_dict(d, href, None) # Set the self HREF, if it's not already set to something else. if o.get_self_href() is None: @@ -447,7 +464,7 @@ def from_file(cls, href: str) -> "STACObject": if root_link is not None: if not root_link.is_resolved(): if root_link.get_absolute_href() == href: - o.set_root(cast(Catalog, o)) + o.set_root(cast(ps.Catalog, o)) return o @classmethod @@ -455,7 +472,8 @@ def from_file(cls, href: str) -> "STACObject": def from_dict(cls, d: Dict[str, Any], href: Optional[str] = None, - root: Optional[Catalog] = None) -> "STACObject": + root: Optional["CatalogType"] = None, + migrate: bool = False) -> "STACObject": """Parses this STACObject from the passed in dictionary. Args: @@ -465,6 +483,8 @@ def from_dict(cls, root (Catalog or Collection): Optional root of the catalog for this object. If provided, the root's resolved object cache can be used to search for previously resolved instances of the STAC object. + migrate: Use True if this dict represents JSON from an older STAC object, + so that migrations are run against it. Returns: STACObject: The STACObject parsed from this dict. @@ -484,8 +504,7 @@ class ExtensionIndex: def __init__(self, stac_object: STACObject) -> None: self.stac_object = stac_object - def __getitem__( - self, extension_id: str) -> STACObjectExtension: + def __getitem__(self, extension_id: str) -> "STACObjectExtensionType": """Gets the extension object for the given extension. Returns: @@ -494,7 +513,7 @@ def __getitem__( by the extension_id. """ # Check to make sure this is a registered extension. - if not pystac.STAC_EXTENSIONS.is_registered_extension(extension_id): + if not ps.STAC_EXTENSIONS.is_registered_extension(extension_id): raise ExtensionError("'{}' is not an extension " "registered with PySTAC".format(extension_id)) @@ -503,10 +522,9 @@ def __getitem__( "Use the 'ext.enable' method to enable this extension " "first.".format(self.stac_object, extension_id)) - return pystac.STAC_EXTENSIONS.extend_object(extension_id, self.stac_object) + return ps.STAC_EXTENSIONS.extend_object(extension_id, self.stac_object) - def __getattr__( - self, extension_id: str) -> STACObjectExtension: + def __getattr__(self, extension_id: str) -> "STACObjectExtensionType": """Gets an extension based on a dynamic attribute. This takes the attribute name and passes it to __getitem__. @@ -530,7 +548,7 @@ def enable(self, extension_id: str) -> None: the object should implement """ - pystac.STAC_EXTENSIONS.enable_extension(extension_id, self.stac_object) + ps.STAC_EXTENSIONS.enable_extension(extension_id, self.stac_object) def implements(self, extension_id: str) -> bool: """Returns true if the associated object implements the given extension. diff --git a/pystac/utils.py b/pystac/utils.py index 4dc627972..f0b30f709 100644 --- a/pystac/utils.py +++ b/pystac/utils.py @@ -40,7 +40,7 @@ def _join(is_path: bool, *args: str) -> str: return posixpath.join(*args) -def make_relative_href(source_href: str, start_href: str, start_is_dir: bool=False) -> str: +def make_relative_href(source_href: str, start_href: str, start_is_dir: bool = False) -> str: """Makes a given HREF relative to the given starting HREF. Args: @@ -94,9 +94,6 @@ def make_absolute_href(source_href: str, then it will be returned unchanged. If the source_href it None, it will return None. """ - if source_href is None: - return None # TODO: Remove the None case - if start_href is None: start_href = os.getcwd() start_is_dir = True diff --git a/pystac/validation/__init__.py b/pystac/validation/__init__.py index 4d4e80215..3f504ad2f 100644 --- a/pystac/validation/__init__.py +++ b/pystac/validation/__init__.py @@ -1,10 +1,12 @@ # flake8: noqa -from typing import Dict, List, Any, Optional, cast -from pystac.stac_object import STACObject +from typing import Dict, List, Any, Optional, cast, TYPE_CHECKING import pystac from pystac.serialization.identify import identify_stac_object from pystac.utils import make_absolute_href +if TYPE_CHECKING: + from pystac.stac_object import STACObject + class STACValidationError(Exception): """Represents a validation error. Thrown by validation calls if the STAC JSON @@ -15,7 +17,7 @@ class STACValidationError(Exception): validation implementation. For the default JsonSchemaValidator this will a the ``jsonschema.ValidationError``. """ - def __init__(self, message: str, source: Optional[Any]=None): + def __init__(self, message: str, source: Optional[Any] = None): super().__init__(message) self.source = source @@ -24,7 +26,7 @@ def __init__(self, message: str, source: Optional[Any]=None): from pystac.validation.stac_validator import (STACValidator, JsonSchemaSTACValidator) -def validate(stac_object: STACObject) -> List[Any]: +def validate(stac_object: "STACObject") -> List[Any]: """Validates a :class:`~pystac.STACObject`. Args: @@ -39,13 +41,17 @@ def validate(stac_object: STACObject) -> List[Any]: STACValidationError """ return validate_dict(stac_dict=stac_object.to_dict(), - stac_object_type=stac_object.STAC_OBJECT_TYPE, - stac_version=pystac.get_stac_version(), - extensions=stac_object.stac_extensions, - href=stac_object.get_self_href()) + stac_object_type=stac_object.STAC_OBJECT_TYPE, + stac_version=pystac.get_stac_version(), + extensions=stac_object.stac_extensions, + href=stac_object.get_self_href()) -def validate_dict(stac_dict: Dict[str, Any], stac_object_type: Optional[str]=None, stac_version: Optional[str]=None, extensions: Optional[List[str]]=None, href: Optional[str]=None) -> List[Any]: +def validate_dict(stac_dict: Dict[str, Any], + stac_object_type: Optional[str] = None, + stac_version: Optional[str] = None, + extensions: Optional[List[str]] = None, + href: Optional[str] = None) -> List[Any]: """Validate a stac object serialized as JSON into a dict. This method delegates to the call to :meth:`pystac.validation.STACValidator.validate` @@ -116,7 +122,7 @@ def validate_all(stac_dict: Dict[str, Any], href: str) -> None: if info.object_type != pystac.STACObjectType.ITEM: if 'links' in stac_dict: # Account for 0.6 links - if isinstance(stac_dict['links'], Dict[str, Dict[str, Any]]): + if isinstance(stac_dict['links'], dict): links: List[Dict[str, Any]] = list(stac_dict['links'].values()) else: links: List[Dict[str, Any]] = cast(List[Dict[str, Any]], stac_dict.get('links')) diff --git a/pystac/validation/schema_uri_map.py b/pystac/validation/schema_uri_map.py index 13a247c27..c9f9a7d1f 100644 --- a/pystac/validation/schema_uri_map.py +++ b/pystac/validation/schema_uri_map.py @@ -193,7 +193,8 @@ def get_core_schema_uri(self, object_type: STACObjectType, stac_version: str): return self._append_base_uri_if_needed(uri, stac_version) - def get_extension_schema_uri(self, extension_id: str, object_type: STACObjectType, stac_version: str): + def get_extension_schema_uri(self, extension_id: str, object_type: STACObjectType, + stac_version: str): uri = None is_latest = stac_version == pystac.get_stac_version() diff --git a/pystac/validation/stac_validator.py b/pystac/validation/stac_validator.py index dfdb7e43c..2106040fd 100644 --- a/pystac/validation/stac_validator.py +++ b/pystac/validation/stac_validator.py @@ -201,7 +201,7 @@ def validate_extension(self, stac_object_type: STACObjectType, stac_version: str, extension_id: str, - href: Optional[str]=None): + href: Optional[str] = None): """Validate an extension stac object. Return value can be None or specific to the implementation. diff --git a/tests/data-files/examples/example-info.csv b/tests/data-files/examples/example-info.csv index 6c26c82e5..21be594f0 100644 --- a/tests/data-files/examples/example-info.csv +++ b/tests/data-files/examples/example-info.csv @@ -1,8 +1,138 @@ -"1.0.0-RC1/catalog.json","CATALOG","1.0.0-beta.2","","" -"1.0.0-RC1/collection-only/collection.json","COLLECTION","1.0.0-beta.2","","" -"1.0.0-RC1/collection.json","COLLECTION","1.0.0-beta.2","","" -"1.0.0-RC1/collectionless-item.json","COLLECTION","1.0.0-beta.2","eo|view","" -"1.0.0-RC1/core-item.json","ITEM","1.0.0-beta.2","","" -"1.0.0-RC1/extended-item.json","ITEM","1.0.0-beta.2","eo|projection|scientific|view","" -"1.0.0-RC1/extensions-collection/collection.json","COLLECTION","1.0.0-beta.2","","" -"1.0.0-RC1/extensions-collection/proj-example/proj-example.json","COLLECTION","1.0.0-beta.2","eo|projection","" \ No newline at end of file +"0.4.1/extensions/examples/landsat8-merged.json","ITEM","0.4.1","eo","" +"0.5.2/extensions/examples/landsat8-merged.json","ITEM","0.5.2","eo","" +"0.7.0/extensions/sar/examples/sentinel1.json","ITEM","0.7.0","sar|datetime-range|checksum","" +"0.8.1/catalog-spec/examples/catalog.json","CATALOG","0.8.1","","" +"0.8.1/catalog-spec/examples/summaries-s2.json","CATALOG","0.8.1","","" +"0.8.1/collection-spec/examples/landsat-collection.json","COLLECTION","0.8.1","","" +"0.8.1/collection-spec/examples/landsat-item.json","ITEM","0.8.1","eo","" +"0.8.1/collection-spec/examples/sentinel2.json","COLLECTION","0.8.1","","" +"0.8.1/extensions/asset/examples/example-landsat8.json","COLLECTION","0.8.1","asset","" +"0.8.1/extensions/checksum/examples/example-sentinel1.json","ITEM","0.8.1","checksum","" +"0.8.1/extensions/datacube/examples/example.json","ITEM","0.8.1","datacube","" +"0.8.1/extensions/datetime-range/examples/example-video.json","ITEM","0.8.1","datetime-range","" +"0.8.1/extensions/eo/examples/example-landsat8.json","ITEM","0.8.1","eo","https://example.com/stac/landsat-extension/1.0/schema.json" +"0.8.1/extensions/label/examples/multidataset/catalog.json","CATALOG","0.8.1","","" +"0.8.1/extensions/label/examples/multidataset/spacenet-buildings/AOI_2_Vegas_img2636.json","ITEM","0.8.1","label","" +"0.8.1/extensions/label/examples/multidataset/spacenet-buildings/AOI_3_Paris_img1648.json","ITEM","0.8.1","label","" +"0.8.1/extensions/label/examples/multidataset/spacenet-buildings/AOI_4_Shanghai_img3344.json","ITEM","0.8.1","label","" +"0.8.1/extensions/label/examples/multidataset/spacenet-buildings/collection.json","COLLECTION","0.8.1","","" +"0.8.1/extensions/label/examples/multidataset/zanzibar/collection.json","COLLECTION","0.8.1","","" +"0.8.1/extensions/label/examples/multidataset/zanzibar/znz001.json","ITEM","0.8.1","label","" +"0.8.1/extensions/label/examples/multidataset/zanzibar/znz029.json","ITEM","0.8.1","label","" +"0.8.1/extensions/label/examples/spacenet-roads/roads_collection.json","COLLECTION","0.8.1","","" +"0.8.1/extensions/label/examples/spacenet-roads/roads_item.json","ITEM","0.8.1","label","" +"0.8.1/extensions/label/examples/spacenet-roads/roads_source.json","ITEM","0.8.1","","" +"0.8.1/extensions/pointcloud/examples/example-autzen.json","ITEM","0.8.1","pointcloud","" +"0.8.1/extensions/sar/examples/envisat.json","ITEM","0.8.1","sar|datetime-range","" +"0.8.1/extensions/sar/examples/sentinel1.json","ITEM","0.8.1","checksum|sar|datetime-range","" +"0.8.1/extensions/scientific/examples/collection.json","COLLECTION","0.8.1","scientific","" +"0.8.1/extensions/scientific/examples/item.json","ITEM","0.8.1","datetime-range|checksum|scientific","" +"0.8.1/item-spec/examples/digitalglobe-sample.json","ITEM","0.8.1","eo","https://example.digitalglobe.com/stac/1.0/schema.json" +"0.8.1/item-spec/examples/landsat8-sample.json","ITEM","0.8.1","eo","https://example.com/stac/landsat-extension/1.0/schema.json" +"0.8.1/item-spec/examples/planet-sample.json","ITEM","0.8.1","eo","https://example.planet.com/stac/1.0/schema.json" +"0.8.1/item-spec/examples/sample-full.json","ITEM","0.8.1","eo","https://example.com/cs-extension/1.0/schema.json" +"0.8.1/item-spec/examples/sample.json","ITEM","0.8.1","","" +"0.8.1/item-spec/examples/sentinel2-sample.json","ITEM","0.8.1","eo","" +"0.9.0/catalog-spec/examples/catalog.json","CATALOG","0.9.0","","" +"0.9.0/collection-spec/examples/landsat-collection.json","COLLECTION","0.9.0","commons|view|eo","" +"0.9.0/collection-spec/examples/landsat-item.json","ITEM","0.9.0","commons|eo|view","https://example.com/stac/landsat-extension/1.0/schema.json","INVALID" +"0.9.0/collection-spec/examples/sentinel2.json","COLLECTION","0.9.0","","" +"0.9.0/extensions/asset/examples/example-landsat8.json","COLLECTION","0.9.0","asset|commons","" +"0.9.0/extensions/checksum/examples/sentinel1.json","ITEM","0.9.0","checksum","" +"0.9.0/extensions/commons/examples/landsat-collection.json","COLLECTION","0.9.0","commons","" +"0.9.0/extensions/commons/examples/landsat-item.json","ITEM","0.9.0","commons|eo|sat","https://example.com/stac/landsat-extension/1.0/schema.json" +"0.9.0/extensions/datacube/examples/example-collection.json","COLLECTION","0.9.0","datacube","" +"0.9.0/extensions/datacube/examples/example-item.json","ITEM","0.9.0","datacube","" +"0.9.0/extensions/eo/examples/example-landsat8.json","ITEM","0.9.0","eo|view|commons","https://example.com/stac/landsat-extension/1.0/schema.json" +"0.9.0/extensions/label/examples/multidataset/catalog.json","CATALOG","0.9.0","","" +"0.9.0/extensions/label/examples/multidataset/spacenet-buildings/AOI_2_Vegas_img2636.json","ITEM","0.9.0","label|version","","INVALID" +"0.9.0/extensions/label/examples/multidataset/spacenet-buildings/AOI_3_Paris_img1648.json","ITEM","0.9.0","label|version","","INVALID" +"0.9.0/extensions/label/examples/multidataset/spacenet-buildings/AOI_4_Shanghai_img3344.json","ITEM","0.9.0","label|version","","INVALID" +"0.9.0/extensions/label/examples/multidataset/spacenet-buildings/collection.json","COLLECTION","0.9.0","","" +"0.9.0/extensions/label/examples/multidataset/zanzibar/collection.json","COLLECTION","0.9.0","","" +"0.9.0/extensions/label/examples/multidataset/zanzibar/znz001.json","ITEM","0.9.0","label|version","","INVALID" +"0.9.0/extensions/label/examples/spacenet-roads/roads_collection.json","COLLECTION","0.9.0","","" +"0.9.0/extensions/label/examples/spacenet-roads/roads_item.json","ITEM","0.9.0","label|version","" +"0.9.0/extensions/label/examples/spacenet-roads/roads_source.json","ITEM","0.9.0","","" +"0.9.0/extensions/pointcloud/examples/example-autzen.json","ITEM","0.9.0","pointcloud","" +"0.9.0/extensions/projection/examples/example-landsat8.json","ITEM","0.9.0","proj|commons","" +"0.9.0/extensions/sar/examples/envisat.json","ITEM","0.9.0","sat|sar","" +"0.9.0/extensions/sar/examples/sentinel1.json","ITEM","0.9.0","checksum|sar|sat","" +"0.9.0/extensions/sat/examples/example-landsat8.json","ITEM","0.9.0","sat|view","" +"0.9.0/extensions/scientific/examples/collection.json","COLLECTION","0.9.0","scientific","" +"0.9.0/extensions/scientific/examples/item.json","ITEM","0.9.0","scientific|checksum","" +"0.9.0/extensions/version/examples/collection.json","COLLECTION","0.9.0","version","" +"0.9.0/extensions/version/examples/item.json","ITEM","0.9.0","version","","INVALID" +"0.9.0/extensions/view/examples/example-landsat8.json","ITEM","0.9.0","sat|view","" +"0.9.0/item-spec/examples/datetimerange.json","ITEM","0.9.0","","" +"0.9.0/item-spec/examples/digitalglobe-sample.json","ITEM","0.9.0","eo|proj|view","https://example.digitalglobe.com/stac/1.0/schema.json" +"0.9.0/item-spec/examples/landsat8-sample.json","ITEM","0.9.0","eo|view","https://example.com/stac/landsat-extension/1.0/schema.json" +"0.9.0/item-spec/examples/planet-sample.json","ITEM","0.9.0","eo|view","https://example.planet.com/stac/1.0/schema.json" +"0.9.0/item-spec/examples/sample-full.json","ITEM","0.9.0","eo|view","https://example.com/cs-extension/1.0/schema.json" +"0.9.0/item-spec/examples/sample.json","ITEM","0.9.0","","" +"0.9.0/item-spec/examples/sentinel2-sample.json","ITEM","0.9.0","eo|view|proj|commons","" +"1.0.0-beta.2/catalog-spec/examples/catalog-items.json","CATALOG","1.0.0-beta.2","","" +"1.0.0-beta.2/catalog-spec/examples/catalog.json","CATALOG","1.0.0-beta.2","","" +"1.0.0-beta.2/collection-spec/examples/landsat-collection.json","COLLECTION","1.0.0-beta.2","","" +"1.0.0-beta.2/collection-spec/examples/sentinel2.json","COLLECTION","1.0.0-beta.2","","" +"1.0.0-beta.2/extensions/checksum/examples/sentinel1.json","ITEM","1.0.0-beta.2","checksum","" +"1.0.0-beta.2/extensions/collection-assets/examples/example-esm.json","COLLECTION","1.0.0-beta.2","collection-assets","https://github.com/NCAR/esm-collection-spec/tree/v0.2.0/schema.json" +"1.0.0-beta.2/extensions/datacube/examples/example-collection.json","COLLECTION","1.0.0-beta.2","datacube","" +"1.0.0-beta.2/extensions/datacube/examples/example-item.json","ITEM","1.0.0-beta.2","datacube","" +"1.0.0-beta.2/extensions/eo/examples/example-landsat8.json","ITEM","1.0.0-beta.2","eo|view","https://example.com/stac/landsat-extension/1.0/schema.json" +"1.0.0-beta.2/extensions/item-assets/examples/example-landsat8.json","COLLECTION","1.0.0-beta.2","item-assets","" +"1.0.0-beta.2/extensions/label/examples/multidataset/catalog.json","CATALOG","1.0.0-beta.2","","" +"1.0.0-beta.2/extensions/label/examples/multidataset/spacenet-buildings/AOI_2_Vegas_img2636.json","ITEM","1.0.0-beta.2","label|version","" +"1.0.0-beta.2/extensions/label/examples/multidataset/spacenet-buildings/AOI_3_Paris_img1648.json","ITEM","1.0.0-beta.2","label|version","" +"1.0.0-beta.2/extensions/label/examples/multidataset/spacenet-buildings/AOI_4_Shanghai_img3344.json","ITEM","1.0.0-beta.2","label|version","" +"1.0.0-beta.2/extensions/label/examples/multidataset/spacenet-buildings/collection.json","COLLECTION","1.0.0-beta.2","","" +"1.0.0-beta.2/extensions/label/examples/multidataset/zanzibar/collection.json","COLLECTION","1.0.0-beta.2","","" +"1.0.0-beta.2/extensions/label/examples/multidataset/zanzibar/znz001.json","ITEM","1.0.0-beta.2","label|version","" +"1.0.0-beta.2/extensions/label/examples/multidataset/zanzibar/znz029.json","ITEM","1.0.0-beta.2","label|version","" +"1.0.0-beta.2/extensions/label/examples/spacenet-roads/roads_collection.json","COLLECTION","1.0.0-beta.2","","" +"1.0.0-beta.2/extensions/label/examples/spacenet-roads/roads_item.json","ITEM","1.0.0-beta.2","label|version","" +"1.0.0-beta.2/extensions/label/examples/spacenet-roads/roads_source.json","ITEM","1.0.0-beta.2","","" +"1.0.0-beta.2/extensions/pointcloud/examples/example-autzen.json","ITEM","1.0.0-beta.2","pointcloud","" +"1.0.0-beta.2/extensions/projection/examples/example-landsat8.json","ITEM","1.0.0-beta.2","eo|projection","" +"1.0.0-beta.2/extensions/sar/examples/envisat.json","ITEM","1.0.0-beta.2","sat|sar","" +"1.0.0-beta.2/extensions/sar/examples/sentinel1.json","ITEM","1.0.0-beta.2","checksum|sar|sat","" +"1.0.0-beta.2/extensions/sat/examples/example-landsat8.json","ITEM","1.0.0-beta.2","sat|view","" +"1.0.0-beta.2/extensions/scientific/examples/collection.json","COLLECTION","1.0.0-beta.2","scientific","" +"1.0.0-beta.2/extensions/scientific/examples/item.json","ITEM","1.0.0-beta.2","scientific|checksum","" +"1.0.0-beta.2/extensions/tiled-assets/examples/example-dimension.json","ITEM","1.0.0-beta.2","datacube|eo|tiled-assets","" +"1.0.0-beta.2/extensions/tiled-assets/examples/example-tiled.json","ITEM","1.0.0-beta.2","eo|tiled-assets","" +"1.0.0-beta.2/extensions/timestamps/examples/example-landsat8.json","ITEM","1.0.0-beta.2","timestamps","" +"1.0.0-beta.2/extensions/version/examples/collection.json","COLLECTION","1.0.0-beta.2","version","" +"1.0.0-beta.2/extensions/version/examples/item.json","ITEM","1.0.0-beta.2","version","" +"1.0.0-beta.2/extensions/view/examples/example-landsat8.json","ITEM","1.0.0-beta.2","sat|view","" +"1.0.0-beta.2/item-spec/examples/CBERS_4_MUX_20181029_177_106_L4.json","ITEM","1.0.0-beta.2","projection|view","https://example.com/stac/cbers-extension/1.0/schema.json" +"1.0.0-beta.2/item-spec/examples/datetimerange.json","ITEM","1.0.0-beta.2","","" +"1.0.0-beta.2/item-spec/examples/digitalglobe-sample.json","ITEM","1.0.0-beta.2","eo|projection|view","https://example.digitalglobe.com/stac/1.0/schema.json" +"1.0.0-beta.2/item-spec/examples/landsat8-sample.json","ITEM","1.0.0-beta.2","eo|view","https://example.com/stac/landsat-extension/1.0/schema.json" +"1.0.0-beta.2/item-spec/examples/planet-sample.json","ITEM","1.0.0-beta.2","eo|view","https://example.planet.com/stac/1.0/schema.json" +"1.0.0-beta.2/item-spec/examples/sample-full.json","ITEM","1.0.0-beta.2","eo|view","https://example.com/cs-extension/1.0/schema.json" +"1.0.0-beta.2/item-spec/examples/sample.json","ITEM","1.0.0-beta.2","","" +"1.0.0-beta.2/item-spec/examples/sentinel2-sample.json","ITEM","1.0.0-beta.2","view|projection","" +"gee-0.6.2/CIESIN_GPWv411_GPW_National_Identifier_Grid.json","COLLECTION","0.6.2","scientific","" +"gee-0.6.2/LANDSAT_LT05_C01_T1_ANNUAL_NDWI.json","COLLECTION","0.6.2","","" +"gee-0.6.2/catalog.json","CATALOG","0.6.2","","" +"iserv-0.6.1/2013/03/27/IP0201303271418280967S05834W.json","ITEM","0.6.1","eo","" +"iserv-0.6.1/2013/03/27/catalog.json","CATALOG","0.6.1","","" +"iserv-0.6.1/2013/03/catalog.json","CATALOG","0.6.1","","" +"iserv-0.6.1/2013/catalog.json","CATALOG","0.6.1","","" +"iserv-0.6.1/catalog.json","COLLECTION","0.6.1","","" +"landsat-0.6.0/010/117/2015-01-02/LC80101172015002LGN00.json","ITEM","0.6.0","eo","" +"landsat-0.6.0/010/117/catalog.json","CATALOG","0.6.0","","" +"landsat-0.6.0/010/catalog.json","COLLECTION","0.6.0","","" +"landsat-0.6.0/156/029/2015-01-01/LC81560292015001LGN00.json","ITEM","0.6.0","eo","" +"landsat-0.6.0/156/029/catalog.json","CATALOG","0.6.0","","" +"landsat-0.6.0/156/catalog.json","COLLECTION","0.6.0","","" +"landsat-0.6.0/catalog.json","CATALOG","0.6.0","","" +"sentinel-0.6.0/catalog.json","CATALOG","0.6.0","","" +"sentinel-0.6.0/sentinel-2-l1c/9/V/XK/2017-10-13/S2B_9VXK_20171013_0.json","ITEM","0.6.0","eo","" +"sentinel-0.6.0/sentinel-2-l1c/9/V/XK/catalog.json","CATALOG","0.6.0","","" +"sentinel-0.6.0/sentinel-2-l1c/9/V/catalog.json","CATALOG","0.6.0","","" +"sentinel-0.6.0/sentinel-2-l1c/9/catalog.json","CATALOG","0.6.0","","" +"sentinel-0.6.0/sentinel-2-l1c/catalog.json","COLLECTION","0.6.0","","" +"hand-0.9.0/collection.json","COLLECTION","0.9.0","","" +"hand-0.8.1/collection.json","COLLECTION","0.8.1","","" \ No newline at end of file diff --git a/tests/extensions/test_label.py b/tests/extensions/test_label.py index 30603b3c1..b3a007f2a 100644 --- a/tests/extensions/test_label.py +++ b/tests/extensions/test_label.py @@ -3,9 +3,10 @@ import unittest from tempfile import TemporaryDirectory -import pystac +import pystac as ps from pystac import (Catalog, Item, CatalogType, STAC_IO) from pystac.extensions import label +import pystac.validation from tests.utils import (TestCases, test_to_from_dict) @@ -44,7 +45,7 @@ def test_from_file_pre_081(self): d['properties'].pop('label:methods') d['properties']['label:task'] = d['properties']['label:tasks'] d['properties'].pop('label:tasks') - label_example_1 = STAC_IO.stac_object_from_dict(d) + label_example_1 = ps.Item.from_dict(d, migrate=True) self.assertEqual(len(label_example_1.ext.label.label_tasks), 2) @@ -81,7 +82,7 @@ def test_read_label_item_owns_asset(self): self.assertEqual(item.assets[asset_key].owner, item) def test_label_description(self): - label_item = pystac.read_file(self.label_example_1_uri) + label_item = ps.Item.from_file(self.label_example_1_uri) # Get self.assertIn("label:description", label_item.properties) @@ -94,7 +95,7 @@ def test_label_description(self): label_item.validate() def test_label_type(self): - label_item = pystac.read_file(self.label_example_1_uri) + label_item = ps.Item.from_file(self.label_example_1_uri) # Get self.assertIn("label:type", label_item.properties) @@ -107,8 +108,8 @@ def test_label_type(self): label_item.validate() def test_label_properties(self): - label_item = pystac.read_file(self.label_example_1_uri) - label_item2 = pystac.read_file(self.label_example_2_uri) + label_item = ps.Item.from_file(self.label_example_1_uri) + label_item2 = ps.Item.from_file(self.label_example_2_uri) # Get self.assertIn("label:properties", label_item.properties) @@ -124,7 +125,7 @@ def test_label_properties(self): def test_label_classes(self): # Get - label_item = pystac.read_file(self.label_example_1_uri) + label_item = ps.Item.from_file(self.label_example_1_uri) label_classes = label_item.ext.label.label_classes self.assertEqual(len(label_classes), 2) @@ -145,7 +146,7 @@ def test_label_classes(self): label_item.validate() def test_label_tasks(self): - label_item = pystac.read_file(self.label_example_1_uri) + label_item = ps.Item.from_file(self.label_example_1_uri) # Get self.assertIn("label:tasks", label_item.properties) @@ -158,7 +159,7 @@ def test_label_tasks(self): label_item.validate() def test_label_methods(self): - label_item = pystac.read_file(self.label_example_1_uri) + label_item = ps.Item.from_file(self.label_example_1_uri) # Get self.assertIn("label:methods", label_item.properties) @@ -172,10 +173,10 @@ def test_label_methods(self): def test_label_overviews(self): # Get - label_item = pystac.read_file(self.label_example_1_uri) + label_item = ps.Item.from_file(self.label_example_1_uri) label_overviews = label_item.ext.label.label_overviews - label_item2 = pystac.read_file(self.label_example_2_uri) + label_item2 = ps.Item.from_file(self.label_example_2_uri) label_overviews2 = label_item2.ext.label.label_overviews self.assertEqual(len(label_overviews), 2) diff --git a/tests/serialization/test_identify.py b/tests/serialization/test_identify.py index e854d90b6..d1a61279d 100644 --- a/tests/serialization/test_identify.py +++ b/tests/serialization/test_identify.py @@ -1,10 +1,11 @@ import unittest from urllib.error import HTTPError +import pystac as ps from pystac import STAC_IO from pystac.cache import CollectionCache from pystac.serialization import (identify_stac_object, identify_stac_object_type, - merge_common_properties, STACObjectType) + merge_common_properties) from pystac.serialization.identify import (STACVersionRange, STACVersionID) from tests.utils import TestCases @@ -17,30 +18,33 @@ def setUp(self): def test_identify(self): collection_cache = CollectionCache() for example in self.examples: - path = example['path'] - d = STAC_IO.read_json(path) - if identify_stac_object_type(d) == STACObjectType.ITEM: - try: - merge_common_properties(d, json_href=path, collection_cache=collection_cache) - except HTTPError: - pass + with self.subTest(example['path']): + path = example['path'] + d = STAC_IO.read_json(path) + if identify_stac_object_type(d) == ps.STACObjectType.ITEM: + try: + merge_common_properties(d, + json_href=path, + collection_cache=collection_cache) + except HTTPError: + pass - actual = identify_stac_object(d) - # Explicitly cover __repr__ functions in tests - str_info = str(actual) - self.assertIsInstance(str_info, str) + actual = identify_stac_object(d) + # Explicitly cover __repr__ functions in tests + str_info = str(actual) + self.assertIsInstance(str_info, str) - msg = 'Failed {}:'.format(path) + msg = 'Failed {}:'.format(path) - self.assertEqual(actual.object_type, example['object_type'], msg=msg) - version_contained_in_range = actual.version_range.contains(example['stac_version']) - self.assertTrue(version_contained_in_range, msg=msg) - self.assertEqual(set(actual.common_extensions), - set(example['common_extensions']), - msg=msg) - self.assertEqual(set(actual.custom_extensions), - set(example['custom_extensions']), - msg=msg) + self.assertEqual(actual.object_type, example['object_type'], msg=msg) + version_contained_in_range = actual.version_range.contains(example['stac_version']) + self.assertTrue(version_contained_in_range, msg=msg) + self.assertEqual(set(actual.common_extensions), + set(example['common_extensions']), + msg=msg) + self.assertEqual(set(actual.custom_extensions), + set(example['custom_extensions']), + msg=msg) class VersionTest(unittest.TestCase): @@ -50,10 +54,11 @@ def test_version_ordering(self): self.assertFalse(STACVersionID('0.9.0') != STACVersionID('0.9.0')) self.assertFalse(STACVersionID('0.9.0') > STACVersionID('0.9.0')) self.assertTrue(STACVersionID('1.0.0-beta.2') < '1.0.0') - self.assertTrue(STACVersionID('0.9.1') > '0.9.0') - self.assertFalse(STACVersionID('0.9.0') > '0.9.0') - self.assertTrue(STACVersionID('0.9.0') <= '0.9.0') - self.assertTrue(STACVersionID('1.0.0-beta.1') <= STACVersionID('1.0.0-beta.2')) + self.assertTrue(STACVersionID('0.9.1') > '0.9.0') # type:ignore + self.assertFalse(STACVersionID('0.9.0') > '0.9.0') # type:ignore + self.assertTrue(STACVersionID('0.9.0') <= '0.9.0') # type:ignore + self.assertTrue( + STACVersionID('1.0.0-beta.1') <= STACVersionID('1.0.0-beta.2')) # type:ignore self.assertFalse(STACVersionID('1.0.0') < STACVersionID('1.0.0-beta.2')) def test_version_range_ordering(self): diff --git a/tests/serialization/test_migrate.py b/tests/serialization/test_migrate.py index 56f9aa05d..faffc2d17 100644 --- a/tests/serialization/test_migrate.py +++ b/tests/serialization/test_migrate.py @@ -1,10 +1,10 @@ import unittest -import pystac +import pystac as ps from pystac import (STAC_IO, STACObject) from pystac.cache import CollectionCache from pystac.serialization import (identify_stac_object, identify_stac_object_type, - merge_common_properties, migrate_to_latest, STACObjectType) + merge_common_properties, migrate_to_latest) from pystac.utils import str_to_datetime from tests.utils import TestCases @@ -21,7 +21,7 @@ def test_migrate(self): path = example['path'] d = STAC_IO.read_json(path) - if identify_stac_object_type(d) == STACObjectType.ITEM: + if identify_stac_object_type(d) == ps.STACObjectType.ITEM: merge_common_properties(d, json_href=path, collection_cache=collection_cache) info = identify_stac_object(d) @@ -32,17 +32,16 @@ def test_migrate(self): self.assertEqual(migrated_info.object_type, info.object_type) self.assertEqual(migrated_info.version_range.latest_valid_version(), - pystac.get_stac_version()) + ps.get_stac_version()) self.assertEqual(set(migrated_info.common_extensions), set(info.common_extensions)) self.assertEqual(set(migrated_info.custom_extensions), set(info.custom_extensions)) # Test that PySTAC can read it without errors. - if info.object_type != STACObjectType.ITEMCOLLECTION: - self.assertIsInstance(STAC_IO.stac_object_from_dict(migrated_d, href=path), - STACObject) + if info.object_type != ps.STACObjectType.ITEMCOLLECTION: + self.assertIsInstance(ps.read_dict(migrated_d, href=path), STACObject) def test_migrates_removed_extension(self): - item = pystac.read_file( + item = ps.Item.from_file( TestCases.get_path('data-files/examples/0.7.0/extensions/sar/' 'examples/sentinel1.json')) self.assertFalse('dtr' in item.stac_extensions) @@ -50,7 +49,7 @@ def test_migrates_removed_extension(self): str_to_datetime("2018-11-03T23:58:55.121559Z")) def test_migrates_added_extension(self): - item = pystac.read_file( + item = ps.Item.from_file( TestCases.get_path('data-files/examples/0.8.1/item-spec/' 'examples/planet-sample.json')) self.assertTrue('view' in item.stac_extensions) @@ -59,7 +58,7 @@ def test_migrates_added_extension(self): self.assertEqual(item.ext.view.off_nadir, 1) def test_migrates_renamed_extension(self): - collection = pystac.read_file( + collection = ps.Collection.from_file( TestCases.get_path('data-files/examples/0.9.0/extensions/asset/' 'examples/example-landsat8.json')) diff --git a/tests/test_catalog.py b/tests/test_catalog.py index d5b9293ed..19caab8b8 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -1,11 +1,12 @@ import os import json +from typing import Any, Dict, List, Tuple, Union, cast import unittest from tempfile import TemporaryDirectory from datetime import datetime from collections import defaultdict -import pystac +import pystac as ps from pystac import (Catalog, Collection, CatalogType, Item, Asset, MediaType, Extensions, HIERARCHICAL_LINKS) from pystac.extensions.label import LabelClasses @@ -19,7 +20,7 @@ def test_determine_type_for_absolute_published(self): cat = TestCases.test_case_1() with TemporaryDirectory() as tmp_dir: cat.normalize_and_save(tmp_dir, catalog_type=CatalogType.ABSOLUTE_PUBLISHED) - cat_json = pystac.STAC_IO.read_json(os.path.join(tmp_dir, 'catalog.json')) + cat_json = ps.STAC_IO.read_json(os.path.join(tmp_dir, 'catalog.json')) catalog_type = CatalogType.determine_type(cat_json) self.assertEqual(catalog_type, CatalogType.ABSOLUTE_PUBLISHED) @@ -28,13 +29,13 @@ def test_determine_type_for_relative_published(self): cat = TestCases.test_case_2() with TemporaryDirectory() as tmp_dir: cat.normalize_and_save(tmp_dir, catalog_type=CatalogType.RELATIVE_PUBLISHED) - cat_json = pystac.STAC_IO.read_json(os.path.join(tmp_dir, 'catalog.json')) + cat_json = ps.STAC_IO.read_json(os.path.join(tmp_dir, 'catalog.json')) catalog_type = CatalogType.determine_type(cat_json) self.assertEqual(catalog_type, CatalogType.RELATIVE_PUBLISHED) def test_determine_type_for_self_contained(self): - cat_json = pystac.STAC_IO.read_json( + cat_json = ps.STAC_IO.read_json( TestCases.get_path('data-files/catalogs/test-case-1/catalog.json')) catalog_type = CatalogType.determine_type(cat_json) self.assertEqual(catalog_type, CatalogType.SELF_CONTAINED) @@ -164,15 +165,15 @@ def test_clear_children_sets_parent_and_root_to_None(self): def test_add_child_throws_if_item(self): cat = TestCases.test_case_1() - item = next(cat.get_all_items()) - with self.assertRaises(pystac.STACError): - cat.add_child(item) + item = next(iter(cat.get_all_items())) + with self.assertRaises(ps.STACError): + cat.add_child(item) # type:ignore def test_add_item_throws_if_child(self): cat = TestCases.test_case_1() - child = next(cat.get_children()) - with self.assertRaises(pystac.STACError): - cat.add_item(child) + child = next(iter(cat.get_children())) + with self.assertRaises(ps.STACError): + cat.add_item(child) # type:ignore def test_get_child_returns_none_if_not_found(self): cat = TestCases.test_case_1() @@ -190,7 +191,7 @@ def test_sets_catalog_type(self): self.assertEqual(cat.catalog_type, CatalogType.SELF_CONTAINED) def test_walk_iterates_correctly(self): - def test_catalog(cat): + def test_catalog(cat: Catalog): expected_catalog_iterations = 1 actual_catalog_iterations = 0 with self.subTest(title='Testing catalog {}'.format(cat.id)): @@ -212,8 +213,8 @@ def test_clone_generates_correct_links(self): catalogs = TestCases.all_test_catalogs() for catalog in catalogs: - expected_link_types_to_counts = {} - actual_link_types_to_counts = {} + expected_link_types_to_counts: Any = {} + actual_link_types_to_counts: Any = {} for root, _, items in catalog.walk(): expected_link_types_to_counts[root.id] = defaultdict(int) @@ -251,10 +252,10 @@ def test_save_uses_previous_catalog_type(self): assert catalog.catalog_type == CatalogType.SELF_CONTAINED with TemporaryDirectory() as tmp_dir: catalog.normalize_hrefs(tmp_dir) - href = catalog.get_self_href() + href = catalog.self_href catalog.save() - cat2 = pystac.read_file(href) + cat2 = ps.Catalog.from_file(href) self.assertEqual(cat2.catalog_type, CatalogType.SELF_CONTAINED) def test_clone_uses_previous_catalog_type(self): @@ -270,15 +271,15 @@ def test_normalize_hrefs_sets_all_hrefs(self): self.assertTrue(root.get_self_href().startswith('http://example.com')) for link in root.links: if link.is_resolved(): - target_href = link.target.get_self_href() + target_href = cast(ps.STACObject, link.target).self_href else: - target_href = link.get_absolute_href() + target_href = link.absolute_href self.assertTrue( 'http://example.com' in target_href, '[{}] {} does not contain "{}"'.format(link.rel, target_href, 'http://example.com')) for item in items: - self.assertIn('http://example.com', item.get_self_href()) + self.assertIn('http://example.com', item.self_href) def test_normalize_hrefs_makes_absolute_href(self): catalog = TestCases.test_case_1() @@ -314,9 +315,9 @@ def test_generate_subcatalogs_does_not_change_item_count(self): with TemporaryDirectory() as tmp_dir: catalog.normalize_hrefs(tmp_dir) - catalog.save(pystac.CatalogType.SELF_CONTAINED) + catalog.save(ps.CatalogType.SELF_CONTAINED) - cat2 = pystac.read_file(os.path.join(tmp_dir, 'catalog.json')) + cat2 = ps.Catalog.from_file(os.path.join(tmp_dir, 'catalog.json')) for child in cat2.get_children(): actual = len(list(child.get_all_items())) expected = item_counts[child.id] @@ -402,13 +403,13 @@ def test_generate_subcatalogs_works_for_subcatalogs_with_same_ids(self): catalog.normalize_hrefs('/') for item in catalog.get_all_items(): - parent_href = item.get_parent().get_self_href() + parent_href = item.get_parent().self_href path_to_parent, _ = os.path.split(parent_href) subcats = [el for el in path_to_parent.split('/') if el] self.assertEqual(len(subcats), 2, msg=" for item '{}'".format(item.id)) def test_map_items(self): - def item_mapper(item): + def item_mapper(item: ps.Item) -> ps.Item: item.properties['ITEM_MAPPER'] = 'YEP' return item @@ -429,7 +430,7 @@ def item_mapper(item): self.assertFalse('ITEM_MAPPER' in item.properties) def test_map_items_multiple(self): - def item_mapper(item): + def item_mapper(item: ps.Item) -> List[ps.Item]: item2 = item.clone() item2.id = item2.id + '_2' item.properties['ITEM_MAPPER_1'] = 'YEP' @@ -485,11 +486,11 @@ def test_map_items_multiple_2(self): item2.add_asset('ortho', Asset(href='/some/other/ortho.tif')) kitten.add_item(item2) - def modify_item_title(item): + def modify_item_title(item: ps.Item) -> ps.Item: item.title = 'Some new title' return item - def create_label_item(item): + def create_label_item(item: ps.Item) -> List[ps.Item]: # Assumes the GEOJSON labels are in the # same location as the image img_href = item.assets['ortho'].href @@ -522,7 +523,7 @@ def create_label_item(item): def test_map_assets_single(self): changed_asset = 'd43bead8-e3f8-4c51-95d6-e24e750a402b' - def asset_mapper(key, asset): + def asset_mapper(key: str, asset: ps.Asset) -> ps.Asset: if key == changed_asset: asset.title = 'NEW TITLE' @@ -549,10 +550,10 @@ def asset_mapper(key, asset): self.assertTrue(found) def test_map_assets_tup(self): - changed_assets = [] + changed_assets: List[str] = [] - def asset_mapper(key, asset): - if 'geotiff' in asset.media_type: + def asset_mapper(key: str, asset: ps.Asset) -> Union[ps.Asset, Tuple[str, ps.Asset]]: + if asset.media_type and 'geotiff' in asset.media_type: asset.title = 'NEW TITLE' changed_assets.append(key) return ('{}-modified'.format(key), asset) @@ -586,8 +587,8 @@ def asset_mapper(key, asset): def test_map_assets_multi(self): changed_assets = [] - def asset_mapper(key, asset): - if 'geotiff' in asset.media_type: + def asset_mapper(key: str, asset: ps.Asset) -> Union[ps.Asset, Dict[str, ps.Asset]]: + if asset.media_type and 'geotiff' in asset.media_type: changed_assets.append(key) mod1 = asset.clone() mod1.title = 'NEW TITLE 1' @@ -649,23 +650,23 @@ def test_make_all_asset_hrefs_relative(self): self.assertEqual(asset.href, original_href) def test_make_all_links_relative_or_absolute(self): - def check_all_relative(cat): + def check_all_relative(cat: Catalog): for root, catalogs, items in cat.walk(): for link in root.links: if link.rel in HIERARCHICAL_LINKS: - self.assertFalse(is_absolute_href(link.get_href())) + self.assertFalse(is_absolute_href(link.href)) for item in items: for link in item.links: if link.rel in HIERARCHICAL_LINKS: - self.assertFalse(is_absolute_href(link.get_href())) + self.assertFalse(is_absolute_href(link.href)) - def check_all_absolute(cat): + def check_all_absolute(cat: Catalog): for root, catalogs, items in cat.walk(): for link in root.links: - self.assertTrue(is_absolute_href(link.get_href())) + self.assertTrue(is_absolute_href(link.href)) for item in items: for link in item.links: - self.assertTrue(is_absolute_href(link.get_href())) + self.assertTrue(is_absolute_href(link.href)) test_cases = TestCases.all_test_catalogs() @@ -704,7 +705,7 @@ def test_extra_fields(self): self.assertTrue('type' in cat_json) self.assertEqual(cat_json['type'], 'FeatureCollection') - read_cat = pystac.read_file(p) + read_cat = ps.Catalog.from_file(p) self.assertTrue('type' in read_cat.extra_fields) self.assertEqual(read_cat.extra_fields['type'], 'FeatureCollection') @@ -722,9 +723,9 @@ def test_validate_all(self): item.geometry = {'type': 'INVALID', 'coordinates': 'NONE'} with TemporaryDirectory() as tmp_dir: cat.normalize_hrefs(tmp_dir) - cat.save(catalog_type=pystac.CatalogType.SELF_CONTAINED) + cat.save(catalog_type=ps.CatalogType.SELF_CONTAINED) - cat2 = pystac.read_file(os.path.join(tmp_dir, 'catalog.json')) + cat2 = ps.Catalog.from_file(os.path.join(tmp_dir, 'catalog.json')) with self.assertRaises(STACValidationError): cat2.validate_all() @@ -748,7 +749,7 @@ def test_set_hrefs_manually(self): if parent is None: root_dir = tmp_dir else: - d = os.path.dirname(parent.get_self_href()) + d = os.path.dirname(parent.self_href) root_dir = os.path.join(d, root.id) root_href = os.path.join(root_dir, root.DEFAULT_FILE_NAME) root.set_self_href(root_href) @@ -768,7 +769,7 @@ def test_set_hrefs_manually(self): if parent is None: self.assertEqual(root.get_self_href(), os.path.join(tmp_dir, 'catalog.json')) else: - d = os.path.dirname(parent.get_self_href()) + d = os.path.dirname(parent.self_href) self.assertEqual(root.get_self_href(), os.path.join(d, root.id, root.DEFAULT_FILE_NAME)) for item in items: @@ -780,7 +781,7 @@ def test_collections_cache_correctly(self): for cat in catalogs: with MockStacIO() as mock_io: expected_collection_reads = set([]) - for root, children, items in cat.walk(): + for root, _, items in cat.walk(): if isinstance(root, Collection) and root != cat: expected_collection_reads.add(root.get_self_href()) @@ -834,7 +835,7 @@ def test_resolve_planet(self): def test_handles_children_with_same_id(self): # This catalog has the root and child collection share an ID. - cat = pystac.read_file(TestCases.get_path('data-files/invalid/shared-id/catalog.json')) + cat = ps.Catalog.from_file(TestCases.get_path('data-files/invalid/shared-id/catalog.json')) items = list(cat.get_all_items()) self.assertEqual(len(items), 1) @@ -849,19 +850,19 @@ def test_catalog_with_href_caches_by_href(self): class FullCopyTest(unittest.TestCase): - def check_link(self, link, tag): + def check_link(self, link: ps.Link, tag: str): if link.is_resolved(): - target_href = link.target.get_self_href() + target_href: str = cast(ps.STACObject, link.target).self_href else: - target_href = link.target + target_href = str(link.target) self.assertTrue(tag in target_href, '[{}] {} does not contain "{}"'.format(link.rel, target_href, tag)) - def check_item(self, item, tag): + def check_item(self, item: Item, tag: str): for link in item.links: self.check_link(link, tag) - def check_catalog(self, c, tag): + def check_catalog(self, c: Catalog, tag: str): self.assertEqual(len(c.get_links('root')), 1) for link in c.links: diff --git a/tests/test_collection.py b/tests/test_collection.py index 80a5519fd..b9588fd72 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -5,11 +5,9 @@ from datetime import datetime from dateutil import tz -import pystac +import pystac as ps from pystac.validation import validate_dict -from pystac.serialization.identify import STACObjectType from pystac import (Collection, Item, Extent, SpatialExtent, TemporalExtent, CatalogType) -from pystac.extensions.eo import Band from pystac.utils import datetime_to_str from tests.utils import (TestCases, RANDOM_GEOM, RANDOM_BBOX) @@ -26,82 +24,22 @@ def test_spatial_extent_from_coordinates(self): for x in bbox: self.assertTrue(type(x) is float) - def test_eo_items_are_heritable(self): - item1 = Item(id='test-item-1', - geometry=RANDOM_GEOM, - bbox=RANDOM_BBOX, - datetime=TEST_DATETIME, - properties={'key': 'one'}, - stac_extensions=['eo', 'commons']) - - item2 = Item(id='test-item-2', - geometry=RANDOM_GEOM, - bbox=RANDOM_BBOX, - datetime=TEST_DATETIME, - properties={'key': 'two'}, - stac_extensions=['eo', 'commons']) - - wv3_bands = [ - Band.create(name='Coastal', description='Coastal: 400 - 450 nm', common_name='coastal'), - Band.create(name='Blue', description='Blue: 450 - 510 nm', common_name='blue'), - Band.create(name='Green', description='Green: 510 - 580 nm', common_name='green'), - Band.create(name='Yellow', description='Yellow: 585 - 625 nm', common_name='yellow'), - Band.create(name='Red', description='Red: 630 - 690 nm', common_name='red'), - Band.create(name='Red Edge', - description='Red Edge: 705 - 745 nm', - common_name='rededge'), - Band.create(name='Near-IR1', description='Near-IR1: 770 - 895 nm', common_name='nir08'), - Band.create(name='Near-IR2', description='Near-IR2: 860 - 1040 nm', common_name='nir09') - ] - - spatial_extent = SpatialExtent(bboxes=[RANDOM_BBOX]) - temporal_extent = TemporalExtent(intervals=[[item1.datetime, None]]) - - collection_extent = Extent(spatial=spatial_extent, temporal=temporal_extent) - - common_properties = { - 'eo:bands': [b.to_dict() for b in wv3_bands], - 'gsd': 0.3, - 'eo:platform': 'Maxar', - 'eo:instrument': 'WorldView3' - } - - collection = Collection(id='test', - description='test', - extent=collection_extent, - properties=common_properties, - stac_extensions=['commons'], - license='CC-BY-SA-4.0') - - collection.add_items([item1, item2]) - - with TemporaryDirectory() as tmp_dir: - collection.normalize_hrefs(tmp_dir) - collection.save(catalog_type=CatalogType.SELF_CONTAINED) - - read_col = Collection.from_file('{}/collection.json'.format(tmp_dir)) - items = list(read_col.get_all_items()) - - self.assertEqual(len(items), 2) - self.assertTrue(items[0].ext.implements('eo')) - self.assertTrue(items[1].ext.implements('eo')) - def test_read_eo_items_are_heritable(self): cat = TestCases.test_case_5() - item = next(cat.get_all_items()) + item = next(iter(cat.get_all_items())) self.assertTrue(item.ext.implements('eo')) def test_save_uses_previous_catalog_type(self): collection = TestCases.test_case_8() - assert collection.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION + assert collection.STAC_OBJECT_TYPE == ps.STACObjectType.COLLECTION self.assertEqual(collection.catalog_type, CatalogType.SELF_CONTAINED) with TemporaryDirectory() as tmp_dir: collection.normalize_hrefs(tmp_dir) - href = collection.get_self_href() + href = collection.self_href collection.save() - collection2 = pystac.read_file(href) + collection2 = ps.Collection.from_file(href) self.assertEqual(collection2.catalog_type, CatalogType.SELF_CONTAINED) def test_clone_uses_previous_catalog_type(self): @@ -115,12 +53,12 @@ def test_multiple_extents(self): col1 = cat1.get_child('country-1').get_child('area-1-1') col1.validate() self.assertIsInstance(col1, Collection) - validate_dict(col1.to_dict(), STACObjectType.COLLECTION) + validate_dict(col1.to_dict(), ps.STACObjectType.COLLECTION) multi_ext_uri = TestCases.get_path('data-files/collections/multi-extent.json') with open(multi_ext_uri) as f: multi_ext_dict = json.load(f) - validate_dict(multi_ext_dict, STACObjectType.COLLECTION) + validate_dict(multi_ext_dict, ps.STACObjectType.COLLECTION) self.assertIsInstance(Collection.from_dict(multi_ext_dict), Collection) multi_ext_col = Collection.from_file(multi_ext_uri) @@ -149,7 +87,7 @@ def test_extra_fields(self): self.assertTrue('test' in col_json) self.assertEqual(col_json['test'], 'extra') - read_col = pystac.read_file(p) + read_col = ps.Collection.from_file(p) self.assertTrue('test' in read_col.extra_fields) self.assertEqual(read_col.extra_fields['test'], 'extra') @@ -157,6 +95,7 @@ def test_update_extents(self): catalog = TestCases.test_case_2() base_collection = catalog.get_child('1a8c1632-fa91-4a62-b33e-3a87c2ebdf16') + assert isinstance(base_collection, Collection) base_extent = base_collection.extent collection = base_collection.clone() @@ -205,19 +144,12 @@ def test_supplying_href_in_init_does_not_fail(self): collection = Collection(id='test', description='test desc', extent=collection_extent, - properties={}, href=test_href) self.assertEqual(collection.get_self_href(), test_href) - def test_reading_0_8_1_collection_as_catalog_throws_correct_exception(self): - cat = pystac.Catalog.from_file( - TestCases.get_path('data-files/examples/hand-0.8.1/collection.json')) - with self.assertRaises(ValueError): - list(cat.get_all_items()) - def test_collection_with_href_caches_by_href(self): - collection = pystac.read_file( + collection = ps.Collection.from_file( TestCases.get_path('data-files/examples/hand-0.8.1/collection.json')) cache = collection._resolved_objects diff --git a/tests/test_item.py b/tests/test_item.py index c309adca0..da85a8d90 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -4,12 +4,11 @@ import unittest from tempfile import TemporaryDirectory -import pystac +import pystac as ps from pystac import Asset, Item, Provider from pystac.validation import validate_dict from pystac.item import CommonMetadata from pystac.utils import (str_to_datetime, is_absolute_href) -from pystac.serialization.identify import STACObjectType from tests.utils import (TestCases, test_to_from_dict) @@ -66,7 +65,7 @@ def test_asset_absolute_href(self): self.assertEqual(expected_href, actual_href) def test_extra_fields(self): - item = pystac.read_file(TestCases.get_path('data-files/item/sample-item.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item.json')) item.extra_fields['test'] = 'extra' @@ -78,7 +77,7 @@ def test_extra_fields(self): self.assertTrue('test' in item_json) self.assertEqual(item_json['test'], 'extra') - read_item = pystac.read_file(p) + read_item = ps.read_file(p) self.assertTrue('test' in read_item.extra_fields) self.assertEqual(read_item.extra_fields['test'], 'extra') @@ -103,9 +102,9 @@ def test_datetime_ISO8601_format(self): self.assertEqual('2016-05-03T13:22:30.040000Z', formatted_time) def test_null_datetime(self): - item = pystac.read_file(TestCases.get_path('data-files/item/sample-item.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item.json')) - with self.assertRaises(pystac.STACError): + with self.assertRaises(ps.STACError): Item('test', geometry=item.geometry, bbox=item.bbox, datetime=None, properties={}) null_dt_item = Item('test', @@ -113,15 +112,14 @@ def test_null_datetime(self): bbox=item.bbox, datetime=None, properties={ - 'start_datetime': pystac.utils.datetime_to_str(item.datetime), - 'end_datetime': pystac.utils.datetime_to_str(item.datetime) + 'start_datetime': ps.utils.datetime_to_str(item.datetime), + 'end_datetime': ps.utils.datetime_to_str(item.datetime) }) null_dt_item.validate() def test_get_set_asset_datetime(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) item_datetime = item.datetime # No property on asset @@ -156,7 +154,7 @@ def test_null_geometry(self): with open(m) as f: item_dict = json.load(f) - validate_dict(item_dict, STACObjectType.ITEM) + validate_dict(item_dict, ps.STACObjectType.ITEM) item = Item.from_dict(item_dict) self.assertIsInstance(item, Item) @@ -168,12 +166,12 @@ def test_null_geometry(self): item_dict['bbox'] def test_0_9_item_with_no_extensions_does_not_read_collection_data(self): - item_json = pystac.STAC_IO.read_json( + item_json = ps.STAC_IO.read_json( TestCases.get_path('data-files/examples/hand-0.9.0/010100/010100.json')) assert item_json.get('stac_extensions') is None assert item_json.get('stac_version') == '0.9.0' - did_merge = pystac.serialization.common_properties.merge_common_properties(item_json) + did_merge = ps.serialization.common_properties.merge_common_properties(item_json) self.assertFalse(did_merge) def test_clone_sets_asset_owner(self): @@ -406,8 +404,7 @@ def test_common_metadata_basics(self): self.assertEqual(x.properties['gsd'], example_gsd) def test_asset_start_datetime(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.start_datetime @@ -428,8 +425,7 @@ def test_asset_start_datetime(self): self.assertEqual(cm.start_datetime, item_value) def test_asset_end_datetime(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.end_datetime @@ -450,8 +446,7 @@ def test_asset_end_datetime(self): self.assertEqual(cm.end_datetime, item_value) def test_asset_license(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.license @@ -472,15 +467,14 @@ def test_asset_license(self): self.assertEqual(cm.license, item_value) def test_asset_providers(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.providers a2_known_value = [ - pystac.Provider(name="USGS", - url="https://landsat.usgs.gov/", - roles=["producer", "licensor"]) + ps.Provider(name="USGS", + url="https://landsat.usgs.gov/", + roles=["producer", "licensor"]) ] # Get @@ -491,17 +485,14 @@ def test_asset_providers(self): self.assertEqual(a2_value[0].to_dict(), a2_known_value[0].to_dict()) # Set - set_value = [ - pystac.Provider(name="John Snow", url="https://cholera.com/", roles=["producer"]) - ] + set_value = [ps.Provider(name="John Snow", url="https://cholera.com/", roles=["producer"])] cm.set_providers(set_value, item.assets['analytic']) new_a1_value = cm.get_providers(item.assets['analytic']) self.assertEqual(new_a1_value[0].to_dict(), set_value[0].to_dict()) self.assertEqual(cm.providers[0].to_dict(), item_value[0].to_dict()) def test_asset_platform(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.platform @@ -522,8 +513,7 @@ def test_asset_platform(self): self.assertEqual(cm.platform, item_value) def test_asset_instruments(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.instruments @@ -544,8 +534,7 @@ def test_asset_instruments(self): self.assertEqual(cm.instruments, item_value) def test_asset_constellation(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.constellation @@ -566,8 +555,7 @@ def test_asset_constellation(self): self.assertEqual(cm.constellation, item_value) def test_asset_mission(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.mission @@ -588,8 +576,7 @@ def test_asset_mission(self): self.assertEqual(cm.mission, item_value) def test_asset_gsd(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.gsd @@ -610,8 +597,7 @@ def test_asset_gsd(self): self.assertEqual(cm.gsd, item_value) def test_asset_created(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.created @@ -632,8 +618,7 @@ def test_asset_created(self): self.assertEqual(cm.created, item_value) def test_asset_updated(self): - item = pystac.read_file( - TestCases.get_path('data-files/item/sample-item-asset-properties.json')) + item = ps.read_file(TestCases.get_path('data-files/item/sample-item-asset-properties.json')) cm = item.common_metadata item_value = cm.updated diff --git a/tests/test_layout.py b/tests/test_layout.py index c8995589c..a5a682dd2 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -1,8 +1,10 @@ from datetime import (datetime, timedelta) import os +from typing import Callable +from pystac.collection import Collection import unittest -import pystac +import pystac as ps from pystac.layout import (LayoutTemplate, CustomLayoutStrategy, TemplateLayoutStrategy, BestPracticesLayoutStrategy, TemplateError) from tests.utils import (TestCases, RANDOM_GEOM, RANDOM_BBOX) @@ -18,11 +20,7 @@ def test_templates_item_datetime(self): template = LayoutTemplate('${year}/${month}/${day}/${date}/item.json') - item = pystac.Item('test', - geometry=RANDOM_GEOM, - bbox=RANDOM_BBOX, - datetime=dt, - properties={}) + item = ps.Item('test', geometry=RANDOM_GEOM, bbox=RANDOM_BBOX, datetime=dt, properties={}) parts = template.get_template_values(item) @@ -45,14 +43,14 @@ def test_templates_item_start_datetime(self): template = LayoutTemplate('${year}/${month}/${day}/${date}/item.json') - item = pystac.Item('test', - geometry=RANDOM_GEOM, - bbox=RANDOM_BBOX, - datetime=None, - properties={ - 'start_datetime': dt.isoformat(), - 'end_datetime': (dt + timedelta(days=1)).isoformat() - }) + item = ps.Item('test', + geometry=RANDOM_GEOM, + bbox=RANDOM_BBOX, + datetime=None, + properties={ + 'start_datetime': dt.isoformat(), + 'end_datetime': (dt + timedelta(days=1)).isoformat() + }) parts = template.get_template_values(item) @@ -70,7 +68,7 @@ def test_templates_item_collection(self): template = LayoutTemplate('${collection}/item.json') collection = TestCases.test_case_4().get_child('acc') - item = next(collection.get_all_items()) + item = next(iter(collection.get_all_items())) assert item.collection_id is not None parts = template.get_template_values(item) @@ -85,7 +83,7 @@ def test_throws_for_no_collection(self): template = LayoutTemplate('${collection}/item.json') collection = TestCases.test_case_4().get_child('acc') - item = next(collection.get_all_items()) + item = next(iter(collection.get_all_items())) item.set_collection(None) assert item.collection_id is None @@ -97,18 +95,18 @@ def test_nested_properties(self): template = LayoutTemplate('${test.prop}/${ext:extra.test.prop}/item.json') - item = pystac.Item('test', - geometry=RANDOM_GEOM, - bbox=RANDOM_BBOX, - datetime=dt, - properties={'test': { - 'prop': 4326 - }}, - extra_fields={'ext:extra': { - 'test': { - 'prop': 3857 - } - }}) + item = ps.Item('test', + geometry=RANDOM_GEOM, + bbox=RANDOM_BBOX, + datetime=dt, + properties={'test': { + 'prop': 4326 + }}, + extra_fields={'ext:extra': { + 'test': { + 'prop': 3857 + } + }}) parts = template.get_template_values(item) @@ -126,11 +124,11 @@ def test_substitute_with_colon_properties(self): template = LayoutTemplate('${ext:prop}/item.json') - item = pystac.Item('test', - geometry=RANDOM_GEOM, - bbox=RANDOM_BBOX, - datetime=dt, - properties={'ext:prop': 1}) + item = ps.Item('test', + geometry=RANDOM_GEOM, + bbox=RANDOM_BBOX, + datetime=dt, + properties={'ext:prop': 1}) path = template.substitute(item) @@ -140,15 +138,15 @@ def test_defaults(self): template = LayoutTemplate('${doesnotexist}/collection.json', defaults={'doesnotexist': 'yes'}) - collection = TestCases.test_case_4().get_child('acc') - collection.properties = {'up': 'down'} - collection.extra_fields = {'one': 'two'} - path = template.substitute(collection) + catalog = TestCases.test_case_4().get_child('acc') + assert catalog is not None + catalog.extra_fields = {'one': 'two'} + path = template.substitute(catalog) self.assertEqual(path, 'yes/collection.json') def test_docstring_examples(self): - item = pystac.read_file( + item = ps.Item.from_file( TestCases.get_path( "data-files/examples/1.0.0-beta.2/item-spec/examples/landsat8-sample.json")) item.common_metadata.license = "CC-BY-3.0" @@ -169,27 +167,27 @@ def test_docstring_examples(self): class CustomLayoutStrategyTest(unittest.TestCase): - def get_custom_catalog_func(self): - def fn(cat, parent_dir, is_root): + def get_custom_catalog_func(self) -> Callable[[ps.Catalog, str, bool], str]: + def fn(cat: ps.Catalog, parent_dir: str, is_root: bool): return os.path.join(parent_dir, 'cat/{}/{}.json'.format(is_root, cat.id)) return fn - def get_custom_collection_func(self): - def fn(col, parent_dir, is_root): + def get_custom_collection_func(self) -> Callable[[ps.Collection, str, bool], str]: + def fn(col: ps.Collection, parent_dir: str, is_root: bool): return os.path.join(parent_dir, 'col/{}/{}.json'.format(is_root, col.id)) return fn - def get_custom_item_func(self): - def fn(item, parent_dir): + def get_custom_item_func(self) -> Callable[[ps.Item, str], str]: + def fn(item: ps.Item, parent_dir: str): return os.path.join(parent_dir, 'item/{}.json'.format(item.id)) return fn def test_produces_layout_for_catalog(self): strategy = CustomLayoutStrategy(catalog_func=self.get_custom_catalog_func()) - cat = pystac.Catalog(id='test', description='test desc') + cat = ps.Catalog(id='test', description='test desc') href = strategy.get_href(cat, parent_dir='http://example.com', is_root=True) self.assertEqual(href, 'http://example.com/cat/True/test.json') @@ -198,7 +196,7 @@ def test_produces_fallback_layout_for_catalog(self): strategy = CustomLayoutStrategy(collection_func=self.get_custom_collection_func(), item_func=self.get_custom_item_func(), fallback_strategy=fallback) - cat = pystac.Catalog(id='test', description='test desc') + cat = ps.Catalog(id='test', description='test desc') href = strategy.get_href(cat, parent_dir='http://example.com') expected = fallback.get_href(cat, parent_dir='http://example.com') self.assertEqual(href, expected) @@ -222,17 +220,17 @@ def test_produces_fallback_layout_for_collection(self): def test_produces_layout_for_item(self): strategy = CustomLayoutStrategy(item_func=self.get_custom_item_func()) collection = TestCases.test_case_8() - item = next(collection.get_all_items()) + item = next(iter(collection.get_all_items())) href = strategy.get_href(item, parent_dir='http://example.com') self.assertEqual(href, 'http://example.com/item/{}.json'.format(item.id)) def test_produces_fallback_layout_for_item(self): fallback = BestPracticesLayoutStrategy() - strategy = CustomLayoutStrategy(catalog_func=self.get_custom_item_func(), + strategy = CustomLayoutStrategy(catalog_func=self.get_custom_catalog_func(), collection_func=self.get_custom_collection_func(), fallback_strategy=fallback) collection = TestCases.test_case_8() - item = next(collection.get_all_items()) + item = next(iter(collection.get_all_items())) href = strategy.get_href(item, parent_dir='http://example.com') expected = fallback.get_href(item, parent_dir='http://example.com') self.assertEqual(href, expected) @@ -243,16 +241,21 @@ class TemplateLayoutStrategyTest(unittest.TestCase): TEST_COLLECTION_TEMPLATE = 'col/${id}/${license}' TEST_ITEM_TEMPLATE = 'item/${collection}/${id}.json' + def _get_collection(self) -> Collection: + result = TestCases.test_case_4().get_child('acc') + assert isinstance(result, Collection) + return result + def test_produces_layout_for_catalog(self): strategy = TemplateLayoutStrategy(catalog_template=self.TEST_CATALOG_TEMPLATE) - cat = pystac.Catalog(id='test', description='test-desc') + cat = ps.Catalog(id='test', description='test-desc') href = strategy.get_href(cat, parent_dir='http://example.com') self.assertEqual(href, 'http://example.com/cat/test/test-desc/catalog.json') def test_produces_layout_for_catalog_with_filename(self): template = 'cat/${id}/${description}/${id}.json' strategy = TemplateLayoutStrategy(catalog_template=template) - cat = pystac.Catalog(id='test', description='test-desc') + cat = ps.Catalog(id='test', description='test-desc') href = strategy.get_href(cat, parent_dir='http://example.com') self.assertEqual(href, 'http://example.com/cat/test/test-desc/test.json') @@ -261,14 +264,14 @@ def test_produces_fallback_layout_for_catalog(self): strategy = TemplateLayoutStrategy(collection_template=self.TEST_COLLECTION_TEMPLATE, item_template=self.TEST_ITEM_TEMPLATE, fallback_strategy=fallback) - cat = pystac.Catalog(id='test', description='test desc') + cat = ps.Catalog(id='test', description='test desc') href = strategy.get_href(cat, parent_dir='http://example.com') expected = fallback.get_href(cat, parent_dir='http://example.com') self.assertEqual(href, expected) def test_produces_layout_for_collection(self): strategy = TemplateLayoutStrategy(collection_template=self.TEST_COLLECTION_TEMPLATE) - collection = TestCases.test_case_4().get_child('acc') + collection = self._get_collection() href = strategy.get_href(collection, parent_dir='http://example.com') self.assertEqual( href, @@ -278,7 +281,7 @@ def test_produces_layout_for_collection(self): def test_produces_layout_for_collection_with_filename(self): template = 'col/${id}/${license}/col.json' strategy = TemplateLayoutStrategy(collection_template=template) - collection = TestCases.test_case_4().get_child('acc') + collection = self._get_collection() href = strategy.get_href(collection, parent_dir='http://example.com') self.assertEqual( href, 'http://example.com/col/{}/{}/col.json'.format(collection.id, collection.license)) @@ -288,15 +291,15 @@ def test_produces_fallback_layout_for_collection(self): strategy = TemplateLayoutStrategy(catalog_template=self.TEST_CATALOG_TEMPLATE, item_template=self.TEST_ITEM_TEMPLATE, fallback_strategy=fallback) - collection = TestCases.test_case_4().get_child('acc') + collection = self._get_collection() href = strategy.get_href(collection, parent_dir='http://example.com') expected = fallback.get_href(collection, parent_dir='http://example.com') self.assertEqual(href, expected) def test_produces_layout_for_item(self): strategy = TemplateLayoutStrategy(item_template=self.TEST_ITEM_TEMPLATE) - collection = TestCases.test_case_4().get_child('acc') - item = next(collection.get_all_items()) + collection = self._get_collection() + item = next(iter(collection.get_all_items())) href = strategy.get_href(item, parent_dir='http://example.com') self.assertEqual(href, 'http://example.com/item/{}/{}.json'.format(item.collection_id, item.id)) @@ -304,8 +307,8 @@ def test_produces_layout_for_item(self): def test_produces_layout_for_item_without_filename(self): template = 'item/${collection}' strategy = TemplateLayoutStrategy(item_template=template) - collection = TestCases.test_case_4().get_child('acc') - item = next(collection.get_all_items()) + collection = self._get_collection() + item = next(iter(collection.get_all_items())) href = strategy.get_href(item, parent_dir='http://example.com') self.assertEqual(href, 'http://example.com/item/{}/{}.json'.format(item.collection_id, item.id)) @@ -315,8 +318,8 @@ def test_produces_fallback_layout_for_item(self): strategy = TemplateLayoutStrategy(catalog_template=self.TEST_CATALOG_TEMPLATE, collection_template=self.TEST_COLLECTION_TEMPLATE, fallback_strategy=fallback) - collection = TestCases.test_case_4().get_child('acc') - item = next(collection.get_all_items()) + collection = self._get_collection() + item = next(iter(collection.get_all_items())) href = strategy.get_href(item, parent_dir='http://example.com') expected = fallback.get_href(item, parent_dir='http://example.com') self.assertEqual(href, expected) @@ -327,12 +330,12 @@ def setUp(self): self.strategy = BestPracticesLayoutStrategy() def test_produces_layout_for_root_catalog(self): - cat = pystac.Catalog(id='test', description='test desc') + cat = ps.Catalog(id='test', description='test desc') href = self.strategy.get_href(cat, parent_dir='http://example.com', is_root=True) self.assertEqual(href, 'http://example.com/catalog.json') def test_produces_layout_for_child_catalog(self): - cat = pystac.Catalog(id='test', description='test desc') + cat = ps.Catalog(id='test', description='test desc') href = self.strategy.get_href(cat, parent_dir='http://example.com') self.assertEqual(href, 'http://example.com/test/catalog.json') @@ -348,7 +351,7 @@ def test_produces_layout_for_child_collection(self): def test_produces_layout_for_item(self): collection = TestCases.test_case_8() - item = next(collection.get_all_items()) + item = next(iter(collection.get_all_items())) href = self.strategy.get_href(item, parent_dir='http://example.com') expected = 'http://example.com/{}/{}.json'.format(item.id, item.id) self.assertEqual(href, expected) diff --git a/tests/test_writing.py b/tests/test_writing.py index d5ac1d311..1ebbc6e7f 100644 --- a/tests/test_writing.py +++ b/tests/test_writing.py @@ -1,9 +1,8 @@ import unittest from tempfile import TemporaryDirectory -import pystac -from pystac import (STAC_IO, STACObject, Collection, CatalogType, HIERARCHICAL_LINKS) -from pystac.serialization import (STACObjectType) +import pystac as ps +from pystac import (STAC_IO, Collection, CatalogType, HIERARCHICAL_LINKS) from pystac.utils import is_absolute_href, make_absolute_href, make_relative_href from pystac.validation import validate_dict @@ -14,7 +13,7 @@ class STACWritingTest(unittest.TestCase): """Tests writing STACs, using JSON Schema validation, and ensure that links are correctly set to relative or absolute. """ - def validate_catalog(self, catalog): + def validate_catalog(self, catalog: ps.Catalog): catalog.validate() validated_count = 1 @@ -27,12 +26,12 @@ def validate_catalog(self, catalog): return validated_count - def validate_file(self, path, object_type): + def validate_file(self, path: str, object_type: str): d = STAC_IO.read_json(path) return validate_dict(d, object_type) - def validate_link_types(self, root_href, catalog_type): - def validate_asset_href_type(item, item_href, link_type): + def validate_link_types(self, root_href: str, catalog_type: ps.CatalogType): + def validate_asset_href_type(item: ps.Item, item_href: str): for asset in item.assets.values(): if not is_absolute_href(asset.href): is_valid = not is_absolute_href(asset.href) @@ -44,36 +43,36 @@ def validate_asset_href_type(item, item_href, link_type): else: self.assertTrue(is_valid) - def validate_item_link_type(href, link_type, should_include_self): + def validate_item_link_type(href: str, link_type: str, should_include_self: bool): item_dict = STAC_IO.read_json(href) - item = STACObject.from_file(href) - rel_links = HIERARCHICAL_LINKS + pystac.STAC_EXTENSIONS.get_extended_object_links(item) + item = ps.Item.from_file(href) + rel_links = HIERARCHICAL_LINKS + ps.STAC_EXTENSIONS.get_extended_object_links(item) for link in item.get_links(): if not link.rel == 'self': if link_type == 'RELATIVE' and link.rel in rel_links: - self.assertFalse(is_absolute_href(link.get_href())) + self.assertFalse(is_absolute_href(link.href)) else: - self.assertTrue(is_absolute_href(link.get_href())) + self.assertTrue(is_absolute_href(link.href)) - validate_asset_href_type(item, href, link_type) + validate_asset_href_type(item, href) rels = set([link['rel'] for link in item_dict['links']]) self.assertEqual('self' in rels, should_include_self) - def validate_catalog_link_type(href, link_type, should_include_self): + def validate_catalog_link_type(href: str, link_type: str, should_include_self: bool): cat_dict = STAC_IO.read_json(href) - cat = STACObject.from_file(href) + cat = ps.Catalog.from_file(href) rels = set([link['rel'] for link in cat_dict['links']]) self.assertEqual('self' in rels, should_include_self) for child_link in cat.get_child_links(): - child_href = make_absolute_href(child_link.target, href) + child_href = make_absolute_href(child_link.href, href) validate_catalog_link_type(child_href, link_type, catalog_type == CatalogType.ABSOLUTE_PUBLISHED) for item_link in cat.get_item_links(): - item_href = make_absolute_href(item_link.target, href) + item_href = make_absolute_href(item_link.href, href) validate_item_link_type(item_href, link_type, catalog_type == CatalogType.ABSOLUTE_PUBLISHED) @@ -87,25 +86,25 @@ def validate_catalog_link_type(href, link_type, should_include_self): validate_catalog_link_type(root_href, link_type, root_should_include_href) - def do_test(self, catalog, catalog_type): + def do_test(self, catalog: ps.Catalog, catalog_type: ps.CatalogType): with TemporaryDirectory() as tmp_dir: catalog.normalize_hrefs(tmp_dir) self.validate_catalog(catalog) catalog.save(catalog_type=catalog_type) - root_href = catalog.get_self_href() + root_href = catalog.self_href self.validate_link_types(root_href, catalog_type) - for parent, children, items in catalog.walk(): + for parent, _, items in catalog.walk(): if issubclass(type(parent), Collection): - stac_object_type = STACObjectType.COLLECTION + stac_object_type = ps.STACObjectType.COLLECTION else: - stac_object_type = STACObjectType.CATALOG - self.validate_file(parent.get_self_href(), stac_object_type) + stac_object_type = ps.STACObjectType.CATALOG + self.validate_file(parent.self_href, stac_object_type) for item in items: - self.validate_file(item.get_self_href(), STACObjectType.ITEM) + self.validate_file(item.self_href, ps.STACObjectType.ITEM) def test_testcases(self): for catalog in TestCases.all_test_catalogs(): diff --git a/tests/utils/test_cases.py b/tests/utils/test_cases.py index b5b2591a8..3b3f8f2f3 100644 --- a/tests/utils/test_cases.py +++ b/tests/utils/test_cases.py @@ -1,10 +1,10 @@ import os from datetime import datetime import csv +from typing import Any, Dict, List -import pystac -from pystac import (Catalog, Item, Asset, Extent, TemporalExtent, SpatialExtent, MediaType, - Extensions) +from pystac import (Catalog, Collection, Item, Asset, Extent, TemporalExtent, SpatialExtent, + MediaType, Extensions) from pystac.extensions.label import (LabelOverview, LabelClasses, LabelCount) TEST_LABEL_CATALOG = { @@ -34,7 +34,7 @@ } } -RANDOM_GEOM = { +RANDOM_GEOM: Dict[str, Any] = { "type": "Polygon", "coordinates": [[[-2.5048828125, 3.8916575492899987], [-1.9610595703125, 3.8916575492899987], @@ -42,7 +42,7 @@ [-2.5048828125, 3.8916575492899987]]] } -RANDOM_BBOX = [ +RANDOM_BBOX: List[float] = [ RANDOM_GEOM['coordinates'][0][0][0], RANDOM_GEOM['coordinates'][0][0][1], RANDOM_GEOM['coordinates'][0][1][0], RANDOM_GEOM['coordinates'][0][1][1] ] @@ -53,11 +53,11 @@ class TestCases: @staticmethod - def get_path(rel_path): + def get_path(rel_path: str) -> str: return os.path.abspath(os.path.join(os.path.dirname(__file__), '..', rel_path)) @staticmethod - def get_examples_info(): + def get_examples_info() -> List[Dict[str, Any]]: examples = [] info_path = TestCases.get_path('data-files/examples/example-info.csv') @@ -90,7 +90,7 @@ def get_examples_info(): return examples @staticmethod - def all_test_catalogs(): + def all_test_catalogs() -> List[Catalog]: return [ TestCases.test_case_1(), TestCases.test_case_2(), @@ -102,15 +102,15 @@ def all_test_catalogs(): ] @staticmethod - def test_case_1(): + def test_case_1() -> Catalog: return Catalog.from_file(TestCases.get_path('data-files/catalogs/test-case-1/catalog.json')) @staticmethod - def test_case_2(): + def test_case_2() -> Catalog: return Catalog.from_file(TestCases.get_path('data-files/catalogs/test-case-2/catalog.json')) @staticmethod - def test_case_3(): + def test_case_3() -> Catalog: root_cat = Catalog(id='test3', description='test case 3 catalog', title='test case 3 title') image_item = Item(id='imagery-item', @@ -175,8 +175,8 @@ def test_case_7(): TestCases.get_path('data-files/catalogs/label_catalog_0_8_1/catalog.json')) @staticmethod - def test_case_8(): + def test_case_8() -> Collection: """Planet disaster data example catalog, 1.0.0-beta.2""" - return pystac.read_file( + return Collection.from_file( TestCases.get_path('data-files/catalogs/' 'planet-example-1.0.0-beta.2/collection.json')) diff --git a/tests/validation/test_validate.py b/tests/validation/test_validate.py index 8b11ae193..48a70bfaf 100644 --- a/tests/validation/test_validate.py +++ b/tests/validation/test_validate.py @@ -8,6 +8,8 @@ import jsonschema import pystac +import pystac.validation +from pystac.cache import CollectionCache from pystac.serialization.common_properties import merge_common_properties from pystac.validation import STACValidationError from tests.utils import TestCases @@ -31,35 +33,36 @@ def test_validate_current_version(self): def test_validate_examples(self): for example in TestCases.get_examples_info(): - stac_version = example['stac_version'] - path = example['path'] - valid = example['valid'] + with self.subTest(example['path']): + stac_version = example['stac_version'] + path = example['path'] + valid = example['valid'] - if stac_version < '0.8': - with open(path) as f: - stac_json = json.load(f) - - self.assertEqual(len(pystac.validation.validate_dict(stac_json)), 0) - else: - with self.subTest(path): + if stac_version < '0.8': with open(path) as f: stac_json = json.load(f) - # Check if common properties need to be merged - if stac_version < '1.0': - if example['object_type'] == pystac.STACObjectType.ITEM: - collection_cache = pystac.cache.CollectionCache() - merge_common_properties(stac_json, collection_cache, path) - - if valid: - pystac.validation.validate_dict(stac_json) - else: - with self.assertRaises(STACValidationError): - try: - pystac.validation.validate_dict(stac_json) - except STACValidationError as e: - self.assertIsInstance(e.source, jsonschema.ValidationError) - raise e + self.assertEqual(len(pystac.validation.validate_dict(stac_json)), 0) + else: + with self.subTest(path): + with open(path) as f: + stac_json = json.load(f) + + # Check if common properties need to be merged + if stac_version < '1.0': + if example['object_type'] == pystac.STACObjectType.ITEM: + collection_cache = CollectionCache() + merge_common_properties(stac_json, collection_cache, path) + + if valid: + pystac.validation.validate_dict(stac_json) + else: + with self.assertRaises(STACValidationError): + try: + pystac.validation.validate_dict(stac_json) + except STACValidationError as e: + self.assertIsInstance(e.source, jsonschema.ValidationError) + raise e def test_validate_error_contains_href(self): # Test that the exception message contains the HREF of the object if available. @@ -130,4 +133,5 @@ def test_validates_geojson_with_tuple_coordinates(self): datetime=datetime.utcnow(), properties={}) - self.assertIsNone(item.validate()) + # Should not raise. + item.validate()