Skip to content

Commit

Permalink
f-s: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 12, 2023
1 parent 06c83b9 commit bb47f4e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
13 changes: 9 additions & 4 deletions src/aiida_pseudo/data/pseudo/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,28 @@
import typing

from aiida import orm, plugins
from aiida.orm.nodes.caching import NodeCaching
from aiida.common.constants import elements
from aiida.common.exceptions import StoringNotAllowed
from aiida.common.files import md5_from_filelike

__all__ = ('PseudoPotentialData',)

class PseudoPotentialDataCaching(NodeCaching):
"""Class to define caching behavior of ``PseudoPotentialData`` nodes."""

def _get_objects_to_hash(self) -> list:
"""Return a list of objects which should be included in the node hash."""
return [self._node.element, self._node.md5]

class PseudoPotentialData(plugins.DataFactory('core.singlefile')):
"""Base class for data types representing pseudo potentials."""

_key_element = 'element'
_key_md5 = 'md5'

_CLS_NODE_CACHING = PseudoPotentialDataCaching

@classmethod
def get_or_create(cls, source: typing.Union[str, pathlib.Path, typing.BinaryIO], filename: str = None):
"""Get pseudopotenial data node from database with matching md5 checksum or create a new one if not existent.
Expand Down Expand Up @@ -184,7 +193,3 @@ def md5(self, value: str):
"""
self.validate_md5(value)
self.base.attributes.set(self._key_md5, value)

def _get_objects_to_hash(self) -> list:
"""Return a list of objects which should be included in the node hash."""
return [self.element, self.md5]
3 changes: 0 additions & 3 deletions tests/data/pseudo/test_pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,5 @@ def test_get_or_create(get_pseudo_potential_data):
assert different_class.uuid != original.uuid

# check hash is from content and not from filename
print(original._get_objects_to_hash())
print([original.element, original.md5])
print(make_hash(original._get_objects_to_hash()))
assert original._get_objects_to_hash() == [original.element, original.md5]
assert original.get_hash() == make_hash([original.element, original.md5])

0 comments on commit bb47f4e

Please sign in to comment.