Skip to content

Commit

Permalink
Merge pull request #454 from stac-utils/fix/rde/avoid-deepcopy
Browse files Browse the repository at this point in the history
Avoid calling deepcopy in from_dict methods when unnecessary
  • Loading branch information
Jon Duckworth authored Jun 17, 2021
2 parents 2739661 + 68df9bd commit 6937607
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 10 deletions.
4 changes: 3 additions & 1 deletion pystac/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ def from_dict(
href: Optional[str] = None,
root: Optional["Catalog"] = None,
migrate: bool = False,
preserve_dict: bool = True,
) -> "Catalog":
if migrate:
info = identify_stac_object(d)
Expand All @@ -916,7 +917,8 @@ def from_dict(

catalog_type = CatalogType.determine_type(d)

d = deepcopy(d)
if preserve_dict:
d = deepcopy(d)

id = d.pop("id")
description = d.pop("description")
Expand Down
5 changes: 4 additions & 1 deletion pystac/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def from_dict(
href: Optional[str] = None,
root: Optional[Catalog] = None,
migrate: bool = False,
preserve_dict: bool = True,
) -> "Collection":
if migrate:
info = identify_stac_object(d)
Expand All @@ -597,7 +598,9 @@ def from_dict(

catalog_type = CatalogType.determine_type(d)

d = deepcopy(d)
if preserve_dict:
d = deepcopy(d)

id = d.pop("id")
description = d.pop("description")
license = d.pop("license")
Expand Down
5 changes: 4 additions & 1 deletion pystac/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def from_dict(
href: Optional[str] = None,
root: Optional[Catalog] = None,
migrate: bool = False,
preserve_dict: bool = True,
) -> "Item":
if migrate:
info = identify_stac_object(d)
Expand All @@ -925,7 +926,9 @@ def from_dict(
f"{d} does not represent a {cls.__name__} instance"
)

d = deepcopy(d)
if preserve_dict:
d = deepcopy(d)

id = d.pop("id")
geometry = d.pop("geometry")
properties = d.pop("properties")
Expand Down
15 changes: 11 additions & 4 deletions pystac/stac_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def stac_object_from_dict(
d: Dict[str, Any],
href: Optional[str] = None,
root: Optional["Catalog_Type"] = None,
preserve_dict: bool = True,
) -> "STACObject_Type":
if identify_stac_object_type(d) == pystac.STACObjectType.ITEM:
collection_cache = None
Expand All @@ -114,15 +115,21 @@ def stac_object_from_dict(
d = migrate_to_latest(d, info)

if info.object_type == pystac.STACObjectType.CATALOG:
result = pystac.Catalog.from_dict(d, href=href, root=root, migrate=False)
result = pystac.Catalog.from_dict(
d, href=href, root=root, migrate=False, preserve_dict=preserve_dict
)
result._stac_io = self
return result

if info.object_type == pystac.STACObjectType.COLLECTION:
return pystac.Collection.from_dict(d, href=href, root=root, migrate=False)
return pystac.Collection.from_dict(
d, href=href, root=root, migrate=False, preserve_dict=preserve_dict
)

if info.object_type == pystac.STACObjectType.ITEM:
return pystac.Item.from_dict(d, href=href, root=root, migrate=False)
return pystac.Item.from_dict(
d, href=href, root=root, migrate=False, preserve_dict=preserve_dict
)

raise ValueError(f"Unknown STAC object type {info.object_type}")

Expand Down Expand Up @@ -164,7 +171,7 @@ def read_stac_object(
"""
d = self.read_json(source)
href = source if isinstance(source, str) else source.get_absolute_href()
return self.stac_object_from_dict(d, href=href, root=root)
return self.stac_object_from_dict(d, href=href, root=root, preserve_dict=False)

def save_json(
self, dest: Union[str, "Link_Type"], json_dict: Dict[str, Any]
Expand Down
8 changes: 7 additions & 1 deletion pystac/stac_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def from_file(
href = make_absolute_href(href)

d = stac_io.read_json(href)
o = cls.from_dict(d, href=href, migrate=True)
o = cls.from_dict(d, href=href, migrate=True, preserve_dict=False)

# Set the self HREF, if it's not already set to something else.
if o.get_self_href() is None:
Expand All @@ -495,6 +495,7 @@ def from_dict(
href: Optional[str] = None,
root: Optional["Catalog_Type"] = None,
migrate: bool = False,
preserve_dict: bool = True,
) -> "STACObject":
"""Parses this STACObject from the passed in dictionary.
Expand All @@ -507,6 +508,11 @@ def from_dict(
previously resolved instances of the STAC object.
migrate: Use True if this dict represents JSON from an older STAC object,
so that migrations are run against it.
preserve_dict: If False, the dict parameter ``d`` may be modified
during this method call. Otherwise the dict is not mutated.
Defaults to True, which results results in a deepcopy of the
parameter. Set to False when possible to avoid the performance
hit of a deepcopy.
Returns:
STACObject: The STACObject parsed from this dict.
Expand Down
14 changes: 14 additions & 0 deletions tests/test_catalog.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import os
import json
import tempfile
Expand Down Expand Up @@ -85,6 +86,19 @@ def test_create_and_read(self) -> None:

self.assertEqual(len(list(items)), 8)

def test_from_dict_preserves_dict(self) -> None:
catalog_dict = TestCases.test_case_1().to_dict()
param_dict = deepcopy(catalog_dict)

# test that the parameter is preserved
_ = Catalog.from_dict(param_dict)
self.assertEqual(param_dict, catalog_dict)

# assert that the parameter is not preserved with
# non-default parameter
_ = Catalog.from_dict(param_dict, preserve_dict=False)
self.assertNotEqual(param_dict, catalog_dict)

def test_read_remote(self) -> None:
# TODO: Move this URL to the main stac-spec repo once the example JSON is fixed.
catalog_url = (
Expand Down
16 changes: 16 additions & 0 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import unittest
import os
import json
Expand Down Expand Up @@ -182,6 +183,21 @@ def test_assets(self) -> None:
collection = pystac.Collection.from_dict(data)
collection.validate()

def test_to_dict_preserves_dict(self) -> None:
path = TestCases.get_path("data-files/collections/with-assets.json")
with open(path) as f:
collection_dict = json.load(f)
param_dict = deepcopy(collection_dict)

# test that the parameter is preserved
_ = Collection.from_dict(param_dict)
self.assertEqual(param_dict, collection_dict)

# assert that the parameter is not preserved with
# non-default parameter
_ = Collection.from_dict(param_dict, preserve_dict=False)
self.assertNotEqual(param_dict, collection_dict)

def test_schema_summary(self) -> None:
collection = pystac.Collection.from_file(
TestCases.get_path(
Expand Down
14 changes: 12 additions & 2 deletions tests/test_item.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import os
from datetime import datetime
import json
Expand Down Expand Up @@ -25,9 +26,10 @@ def test_to_from_dict(self) -> None:
self.maxDiff = None

item_dict = self.get_example_item_dict()
param_dict = deepcopy(item_dict)

assert_to_from_dict(self, Item, item_dict)
item = Item.from_dict(item_dict)
assert_to_from_dict(self, Item, param_dict)
item = Item.from_dict(param_dict)
self.assertEqual(item.id, "CS3-20160503_132131_05")

# test asset creation additional field(s)
Expand All @@ -37,6 +39,14 @@ def test_to_from_dict(self) -> None:
)
self.assertEqual(len(item.assets["thumbnail"].properties), 0)

# test that the parameter is preserved
self.assertEqual(param_dict, item_dict)

# assert that the parameter is not preserved with
# non-default parameter
_ = Item.from_dict(param_dict, preserve_dict=False)
self.assertNotEqual(param_dict, item_dict)

def test_set_self_href_does_not_break_asset_hrefs(self) -> None:
cat = TestCases.test_case_2()
for item in cat.get_all_items():
Expand Down

0 comments on commit 6937607

Please sign in to comment.