diff --git a/src/galois/_fields/_gf2.py b/src/galois/_fields/_gf2.py index 04fc16d3e..388be622f 100644 --- a/src/galois/_fields/_gf2.py +++ b/src/galois/_fields/_gf2.py @@ -172,22 +172,20 @@ def __call__(self, ufunc, method, inputs, kwargs, meta): else: result_shape = np.broadcast_shapes(*(i.shape for i in inputs)) - if is_outer_product and len(inputs) == 2: - a = np.unpackbits(inputs[0]) - output = a[:, np.newaxis].view(np.ndarray) * inputs[1].view(np.ndarray) + if is_outer_product: + assert len(inputs) == 2 + # Unpack the first argument and propagate the bitpacked second argument + inputs = [np.unpackbits(x).view(np.ndarray) if i == 0 else x.view(np.ndarray) for i, x in enumerate(inputs)] + output = np.multiply.outer(*inputs) else: - output = super().__call__(ufunc, method, inputs, kwargs, meta) - - assert len(output.view(np.ndarray).shape) == len(result_shape) - # output = output.view(np.ndarray) - # if output.shape != result_shape: - # for axis, shape in enumerate(zip(output.shape, result_shape)): - # if axis == len(result_shape) - 1: - # # The last axis remains packed - # break - # - # if shape[0] != shape[1]: - # output = np.unpackbits(output, axis=axis, count=shape[1]) + if any(i.shape != inputs[0].shape for i in inputs): + # We can't do simple bitwise multiplication when the shapes aren't the same due to broadcasting + inputs = [np.unpackbits(i) for i in inputs] + output = reduce(operator.mul, inputs) # We need this to use GF2's multiply + output = np.packbits(output) + else: + output = super().__call__(ufunc, method, inputs, kwargs, meta) + output = self.field._view(output) output._axis_count = result_shape[-1] return output diff --git a/tests/fields/test_bitpacked.py b/tests/fields/test_bitpacked.py index d304bed11..57a66e68c 100644 --- a/tests/fields/test_bitpacked.py +++ b/tests/fields/test_bitpacked.py @@ -1,5 +1,7 @@ import numpy as np import galois +from galois import GF2 + def test_galois_array_indexing(): # Define a Galois field array @@ -60,9 +62,6 @@ def test_galois_array_indexing(): # taken = np.take(arr, [0, 2]) # assert np.array_equal(taken, GF([1, 1])) - print("All tests passed.") - - def test_galois_array_setting(): # Define a Galois field array GF = galois.GF(2) @@ -134,9 +133,77 @@ def test_galois_array_setting(): arr_2d[np.ix_(row_indices, col_indices)] = GF([[0, 0], [0, 0]]) assert np.array_equal(arr_2d, np.packbits(GF([[0, 0], [0, 0]]))) - print("All set-indexing tests passed.") - - -if __name__ == "__main__": - test_galois_array_indexing() - test_galois_array_setting() +def test_inv(): + N = 10 + u = GF2.Random((N, N), seed=2) + p = np.packbits(u) + # print(x.get_unpacked_slice(1)) + # index = np.index_exp[:,1:4:2] + # index = np.index_exp[[0,1], [0, 1]] + # print(a) + # print(a[index]) + # print(x.get_unpacked_slice(index)) + print(np.linalg.inv(u)) + print(np.unpackbits(np.linalg.inv(p))) + assert np.array_equal(np.linalg.inv(u), np.unpackbits(np.linalg.inv(p))) + +def test_arithmetic(): + size = (20, 10) + cm = np.random.randint(2, size=size, dtype=np.uint8) + cm2 = np.random.randint(2, size=size, dtype=np.uint8) + vec = np.random.randint(2, size=size[1], dtype=np.uint8) + + cm_GF2 = GF2(cm) + cm2_GF2 = GF2(cm2) + cm3_GF2 = GF2(cm2.T) + vec_GF2 = GF2(vec) + + cm_gf2bp = np.packbits(cm_GF2) + cm2_gf2bp = np.packbits(cm2_GF2) + cm3_gf2bp = np.packbits(cm2_GF2.T) + vec_gf2bp = np.packbits(vec_GF2) + + # Addition + assert np.array_equal(np.unpackbits(cm_gf2bp + cm2_gf2bp), cm_GF2 + cm2_GF2) + + # Multiplication + assert np.array_equal(np.unpackbits(cm_gf2bp * cm2_gf2bp), cm_GF2 * cm2_GF2) + + # Matrix-vector product + assert np.array_equal(np.unpackbits(cm_gf2bp @ vec_gf2bp), cm_GF2 @ vec_GF2) + + # Matrix-matrix product + assert np.array_equal(np.unpackbits(cm_gf2bp @ cm3_gf2bp), cm_GF2 @ cm3_GF2) + +def test_broadcasting(): + a = np.random.randint(0, 2, 10) + b = np.random.randint(0, 2, 10) + x = np.packbits(GF2(a)) + y = np.packbits(GF2(b)) + + c = a * b + z = x * y + assert c.shape == z.shape == np.unpackbits(z).shape # (10,) + + c = np.multiply.outer(a, b) + z = np.multiply.outer(x, y) + assert np.array_equal(np.unpackbits(z), c) + assert c.shape == z.shape == np.unpackbits(z).shape # (10, 10) + +def test_advanced_broadcasting(): + a = np.random.randint(0, 2, (1, 2, 3)) + b = np.random.randint(0, 2, (2, 2, 1)) + x = np.packbits(GF2(a)) + y = np.packbits(GF2(b)) + + c = a * b + z = x * y + assert np.array_equal(np.unpackbits(z), c) + assert c.shape == z.shape == np.unpackbits(z).shape # (2, 2, 3) + + c = np.multiply.outer(a, b) + z = np.multiply.outer(x, y) + print(c.shape) + print(z.shape) + assert np.array_equal(np.unpackbits(z), c) + assert c.shape == z.shape == np.unpackbits(z).shape # (1, 2, 3, 2, 2, 1)