From 3baadca2da893a075add1a3a00694e6204517ebc Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 7 Feb 2025 10:49:16 +0000 Subject: [PATCH] Addressed comments --- tests/infra/__init__.py | 2 +- tests/infra/device_connector.py | 1 + tests/infra/device_runner.py | 76 ++++++++++++--------- tests/infra/multichip_tester.py | 45 ++++++------ tests/infra/workload.py | 6 +- tests/jax/multichip/manual/all_gather.py | 12 ++-- tests/jax/multichip/manual/unary_eltwise.py | 10 +-- 7 files changed, 83 insertions(+), 69 deletions(-) diff --git a/tests/infra/__init__.py b/tests/infra/__init__.py index 4c9f31e2..156bc137 100644 --- a/tests/infra/__init__.py +++ b/tests/infra/__init__.py @@ -8,4 +8,4 @@ from .model_tester import ModelTester, RunMode from .multichip_tester import run_multichip_test_with_random_inputs from .op_tester import run_op_test, run_op_test_with_random_inputs -from .utils import random_tensor, supported_dtypes, make_partition_spec +from .utils import random_tensor, make_partition_spec diff --git a/tests/infra/device_connector.py b/tests/infra/device_connector.py index 295936c9..3ad5c490 100644 --- a/tests/infra/device_connector.py +++ b/tests/infra/device_connector.py @@ -70,6 +70,7 @@ def is_initialized(self) -> bool: return False def get_tt_device_mesh(self, shape: tuple, axis_names: tuple) -> jax.sharding.Mesh: + """Returns TTDevice mesh with specified `shape` and `axis_names`.""" tt_devices = jax.devices(DeviceType.TT.value) return jax.make_mesh(shape, axis_names, devices=tt_devices) diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 940fbfa7..8e0841b1 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -26,10 +26,10 @@ def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor: @staticmethod def run_on_multichip_device(multichip_workload: MultichipWorkload) -> Tensor: """Runs `workload` on a multichip device.""" - sharded_workload = DeviceRunner._safely_put_sharding_on_workload( + sharded_workload = DeviceRunner._put_multichip_workload_on_device( multichip_workload ) - return sharded_workload.execute().block_until_ready() + return sharded_workload.execute() @staticmethod def run_on_cpu(workload: Workload) -> Tensor: @@ -72,18 +72,54 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def _safely_put_sharding_on_workload( + def _run_on_device( + workload: Workload, device_type: DeviceType, device_num: int = 0 + ) -> Tensor: + """Runs `workload` on device identified by `device_type`.""" + device_workload = DeviceRunner._put_on_device(workload, device_type, device_num) + device = device_connector.connect_device(device_type, device_num) + + with jax.default_device(device): + return device_workload.execute() + + @staticmethod + def _put_sharded_tensor_on_multichip_device( + tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec + ) -> Tensor: + """ + Needed for multichip: Uses put_device to give inputs shardings. + We just put dummy sharding equal to none, since jax needs to have some notion of sharding + when running graph with these buffers as input. + TODO: This can be omitted when we find a way to get sharding information from the StableHLO + code back to jax through a protobuf (issue #227). + """ + none_tuple = (None,) * len(in_spec) + none_spec = PartitionSpec(*none_tuple) + return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True) + + @staticmethod + def _put_on_device( + workload: Workload, device_type: DeviceType, device_num: int = 0 + ) -> Workload: + """Puts `workload` on device and returns it.""" + device = device_connector.connect_device(device_type, device_num) + return DeviceRunner._safely_put_workload_on_device(workload, device) + + @staticmethod + def _put_multichip_workload_on_device( multichip_workload: MultichipWorkload, ) -> MultichipWorkload: """Gives the workload inputs shardings, necessary for multichip workloads""" args_on_device = [] spec_index = 0 + # TODO: It might necessary to put a try-except block here, but holding that off until we + # come across a case where it's needed. for arg in multichip_workload.args: if not isinstance(arg, Tensor): args_on_device.append(arg) else: args_on_device.append( - DeviceRunner._put_tensor_none_sharding( + DeviceRunner._put_sharded_tensor_on_multichip_device( arg, multichip_workload.mesh, multichip_workload.in_specs[spec_index], @@ -96,7 +132,9 @@ def _safely_put_sharding_on_workload( if not isinstance(value, Tensor): kwargs_on_device[key] = value else: - kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding( + kwargs_on_device[ + key + ] = DeviceRunner._put_sharded_tensor_on_multichip_device( value, multichip_workload.mesh, multichip_workload.in_specs[spec_index], @@ -111,34 +149,6 @@ def _safely_put_sharding_on_workload( in_specs=multichip_workload.in_specs, ) - @staticmethod - def _run_on_device( - workload: Workload, device_type: DeviceType, device_num: int = 0 - ) -> Tensor: - """Runs `workload` on device identified by `device_type`.""" - device_workload = DeviceRunner._put_on_device(workload, device_type, device_num) - device = device_connector.connect_device(device_type, device_num) - - with jax.default_device(device): - return device_workload.execute() - - @staticmethod - def _put_tensor_none_sharding( - tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec - ) -> Tensor: - """Needed for multichip: Uses put_device to give inputs shardings.""" - none_tuple = (None,) * len(in_spec) - none_spec = PartitionSpec(*none_tuple) - return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True) - - @staticmethod - def _put_on_device( - workload: Workload, device_type: DeviceType, device_num: int = 0 - ) -> Workload: - """Puts `workload` on device and returns it.""" - device = device_connector.connect_device(device_type, device_num) - return DeviceRunner._safely_put_workload_on_device(workload, device) - @staticmethod def _put_tensors_on_device( device_type: DeviceType, tensors: Sequence[Tensor] diff --git a/tests/infra/multichip_tester.py b/tests/infra/multichip_tester.py index b8a9bf1a..cc4d5b29 100644 --- a/tests/infra/multichip_tester.py +++ b/tests/infra/multichip_tester.py @@ -25,22 +25,23 @@ class MultichipTester(BaseTester): and output sharding specifications. Attributes: - mesh (jax.Mesh): The device mesh over which the computation is distributed. + device_mesh (jax.Mesh): The device mesh over which the computation is distributed. in_specs (tuple): The sharding specifications for the input tensors. out_specs (jax.sharding.PartitionSpec): The sharding specification for the output tensor. """ def __init__( self, - mesh: jax.Mesh, - in_specs: tuple, + in_specs: tuple[jax.sharding.PartitionSpec], out_specs: jax.sharding.PartitionSpec, + mesh_shape: tuple, + axis_names: tuple, comparison_config: ComparisonConfig = ComparisonConfig(), ) -> None: - self.mesh = mesh + super().__init__(comparison_config) self.in_specs = in_specs self.out_specs = out_specs - super().__init__(comparison_config) + self.device_mesh = device_connector.get_tt_device_mesh(mesh_shape, axis_names) def _compile_for_cpu( self, executable: Callable, static_argnames: Sequence[str] = None @@ -53,9 +54,12 @@ def _compile_for_device( ) -> Callable: """Sets up executable for just-in-time compile and execution on multichip device.""" module_sharded = shard_map( - executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs + executable, + mesh=self.device_mesh, + in_specs=self.in_specs, + out_specs=self.out_specs, ) - output_sharding = NamedSharding(self.mesh, self.out_specs) + output_sharding = NamedSharding(self.device_mesh, self.out_specs) return jax.jit( module_sharded, out_shardings=output_sharding, @@ -66,28 +70,26 @@ def test( self, multichip_workload: MultichipWorkload, cpu_workload: Workload ) -> None: """ - Runs test by running `workload` on TT device and CPU and comparing the results. + Runs test by running `workload` on TT device and 'cpu_workload' on the CPU and comparing the results. """ - multichip_compiled_workload = MultichipWorkload( + compiled_device_workload = MultichipWorkload( self._compile_for_device(multichip_workload.executable), multichip_workload.args, multichip_workload.kwargs, - mesh=self.mesh, + device_mesh=self.device_mesh, in_specs=self.in_specs, ) - cpu_compiled_workload = Workload( + compiled_cpu_workload = Workload( self._compile_for_cpu(cpu_workload.executable), cpu_workload.args, cpu_workload.kwargs, ) - tt_multichip_res = DeviceRunner.run_on_multichip_device( - multichip_compiled_workload - ) - cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload) + device_res = DeviceRunner.run_on_multichip_device(compiled_device_workload) + cpu_res = DeviceRunner.run_on_cpu(compiled_cpu_workload) - self._compare(tt_multichip_res, cpu_res) + self._compare(device_res, cpu_res) def test_with_random_inputs( self, @@ -107,11 +109,11 @@ def test_with_random_inputs( ) for shape in input_shapes ] - multichip_workload = MultichipWorkload( - device_executable, inputs, mesh=self.mesh, in_specs=self.in_specs + device_workload = MultichipWorkload( + device_executable, inputs, mesh=self.device_mesh, in_specs=self.in_specs ) cpu_workload = Workload(cpu_executable, inputs) - self.test(multichip_workload, cpu_workload) + self.test(device_workload, cpu_workload) def run_multichip_test_with_random_inputs( @@ -130,8 +132,9 @@ def run_multichip_test_with_random_inputs( Tests an input executable with random inputs in range [`minval`, `maxval`) by running it on a mesh of TT devices and comparing it to output of the cpu executable ran with the same input. """ - mesh = device_connector.get_tt_device_mesh(mesh_shape, axis_names) - tester = MultichipTester(mesh, in_specs, out_specs, comparison_config) + tester = MultichipTester( + in_specs, out_specs, mesh_shape, axis_names, comparison_config + ) tester.test_with_random_inputs( device_executable, cpu_executable, input_shapes, minval, maxval ) diff --git a/tests/infra/workload.py b/tests/infra/workload.py index a71d7b15..38257d7d 100644 --- a/tests/infra/workload.py +++ b/tests/infra/workload.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -import jax +from jax.sharding import Mesh, PartitionSpec from typing import Any, Callable, Mapping, Optional, Sequence @@ -36,5 +36,5 @@ class MultichipWorkload(Workload): An extension of the Workload dataclass that includes a mesh and partition specs, necessary for multichip sharding. """ - mesh: jax.sharding.Mesh = None - in_specs: Sequence[jax.sharding.PartitionSpec] = None + device_mesh: Mesh = None + in_specs: Sequence[PartitionSpec] = None diff --git a/tests/jax/multichip/manual/all_gather.py b/tests/jax/multichip/manual/all_gather.py index ad63d2d4..55eb5b2f 100644 --- a/tests/jax/multichip/manual/all_gather.py +++ b/tests/jax/multichip/manual/all_gather.py @@ -10,19 +10,19 @@ from tests.utils import make_partition_spec -@pytest.mark.parametrize("x_shape", [(8192, 784)]) +@pytest.mark.parametrize(("x_shape", "axis_names"), [((8192, 784), ("batch",))]) @pytest.mark.skip(reason=compile_fail("Multichip still in development")) -def test_all_gather(x_shape: tuple): +def test_all_gather(x_shape: tuple, axis_names: tuple): def fwd(batch): - act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True) + act = jax.lax.all_gather(batch, axis_names, axis=0, tiled=True) return act def golden_fwd(batch): return jnp.tile(batch, (2, 1)) - in_specs = (make_partition_spec(("batch")),) - out_specs = make_partition_spec(("batch")) + in_specs = (make_partition_spec(axis_names),) + out_specs = make_partition_spec(axis_names) run_multichip_test_with_random_inputs( - fwd, golden_fwd, [x_shape], (2,), ("batch"), in_specs, out_specs + fwd, golden_fwd, [x_shape], (2,), axis_names, in_specs, out_specs ) diff --git a/tests/jax/multichip/manual/unary_eltwise.py b/tests/jax/multichip/manual/unary_eltwise.py index 64995a34..5da81970 100644 --- a/tests/jax/multichip/manual/unary_eltwise.py +++ b/tests/jax/multichip/manual/unary_eltwise.py @@ -10,12 +10,12 @@ from tests.utils import make_partition_spec -@pytest.mark.parametrize("x_shape", [(256, 256)]) +@pytest.mark.parametrize(("x_shape", "axis_names"), [((256, 256), ("x", "y"))]) @pytest.mark.skip(reason=compile_fail("Multichip still in development")) -def test_unary_eltwise(x_shape: tuple): +def test_unary_eltwise(x_shape: tuple, axis_names: tuple): def fwd(a_block): b_block = jnp.negative(a_block) - stitched_result = jax.lax.psum(b_block, ("x", "y")) + stitched_result = jax.lax.psum(b_block, axis_names) return stitched_result def fwd_single_device(a_block): @@ -26,9 +26,9 @@ def fwd_single_device(a_block): stitched_result = b1 + b2 return stitched_result - in_specs = (make_partition_spec((("x", "y"))),) + in_specs = (make_partition_spec(axis_names),) out_specs = make_partition_spec((None, None)) run_multichip_test_with_random_inputs( - fwd, fwd_single_device, [x_shape], (1, 2), ("x", "y"), in_specs, out_specs + fwd, fwd_single_device, [x_shape], (1, 2), axis_names, in_specs, out_specs )