Skip to content

Commit

Permalink
Update code and tests to pass with type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
lossyrob committed Apr 25, 2021
1 parent d531cc2 commit 4bb702a
Show file tree
Hide file tree
Showing 34 changed files with 905 additions and 683 deletions.
16 changes: 14 additions & 2 deletions pystac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ class STACError(Exception):
pass


class STACTypeError(Exception):
"""A STACTypeError is raised when encountering a representation of
a STAC entity that is not correct for the context; for example, if
a Catalog JSON was read in as an Item.
"""
pass


from typing import Any, Dict, Optional
from pystac.version import (__version__, get_stac_version, set_stac_version) # type:ignore
from pystac.stac_io import STAC_IO
Expand All @@ -21,8 +29,12 @@ class STACError(Exception):
from pystac.media_type import MediaType # type:ignore
from pystac.link import (Link, HIERARCHICAL_LINKS) # type:ignore
from pystac.catalog import (Catalog, CatalogType) # type:ignore
from pystac.collection import (Collection, Extent, SpatialExtent, TemporalExtent, # type:ignore
Provider) # type:ignore
from pystac.collection import (
Collection, # type:ignore
Extent, # type:ignore
SpatialExtent, # type:ignore
TemporalExtent, # type:ignore
Provider) # type:ignore
from pystac.item import (Item, Asset, CommonMetadata) # type:ignore

from pystac.serialization import stac_object_from_dict
Expand Down
60 changes: 32 additions & 28 deletions pystac/cache.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from collections import ChainMap
from copy import copy
from pystac.collection import Collection
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from pystac.stac_object import STACObject
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast

import pystac
import pystac as ps

if TYPE_CHECKING:
from pystac.stac_object import STACObject
from pystac.collection import Collection

def get_cache_key(stac_object: STACObject) -> Tuple[str, bool]:

def get_cache_key(stac_object: "STACObject") -> Tuple[str, bool]:
"""Produce a cache key for the given STAC object.
If a self href is set, use that as the cache key.
Expand Down Expand Up @@ -56,16 +58,16 @@ class ResolvedObjectCache:
ids_to_collections (Dict[str, Collection]): Map of collection IDs to collections.
"""
def __init__(self,
id_keys_to_objects: Optional[Dict[str, STACObject]] = None,
hrefs_to_objects: Optional[Dict[str, STACObject]] = None,
ids_to_collections: Dict[str, Collection] = None):
id_keys_to_objects: Optional[Dict[str, "STACObject"]] = None,
hrefs_to_objects: Optional[Dict[str, "STACObject"]] = None,
ids_to_collections: Dict[str, "Collection"] = None):
self.id_keys_to_objects = id_keys_to_objects or {}
self.hrefs_to_objects = hrefs_to_objects or {}
self.ids_to_collections = ids_to_collections or {}

self._collection_cache = None

def get_or_cache(self, obj: STACObject) -> STACObject:
def get_or_cache(self, obj: "STACObject") -> "STACObject":
"""Gets the STACObject that is the cached version of the given STACObject; or, if
none exists, sets the cached object to the given object.
Expand All @@ -91,7 +93,7 @@ def get_or_cache(self, obj: STACObject) -> STACObject:
self.cache(obj)
return obj

def get(self, obj: STACObject) -> Optional[STACObject]:
def get(self, obj: "STACObject") -> Optional["STACObject"]:
"""Get the cached object that has the same cache key as the given object.
Args:
Expand All @@ -107,7 +109,7 @@ def get(self, obj: STACObject) -> Optional[STACObject]:
else:
return self.id_keys_to_objects.get(key)

def get_by_href(self, href: str) -> Optional[STACObject]:
def get_by_href(self, href: str) -> Optional["STACObject"]:
"""Gets the cached object at href.
Args:
Expand All @@ -118,7 +120,7 @@ def get_by_href(self, href: str) -> Optional[STACObject]:
"""
return self.hrefs_to_objects.get(href)

