Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 25, 2025
1 parent cd3f367 commit 3baadca
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 69 deletions.
2 changes: 1 addition & 1 deletion tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from .model_tester import ModelTester, RunMode
from .multichip_tester import run_multichip_test_with_random_inputs
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor, supported_dtypes, make_partition_spec
from .utils import random_tensor, make_partition_spec
1 change: 1 addition & 0 deletions tests/infra/device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def is_initialized(self) -> bool:
return False

def get_tt_device_mesh(self, shape: tuple, axis_names: tuple) -> jax.sharding.Mesh:
"""Returns TTDevice mesh with specified `shape` and `axis_names`."""
tt_devices = jax.devices(DeviceType.TT.value)
return jax.make_mesh(shape, axis_names, devices=tt_devices)

Expand Down
76 changes: 43 additions & 33 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor:
@staticmethod
def run_on_multichip_device(multichip_workload: MultichipWorkload) -> Tensor:
"""Runs `workload` on a multichip device."""
sharded_workload = DeviceRunner._safely_put_sharding_on_workload(
sharded_workload = DeviceRunner._put_multichip_workload_on_device(
multichip_workload
)
return sharded_workload.execute().block_until_ready()
return sharded_workload.execute()

@staticmethod
def run_on_cpu(workload: Workload) -> Tensor:
Expand Down Expand Up @@ -72,18 +72,54 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def _safely_put_sharding_on_workload(
def _run_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(workload, device_type, device_num)
device = device_connector.connect_device(device_type, device_num)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_sharded_tensor_on_multichip_device(
tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec
) -> Tensor:
"""
Needed for multichip: Uses put_device to give inputs shardings.
We just put dummy sharding equal to none, since jax needs to have some notion of sharding
when running graph with these buffers as input.
TODO: This can be omitted when we find a way to get sharding information from the StableHLO
code back to jax through a protobuf (issue #227).
"""
none_tuple = (None,) * len(in_spec)
none_spec = PartitionSpec(*none_tuple)
return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True)

@staticmethod
def _put_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type, device_num)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
def _put_multichip_workload_on_device(
multichip_workload: MultichipWorkload,
) -> MultichipWorkload:
"""Gives the workload inputs shardings, necessary for multichip workloads"""
args_on_device = []
spec_index = 0
# TODO: It might necessary to put a try-except block here, but holding that off until we
# come across a case where it's needed.
for arg in multichip_workload.args:
if not isinstance(arg, Tensor):
args_on_device.append(arg)
else:
args_on_device.append(
DeviceRunner._put_tensor_none_sharding(
DeviceRunner._put_sharded_tensor_on_multichip_device(
arg,
multichip_workload.mesh,
multichip_workload.in_specs[spec_index],
Expand All @@ -96,7 +132,9 @@ def _safely_put_sharding_on_workload(
if not isinstance(value, Tensor):
kwargs_on_device[key] = value
else:
kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding(
kwargs_on_device[
key
] = DeviceRunner._put_sharded_tensor_on_multichip_device(
value,
multichip_workload.mesh,
multichip_workload.in_specs[spec_index],
Expand All @@ -111,34 +149,6 @@ def _safely_put_sharding_on_workload(
in_specs=multichip_workload.in_specs,
)

@staticmethod
def _run_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(workload, device_type, device_num)
device = device_connector.connect_device(device_type, device_num)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_tensor_none_sharding(
tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec
) -> Tensor:
"""Needed for multichip: Uses put_device to give inputs shardings."""
none_tuple = (None,) * len(in_spec)
none_spec = PartitionSpec(*none_tuple)
return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True)

@staticmethod
def _put_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type, device_num)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
def _put_tensors_on_device(
device_type: DeviceType, tensors: Sequence[Tensor]
Expand Down
45 changes: 24 additions & 21 deletions tests/infra/multichip_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ class MultichipTester(BaseTester):
and output sharding specifications.
Attributes:
mesh (jax.Mesh): The device mesh over which the computation is distributed.
device_mesh (jax.Mesh): The device mesh over which the computation is distributed.
in_specs (tuple): The sharding specifications for the input tensors.
out_specs (jax.sharding.PartitionSpec): The sharding specification for the output tensor.
"""

def __init__(
self,
mesh: jax.Mesh,
in_specs: tuple,
in_specs: tuple[jax.sharding.PartitionSpec],
out_specs: jax.sharding.PartitionSpec,
mesh_shape: tuple,
axis_names: tuple,
comparison_config: ComparisonConfig = ComparisonConfig(),
) -> None:
self.mesh = mesh
super().__init__(comparison_config)
self.in_specs = in_specs
self.out_specs = out_specs
super().__init__(comparison_config)
self.device_mesh = device_connector.get_tt_device_mesh(mesh_shape, axis_names)

def _compile_for_cpu(
self, executable: Callable, static_argnames: Sequence[str] = None
Expand All @@ -53,9 +54,12 @@ def _compile_for_device(
) -> Callable:
"""Sets up executable for just-in-time compile and execution on multichip device."""
module_sharded = shard_map(
executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs
executable,
mesh=self.device_mesh,
in_specs=self.in_specs,
out_specs=self.out_specs,
)
output_sharding = NamedSharding(self.mesh, self.out_specs)
output_sharding = NamedSharding(self.device_mesh, self.out_specs)
return jax.jit(
module_sharded,
out_shardings=output_sharding,
Expand All @@ -66,28 +70,26 @@ def test(
self, multichip_workload: MultichipWorkload, cpu_workload: Workload
) -> None:
"""
Runs test by running `workload` on TT device and CPU and comparing the results.
Runs test by running `workload` on TT device and 'cpu_workload' on the CPU and comparing the results.
"""
multichip_compiled_workload = MultichipWorkload(
compiled_device_workload = MultichipWorkload(
self._compile_for_device(multichip_workload.executable),
multichip_workload.args,
multichip_workload.kwargs,
mesh=self.mesh,
device_mesh=self.device_mesh,
in_specs=self.in_specs,
)

cpu_compiled_workload = Workload(
compiled_cpu_workload = Workload(
self._compile_for_cpu(cpu_workload.executable),
cpu_workload.args,
cpu_workload.kwargs,
)

tt_multichip_res = DeviceRunner.run_on_multichip_device(
multichip_compiled_workload
)
cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload)
device_res = DeviceRunner.run_on_multichip_device(compiled_device_workload)
cpu_res = DeviceRunner.run_on_cpu(compiled_cpu_workload)

self._compare(tt_multichip_res, cpu_res)
self._compare(device_res, cpu_res)

def test_with_random_inputs(
self,
Expand All @@ -107,11 +109,11 @@ def test_with_random_inputs(
)
for shape in input_shapes
]
multichip_workload = MultichipWorkload(
device_executable, inputs, mesh=self.mesh, in_specs=self.in_specs
device_workload = MultichipWorkload(
device_executable, inputs, mesh=self.device_mesh, in_specs=self.in_specs
)
cpu_workload = Workload(cpu_executable, inputs)
self.test(multichip_workload, cpu_workload)
self.test(device_workload, cpu_workload)


def run_multichip_test_with_random_inputs(
Expand All @@ -130,8 +132,9 @@ def run_multichip_test_with_random_inputs(
Tests an input executable with random inputs in range [`minval`, `maxval`) by running it on a mesh of
TT devices and comparing it to output of the cpu executable ran with the same input.
"""
mesh = device_connector.get_tt_device_mesh(mesh_shape, axis_names)
tester = MultichipTester(mesh, in_specs, out_specs, comparison_config)
tester = MultichipTester(
in_specs, out_specs, mesh_shape, axis_names, comparison_config
)
tester.test_with_random_inputs(
device_executable, cpu_executable, input_shapes, minval, maxval
)
6 changes: 3 additions & 3 deletions tests/infra/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
import jax
from jax.sharding import Mesh, PartitionSpec
from typing import Any, Callable, Mapping, Optional, Sequence


Expand Down Expand Up @@ -36,5 +36,5 @@ class MultichipWorkload(Workload):
An extension of the Workload dataclass that includes a mesh and partition specs, necessary for multichip sharding.
"""

mesh: jax.sharding.Mesh = None
in_specs: Sequence[jax.sharding.PartitionSpec] = None
device_mesh: Mesh = None
in_specs: Sequence[PartitionSpec] = None
12 changes: 6 additions & 6 deletions tests/jax/multichip/manual/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
from tests.utils import make_partition_spec


@pytest.mark.parametrize("x_shape", [(8192, 784)])
@pytest.mark.parametrize(("x_shape", "axis_names"), [((8192, 784), ("batch",))])
@pytest.mark.skip(reason=compile_fail("Multichip still in development"))
def test_all_gather(x_shape: tuple):
def test_all_gather(x_shape: tuple, axis_names: tuple):
def fwd(batch):
act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True)
act = jax.lax.all_gather(batch, axis_names, axis=0, tiled=True)
return act

def golden_fwd(batch):
return jnp.tile(batch, (2, 1))

in_specs = (make_partition_spec(("batch")),)
out_specs = make_partition_spec(("batch"))
in_specs = (make_partition_spec(axis_names),)
out_specs = make_partition_spec(axis_names)

run_multichip_test_with_random_inputs(
fwd, golden_fwd, [x_shape], (2,), ("batch"), in_specs, out_specs
fwd, golden_fwd, [x_shape], (2,), axis_names, in_specs, out_specs
)
10 changes: 5 additions & 5 deletions tests/jax/multichip/manual/unary_eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from tests.utils import make_partition_spec


@pytest.mark.parametrize("x_shape", [(256, 256)])
@pytest.mark.parametrize(("x_shape", "axis_names"), [((256, 256), ("x", "y"))])
@pytest.mark.skip(reason=compile_fail("Multichip still in development"))
def test_unary_eltwise(x_shape: tuple):
def test_unary_eltwise(x_shape: tuple, axis_names: tuple):
def fwd(a_block):
b_block = jnp.negative(a_block)
stitched_result = jax.lax.psum(b_block, ("x", "y"))
stitched_result = jax.lax.psum(b_block, axis_names)
return stitched_result

def fwd_single_device(a_block):
Expand All @@ -26,9 +26,9 @@ def fwd_single_device(a_block):
stitched_result = b1 + b2
return stitched_result

in_specs = (make_partition_spec((("x", "y"))),)
in_specs = (make_partition_spec(axis_names),)
out_specs = make_partition_spec((None, None))

run_multichip_test_with_random_inputs(
fwd, fwd_single_device, [x_shape], (1, 2), ("x", "y"), in_specs, out_specs
fwd, fwd_single_device, [x_shape], (1, 2), axis_names, in_specs, out_specs
)

0 comments on commit 3baadca

Please sign in to comment.