Skip to content

Commit

Permalink
Fix up broadcasting in add, sub, div
Browse files Browse the repository at this point in the history
  • Loading branch information
amirebrahimi committed Jan 11, 2025
1 parent b2f316c commit c535cd6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 22 deletions.
33 changes: 27 additions & 6 deletions src/galois/_fields/_gf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,15 @@ class add_ufunc_bitpacked(add_ufunc):
"""

def __call__(self, ufunc, method, inputs, kwargs, meta):
output = super().__call__(ufunc, method, inputs, kwargs, meta)
output._axis_count = inputs[0]._axis_count
result_shape = np.broadcast_shapes(*(i.shape for i in inputs))
if any(i.shape != inputs[0].shape for i in inputs):
# We can't do simple bitwise addition when the shapes aren't the same due to broadcasting
inputs = [np.unpackbits(i) for i in inputs]
output = reduce(operator.add, inputs) # We need this to use GF2's addition
output = np.packbits(output)
else:
output = super().__call__(ufunc, method, inputs, kwargs, meta)
output._axis_count = result_shape[-1]
return output


Expand All @@ -155,8 +162,15 @@ class subtract_ufunc_bitpacked(subtract_ufunc):
"""

def __call__(self, ufunc, method, inputs, kwargs, meta):
output = super().__call__(ufunc, method, inputs, kwargs, meta)
output._axis_count = max(i._axis_count for i in inputs)
result_shape = np.broadcast_shapes(*(i.shape for i in inputs))
if any(i.shape != inputs[0].shape for i in inputs):
# We can't do simple bitwise subtraction when the shapes aren't the same due to broadcasting
inputs = [np.unpackbits(i) for i in inputs]
output = reduce(operator.sub, inputs) # We need this to use GF2's subtraction
output = np.packbits(output)
else:
output = super().__call__(ufunc, method, inputs, kwargs, meta)
output._axis_count = result_shape[-1]
return output


Expand Down Expand Up @@ -197,8 +211,15 @@ class divide_ufunc_bitpacked(divide):
"""

def __call__(self, ufunc, method, inputs, kwargs, meta):
output = super().__call__(ufunc, method, inputs, kwargs, meta)
output._axis_count = max(i._axis_count for i in inputs)
result_shape = np.broadcast_shapes(*(i.shape for i in inputs))
if any(i.shape != inputs[0].shape for i in inputs):
# We can't do simple bitwise division when the shapes aren't the same due to broadcasting
inputs = [np.unpackbits(i) for i in inputs]
output = reduce(operator.truediv, inputs) # We need this to use GF2's division
output = np.packbits(output)
else:
output = super().__call__(ufunc, method, inputs, kwargs, meta)
output._axis_count = result_shape[-1]
return output


Expand Down
35 changes: 19 additions & 16 deletions tests/fields/test_bitpacked.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import galois
from galois import GF2
import operator as ops


def test_galois_array_indexing():
Expand Down Expand Up @@ -176,30 +177,32 @@ def test_arithmetic():
assert np.array_equal(np.unpackbits(cm_gf2bp @ cm3_gf2bp), cm_GF2 @ cm3_GF2)

def test_broadcasting():
a = np.random.randint(0, 2, 10)
b = np.random.randint(0, 2, 10)
x = np.packbits(GF2(a))
y = np.packbits(GF2(b))
a = GF2(np.random.randint(0, 2, 10))
b = GF2(np.random.randint(0, 2, 10))
x = np.packbits(a)
y = np.packbits(b)

c = a * b
z = x * y
assert c.shape == z.shape == np.unpackbits(z).shape # (10,)
for op in [ops.add, ops.sub, ops.mul]:
c = op(a, b)
z = op(x, y)
assert c.shape == z.shape == np.unpackbits(z).shape # (10,)

c = np.multiply.outer(a, b)
z = np.multiply.outer(x, y)
assert np.array_equal(np.unpackbits(z), c)
assert c.shape == z.shape == np.unpackbits(z).shape # (10, 10)

def test_advanced_broadcasting():
a = np.random.randint(0, 2, (1, 2, 3))
b = np.random.randint(0, 2, (2, 2, 1))
x = np.packbits(GF2(a))
y = np.packbits(GF2(b))

c = a * b
z = x * y
assert np.array_equal(np.unpackbits(z), c)
assert c.shape == z.shape == np.unpackbits(z).shape # (2, 2, 3)
a = GF2(np.random.randint(0, 2, (1, 2, 3)))
b = GF2(np.random.randint(0, 2, (2, 2, 1)))
x = np.packbits(a)
y = np.packbits(b)

for op in [ops.add, ops.sub, ops.mul]:
c = op(a, b)
z = op(x, y)
assert np.array_equal(np.unpackbits(z), c)
assert c.shape == z.shape == np.unpackbits(z).shape # (2, 2, 3)

c = np.multiply.outer(a, b)
z = np.multiply.outer(x, y)
Expand Down

0 comments on commit c535cd6

Please sign in to comment.