def get_collection_by_id(self, id: str) -> Optional[Collection]:
def get_collection_by_id(self, id: str) -> Optional["Collection"]:
"""Retrieved a cached Collection by its ID.
Args:
Expand All @@ -130,7 +132,7 @@ def get_collection_by_id(self, id: str) -> Optional[Collection]:
"""
return self.ids_to_collections.get(id)

def cache(self, obj: STACObject) -> None:
def cache(self, obj: "STACObject") -> None:
"""Set the given object into the cache.
Args:
Expand All @@ -142,10 +144,10 @@ def cache(self, obj: STACObject) -> None:
else:
self.id_keys_to_objects[key] = obj

if isinstance(obj, Collection):
if isinstance(obj, ps.Collection):
self.ids_to_collections[obj.id] = obj

def remove(self, obj: STACObject) -> None:
def remove(self, obj: "STACObject") -> None:
"""Removes any cached object that matches the given object's cache key.
Args:
Expand All @@ -158,10 +160,10 @@ def remove(self, obj: STACObject) -> None:
else:
self.id_keys_to_objects.pop(key, None)

if obj.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION:
if obj.STAC_OBJECT_TYPE == ps.STACObjectType.COLLECTION:
self.id_keys_to_objects.pop(obj.id, None)

def __contains__(self, obj: STACObject) -> bool:
def __contains__(self, obj: "STACObject") -> bool:
key, is_href = get_cache_key(obj)
return key in self.hrefs_to_objects if is_href else key in self.id_keys_to_objects

Expand Down Expand Up @@ -213,23 +215,25 @@ class CollectionCache:
and will set Collection JSON that it reads in order to merge in common properties.
"""
def __init__(self,
cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None,
cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None):
cached_ids: Dict[str, Union["Collection", Dict[str, Any]]] = None,
cached_hrefs: Dict[str, Union["Collection", Dict[str, Any]]] = None):
self.cached_ids = cached_ids or {}
self.cached_hrefs = cached_hrefs or {}

def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]:
def get_by_id(self, collection_id: str) -> Optional[Union["Collection", Dict[str, Any]]]:
return self.cached_ids.get(collection_id)

def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
def get_by_href(self, href: str) -> Optional[Union["Collection", Dict[str, Any]]]:
return self.cached_hrefs.get(href)

def contains_id(self, collection_id: str) -> bool:
return collection_id in self.cached_ids

def cache(self, collection: Union[Collection, Dict[str, Any]], href: Optional[str] = None) -> None:
def cache(self,
collection: Union["Collection", Dict[str, Any]],
href: Optional[str] = None) -> None:
"""Caches a collection JSON."""
if isinstance(collection, Collection):
if isinstance(collection, ps.Collection):
self.cached_ids[collection.id] = collection
else:
self.cached_ids[collection['id']] = collection
Expand All @@ -241,24 +245,24 @@ def cache(self, collection: Union[Collection, Dict[str, Any]], href: Optional[st
class ResolvedObjectCollectionCache(CollectionCache):
def __init__(self,
resolved_object_cache: ResolvedObjectCache,
cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None,
cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None):
cached_ids: Dict[str, Union["Collection", Dict[str, Any]]] = None,
cached_hrefs: Dict[str, Union["Collection", Dict[str, Any]]] = None):
super().__init__(cached_ids, cached_hrefs)
self.resolved_object_cache = resolved_object_cache

def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]:
def get_by_id(self, collection_id: str) -> Optional[Union["Collection", Dict[str, Any]]]:
result = self.resolved_object_cache.get_collection_by_id(collection_id)
if result is None:
return super().get_by_id(collection_id)
else:
return result

def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
def get_by_href(self, href: str) -> Optional[Union["Collection", Dict[str, Any]]]:
result = self.resolved_object_cache.get_by_href(href)
if result is None:
return super().get_by_href(href)
else:
return cast(Collection, result)
return cast(ps.Collection, result)

def contains_id(self, collection_id: str) -> bool:
return (self.resolved_object_cache.contains_collection_id(collection_id)
Expand Down
Loading

0 comments on commit 4bb702a

Please sign in to comment.