Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add additional tests; Fix multiply/outer product broadcasting
Browse files Browse the repository at this point in the history
amirebrahimi committed Jan 11, 2025
1 parent 8d2ad26 commit b2f316c
Showing 2 changed files with 89 additions and 24 deletions.
28 changes: 13 additions & 15 deletions src/galois/_fields/_gf2.py
Original file line number Diff line number Diff line change
@@ -172,22 +172,20 @@ def __call__(self, ufunc, method, inputs, kwargs, meta):
else:
result_shape = np.broadcast_shapes(*(i.shape for i in inputs))

if is_outer_product and len(inputs) == 2:
a = np.unpackbits(inputs[0])
output = a[:, np.newaxis].view(np.ndarray) * inputs[1].view(np.ndarray)
if is_outer_product:
assert len(inputs) == 2
# Unpack the first argument and propagate the bitpacked second argument
inputs = [np.unpackbits(x).view(np.ndarray) if i == 0 else x.view(np.ndarray) for i, x in enumerate(inputs)]
output = np.multiply.outer(*inputs)
else:
output = super().__call__(ufunc, method, inputs, kwargs, meta)

assert len(output.view(np.ndarray).shape) == len(result_shape)
# output = output.view(np.ndarray)
# if output.shape != result_shape:
# for axis, shape in enumerate(zip(output.shape, result_shape)):
# if axis == len(result_shape) - 1:
# # The last axis remains packed
# break
#
# if shape[0] != shape[1]:
# output = np.unpackbits(output, axis=axis, count=shape[1])
if any(i.shape != inputs[0].shape for i in inputs):
# We can't do simple bitwise multiplication when the shapes aren't the same due to broadcasting
inputs = [np.unpackbits(i) for i in inputs]
output = reduce(operator.mul, inputs) # We need this to use GF2's multiply
output = np.packbits(output)
else:
output = super().__call__(ufunc, method, inputs, kwargs, meta)

output = self.field._view(output)
output._axis_count = result_shape[-1]
return output
85 changes: 76 additions & 9 deletions tests/fields/test_bitpacked.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import galois
from galois import GF2


def test_galois_array_indexing():
# Define a Galois field array
@@ -60,9 +62,6 @@ def test_galois_array_indexing():
# taken = np.take(arr, [0, 2])
# assert np.array_equal(taken, GF([1, 1]))

print("All tests passed.")


def test_galois_array_setting():
# Define a Galois field array
GF = galois.GF(2)
@@ -134,9 +133,77 @@ def test_galois_array_setting():
arr_2d[np.ix_(row_indices, col_indices)] = GF([[0, 0], [0, 0]])
assert np.array_equal(arr_2d, np.packbits(GF([[0, 0], [0, 0]])))

print("All set-indexing tests passed.")


if __name__ == "__main__":
test_galois_array_indexing()
test_galois_array_setting()
def test_inv():
N = 10
u = GF2.Random((N, N), seed=2)
p = np.packbits(u)
# print(x.get_unpacked_slice(1))
# index = np.index_exp[:,1:4:2]
# index = np.index_exp[[0,1], [0, 1]]
# print(a)
# print(a[index])
# print(x.get_unpacked_slice(index))
print(np.linalg.inv(u))
print(np.unpackbits(np.linalg.inv(p)))
assert np.array_equal(np.linalg.inv(u), np.unpackbits(np.linalg.inv(p)))

def test_arithmetic():
size = (20, 10)
cm = np.random.randint(2, size=size, dtype=np.uint8)
cm2 = np.random.randint(2, size=size, dtype=np.uint8)
vec = np.random.randint(2, size=size[1], dtype=np.uint8)

cm_GF2 = GF2(cm)
cm2_GF2 = GF2(cm2)
cm3_GF2 = GF2(cm2.T)
vec_GF2 = GF2(vec)

cm_gf2bp = np.packbits(cm_GF2)
cm2_gf2bp = np.packbits(cm2_GF2)
cm3_gf2bp = np.packbits(cm2_GF2.T)
vec_gf2bp = np.packbits(vec_GF2)

# Addition
assert np.array_equal(np.unpackbits(cm_gf2bp + cm2_gf2bp), cm_GF2 + cm2_GF2)

# Multiplication
assert np.array_equal(np.unpackbits(cm_gf2bp * cm2_gf2bp), cm_GF2 * cm2_GF2)

# Matrix-vector product
assert np.array_equal(np.unpackbits(cm_gf2bp @ vec_gf2bp), cm_GF2 @ vec_GF2)

# Matrix-matrix product
assert np.array_equal(np.unpackbits(cm_gf2bp @ cm3_gf2bp), cm_GF2 @ cm3_GF2)

def test_broadcasting():
a = np.random.randint(0, 2, 10)
b = np.random.randint(0, 2, 10)
x = np.packbits(GF2(a))
y = np.packbits(GF2(b))

c = a * b
z = x * y
assert c.shape == z.shape == np.unpackbits(z).shape # (10,)

c = np.multiply.outer(a, b)
z = np.multiply.outer(x, y)
assert np.array_equal(np.unpackbits(z), c)
assert c.shape == z.shape == np.unpackbits(z).shape # (10, 10)

def test_advanced_broadcasting():
a = np.random.randint(0, 2, (1, 2, 3))
b = np.random.randint(0, 2, (2, 2, 1))
x = np.packbits(GF2(a))
y = np.packbits(GF2(b))

c = a * b
z = x * y
assert np.array_equal(np.unpackbits(z), c)
assert c.shape == z.shape == np.unpackbits(z).shape # (2, 2, 3)

c = np.multiply.outer(a, b)
z = np.multiply.outer(x, y)
print(c.shape)
print(z.shape)
assert np.array_equal(np.unpackbits(z), c)
assert c.shape == z.shape == np.unpackbits(z).shape # (1, 2, 3, 2, 2, 1)

0 comments on commit b2f316c

Please sign in to comment.