Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: equal() to API standard #152

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
cd /tmp
git clone https://github.com/kokkos/pykokkos-base.git
cd pykokkos-base
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5
- name: Install pykokkos
run: |
python -m pip install .
Expand All @@ -49,4 +49,4 @@ jobs:
# for hypothesis-driven test case generation
pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s
# only run a subset of the conformance tests to get started
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_equal
2 changes: 1 addition & 1 deletion .github/workflows/main_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
cd /tmp
git clone https://github.com/kokkos/pykokkos-base.git
cd pykokkos-base
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5
- name: Install pykokkos
run: |
python -m pip install .
Expand Down
1 change: 1 addition & 0 deletions pykokkos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from pykokkos.lib.manipulate import reshape, ravel, expand_dims
from pykokkos.lib.util import all, any, sum, find_max, searchsorted, col, linspace, logspace
from pykokkos.lib.constants import e, pi, inf, nan
from pykokkos.interface.views import astype

__array_api_version__ = "2021.12"

Expand Down
48 changes: 42 additions & 6 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,43 @@ def _get_type(self, dtype: Union[DataType, type]) -> Optional[DataType]:


def __eq__(self, other):
if not isinstance(other, pk.View) and self.rank() > 0:
return [i == other for i in self]

if self.array == other:
return True
# avoid circular import with scoped import
from pykokkos.lib.ufuncs import equal
if isinstance(other, float):
new_other = pk.View((), dtype=pk.double)
new_other[:] = other
elif isinstance(other, bool):
new_other = pk.View((), dtype=pk.bool)
new_other[:] = other
elif isinstance(other, int):
if self.ndim == 0:
ret = pk.View((), dtype=pk.bool)
ret[:] = int(self) == other
return ret
if 0 <= other <= 255:
other_dtype = pk.uint8
elif 0 <= other <= 65535:
other_dtype = pk.uint16
elif 0 <= other <= 4294967295:
other_dtype = pk.uint32
elif 0 <= other <= 18446744073709551615:
other_dtype = pk.uint64
elif -128 <= other <= 127:
other_dtype = pk.int8
elif -32768 <= other <= 32767:
other_dtype = pk.int16
elif -2147483648 <= other <= 2147483647:
other_dtype = pk.int32
elif -9223372036854775808 <= other <= 9223372036854775807:
other_dtype = pk.int64
new_other = pk.View((), dtype=other_dtype)
new_other[:] = other
elif isinstance(other, pk.View):
new_other = other
else:
return False
raise ValueError("unexpected types!")
return equal(self, new_other)



def __hash__(self):
Expand Down Expand Up @@ -785,3 +815,9 @@ class ScratchView7D(ScratchView, Generic[T]):

class ScratchView8D(ScratchView, Generic[T]):
pass


def astype(view, dtype):
new_view = pk.View([*view.shape], dtype=dtype)
new_view[:] = view
return new_view
Loading