Skip to content

Commit

Permalink
Add additional indexing support
Browse files Browse the repository at this point in the history
  • Loading branch information
amirebrahimi committed Dec 21, 2024
1 parent 1186fa4 commit c07ba0f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 7 deletions.
48 changes: 41 additions & 7 deletions src/galois/_fields/_gf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def __new__(
) -> Self:
# axis_element_count is required, but by making it optional it allows us to catch uses of the class that are not
# supported (e.g. Random)
if isinstance(x, np.ndarray) and axis_element_count is not None:
if isinstance(x, (tuple, list, np.ndarray, FieldArray)) and axis_element_count is not None:
# NOTE: I'm not sure that we want to change the dtype specifically for the bit-packed version or how we verify
# dtype = cls._get_dtype(dtype)
# x = cls._verify_array_like_types_and_values(x)
Expand Down Expand Up @@ -450,7 +450,8 @@ def Identity(cls, size: int, dtype: DTypeLike | None = None) -> Self:
return np.packbits(array)

def get_unpacked_slice(self, index):
if isinstance(index, Sequence):
post_index = NotImplemented
if isinstance(index, (Sequence, np.ndarray)):
if len(index) == 2:
row_index, col_index = index
if isinstance(col_index, int):
Expand All @@ -464,16 +465,49 @@ def get_unpacked_slice(self, index):
post_index = (list(range(len(row_index))), col_index)
col_index = tuple(s // 8 for s in col_index)
index = (row_index, col_index)
elif col_index is None: # new axis
post_index = (slice(None), None)
index = (row_index,)
elif ((isinstance(index, np.ndarray) and index.ndim == 1) or
(isinstance(index, list) and all(isinstance(x, int) for x in index))):
post_index = index
index = list(range((len(index) // 8) + 1))
elif isinstance(index, tuple) and any(x is Ellipsis for x in index):
post_index = index[1:]
axis_adjustment = (slice(None),) if index[-1] is Ellipsis else (index[-1] // 8,)
index = index[:-1] + axis_adjustment
elif isinstance(index, slice):
# TODO
pass

packed = self[index]
if len(packed.shape) == 1:
if self.ndim > 1:
# Rows aren't packed, so we can index normally
post_index = slice(None)
if len(self.shape) == 1:
# Array is 1-D, so we need to adjust
post_index = index
index = slice(index.start // 8 if index.start is not None else index.start,
max(index.step // 8, 1) if index.step is not None else index.step,
max(index.stop // 8, 1) if index.stop is not None else index.stop)
elif isinstance(index, int):
post_index = index
index //= 8

if post_index is NotImplemented:
raise NotImplementedError(f"The following indexing scheme is not supported:\n{index}\n"
"If you believe this scheme should be supported, "
"please submit a GitHub issue at https://github.com/mhostetter/galois/issues.\n\n"
"If you'd like to perform this operation on the data, you should first call "
"`array = array.view(np.ndarray)` and then call the function."
)

packed = self.view(np.ndarray)[index]
if np.isscalar(packed):
packed = GF2BP([packed], self._axis_count).view(np.ndarray)
if packed.ndim == 1 and self.ndim > 1:
packed = packed[:, None]
unpacked = np.unpackbits(packed, axis=-1, count=self._axis_count)
return GF2._view(unpacked[post_index])

def __getitem__(self, item):
return self.get_unpacked_slice(item)

def set_unpacked_slice(self, index, value):
pass
Expand Down
65 changes: 65 additions & 0 deletions tests/fields/test_bitpacked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import galois

def test_galois_array_indexing():
# Define a Galois field array
GF = galois.GF(2)
arr = GF([1, 0, 1, 1])
arr = np.packbits(arr)

# 1. Basic Indexing
assert arr[0] == GF(1)
assert arr[2] == GF(1)

# 2. Negative Indexing
assert arr[-1] == GF(1)
assert arr[-2] == GF(1)

# 3. Slicing
assert np.array_equal(arr[1:3], GF([0, 1]))
assert np.array_equal(arr[:3], GF([1, 0, 1]))
assert np.array_equal(arr[::2], GF([1, 1]))
assert np.array_equal(arr[::-1], GF([1, 1, 0, 1]))

# 4. Multidimensional Indexing
arr_2d = GF([[1, 0], [0, 1]])
arr_2d = np.packbits(arr_2d)
assert arr_2d[0, 1] == GF(0)
assert np.array_equal(arr_2d[:, 1], GF([0, 1]))

# 5. Boolean Indexing
mask = np.array([True, False, True, False])
assert np.array_equal(arr[mask], GF([1, 1]))

# 6. Fancy Indexing
indices = [0, 2, 3]
assert np.array_equal(arr[indices], GF([1, 1, 1]))

# 7. Ellipsis
arr_3d = GF(np.random.randint(0, 2, (2, 3, 4)))
arr_3d = np.packbits(arr_3d)
shape_check = arr_3d[0, ..., 1].shape # (3,)
assert shape_check == (3,)

# 8. Indexing with slice objects
s = slice(1, 3)
assert np.array_equal(arr[s], GF([0, 1]))

# 9. Using np.newaxis
reshaped = arr[:, np.newaxis]
assert reshaped.shape == (4, 1)

# 10. Indexing with np.ix_
row_indices = np.array([0, 1])
col_indices = np.array([0, 1])
sub_matrix = arr_2d[np.ix_(row_indices, col_indices)]
assert np.array_equal(sub_matrix, GF([[1, 0], [0, 1]]))

# 11. Indexing with np.take
taken = np.take(arr, [0, 2])
assert np.array_equal(taken, GF([1, 1]))

print("All tests passed.")

if __name__ == "__main__":
test_galois_array_indexing()

0 comments on commit c07ba0f

Please sign in to comment.