Skip to content

Commit

Permalink
Added max_pool and avg_pool functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
Fadi Badine authored and Fadi Badine committed Jan 26, 2025
1 parent bc2c6a3 commit 600b96c
Showing 1 changed file with 161 additions and 2 deletions.
163 changes: 161 additions & 2 deletions keras/src/backend/mlx/nn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import builtins
import math
import operator
from itertools import accumulate

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -122,16 +125,172 @@ def log_softmax(x, axis=-1):
return x - mx.logsumexp(x, axis=axis, keepdims=True)


def _calculate_padding(input_shape, pool_size, strides):
ndim = len(input_shape)

padding = ()
for d in range(ndim):
pad = max(0, (pool_size[d] - 1) - ((input_shape[d] - 1) % strides[d]))
padding = padding + (pad,)

return [(p // 2, (p + 1) // 2) for p in padding]


def _non_overlapping_sliding_windows(x, shape, window_shape):
# Compute the intermediate shape
new_shape = [shape[0]]
for s, w in zip(shape[1:], window_shape):
new_shape.append(s // w)
new_shape.append(w)
new_shape.append(shape[-1])

last_axis = len(new_shape) - 1
axis_order = [
0,
*range(1, last_axis, 2),
*range(2, last_axis, 2),
last_axis,
]

x = x.reshape(new_shape)
x = x.transpose(axis_order)
return x


def _sliding_windows(x, window_shape, window_strides):
if x.ndim < 3:
raise ValueError(
f"To extract sliding windows at least 1 spatial dimension "
f"(3 total) is needed but the input only has {x.ndim} dimensions."
)

spatial_dims = x.shape[1:-1]
if not (len(spatial_dims) == len(window_shape) == len(window_strides)):
raise ValueError(
f"To extract sliding windows the window shapes and strides must "
f"have the same number of spatial dimensions as the signal but "
f"the signal has {len(spatial_dims)} dims and the window shape "
f"has {len(window_shape)} and strides have {len(window_strides)}."
)

shape = x.shape
if all(
window == stride and size % window == 0
for size, window, stride in zip(
spatial_dims, window_shape, window_strides
)
):
return _non_overlapping_sliding_windows(x, shape, window_shape)

strides = list(
reversed(list(accumulate(reversed(shape + (1,)), operator.mul)))
)[1:]

# Compute the output shape
final_shape = [shape[0]]
final_shape += [
(size - window) // stride + 1
for size, window, stride in zip(
spatial_dims, window_shape, window_strides
)
]
final_shape += window_shape
final_shape += [shape[-1]]

# Compute the output strides
final_strides = strides[:1]
final_strides += [
og_stride * stride
for og_stride, stride in zip(strides[1:-1], window_strides)
]
final_strides += strides[1:-1]
final_strides += strides[-1:] # should always be [1]

return mx.as_strided(x, final_shape, final_strides)


def _pool(
inputs, pool_size, strides, padding, padding_value, data_format, pooling_fn
):
if padding not in ("same", "valid"):
raise ValueError(
f"Invalid padding '{padding}', must be 'same' or 'valid'."
)

if data_format == "channels_first":
# mlx expects channels_last
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)

if padding == "same":
pads = _calculate_padding(inputs.shape[1:-1], pool_size, strides)

if any(p[1] > 0 for p in pads):
inputs = mx.pad(
inputs,
[(0, 0)] + pads + [(0, 0)],
constant_values=padding_value,
)

inputs = _sliding_windows(inputs, pool_size, strides)

axes = tuple(range(-len(pool_size) - 1, -1, 1))
result = pooling_fn(inputs, axes)

if data_format == "channels_first":
result = result.transpose(0, -1, *range(1, result.ndim - 1))
return result


def max_pool(
inputs, pool_size, strides=None, padding="valid", data_format=None
):
raise NotImplementedError("MLX backend doesn't support max pooling yet")
inputs = convert_to_tensor(inputs)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
strides = pool_size if strides is None else strides
strides = standardize_tuple(strides, num_spatial_dims, "strides")

return _pool(
inputs, pool_size, strides, padding, -mx.inf, data_format, mx.max
)


def average_pool(
inputs, pool_size, strides=None, padding="valid", data_format=None
):
raise NotImplementedError("MLX backend doesn't support average pooling yet")
inputs = convert_to_tensor(inputs)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
strides = pool_size if strides is None else strides
strides = standardize_tuple(strides, num_spatial_dims, "strides")

# Create a pool by applying the sum function in each window
pooled = _pool(
inputs, pool_size, strides, padding, 0.0, data_format, mx.sum
)
if padding == "valid":
# No padding needed. Divide by the size of the pool which gives
# the average
return pooled / math.prod(pool_size)
else:
# Create a tensor of ones of the same shape of inputs.
# Then create a pool, padding by zero and using sum as function.
# This will create a tensor of the smae dimensions as pooled tensor
# with values being the sum.
# By dividing pooled by windows_counts, we get the average while
# skipping the padded values.
window_counts = _pool(
mx.ones(inputs.shape, inputs.dtype),
pool_size,
strides,
padding,
0.0,
data_format,
mx.sum,
)
return pooled / window_counts


def conv(
Expand Down

0 comments on commit 600b96c

Please sign in to comment.