diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 47b06ba4..38b46f5c 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -2,19 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 +import inspect +import jax +from jax.sharding import Mesh, PartitionSpec, NamedSharding from typing import Callable, Sequence -import jax from .device_connector import DeviceType, device_connector +from .mulitchip_workload import MultichipWorkload from .types import Tensor from .workload import Workload -from jax.sharding import Mesh, PartitionSpec, NamedSharding - -import inspect - - class DeviceRunner: """ Class providing methods to put and run workload on any supported device. @@ -71,36 +69,26 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: raise NotImplementedError("Support for GPUs not implemented") @staticmethod - def put_with_none_sharding( - workload: Workload, - mesh: jax.sharding.Mesh, - in_specs: Sequence[jax.sharding.PartitionSpec], - ) -> Tensor: + def put_with_none_sharding(multichip_workload: MultichipWorkload) -> MultichipWorkload: """Gives inputs shardings for multichip workloads""" args_on_device = [] spec_index = 0 - for arg in workload.args: + for arg in multichip_workload.args: if not isinstance(arg, Tensor): args_on_device.append(arg) else: - args_on_device.append( - DeviceRunner._put_tensor_none_sharding( - arg, mesh, in_specs[spec_index] - ) - ) - spec_index += 1 + args_on_device.append(DeviceRunner._put_tensor_none_sharding(arg, multichip_workload.mesh, multichip_workload.in_specs[spec_index])) + spec_index+=1 kwargs_on_device = {} - for key, value in workload.kwargs.items(): + for key, value in multichip_workload.kwargs.items(): if not isinstance(value, Tensor): kwargs_on_device[key] = value else: - kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding( - value, mesh, in_specs[spec_index] - ) - spec_index += 1 + kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding(value, multichip_workload.mesh, multichip_workload.in_specs[spec_index]) + spec_index+=1 - return Workload(workload.executable, args_on_device, kwargs_on_device) + return MultichipWorkload(multichip_workload.executable, args_on_device, kwargs_on_device, mesh = multichip_workload.mesh, in_specs = multichip_workload.in_specs) @staticmethod def _run_manual(workload: Workload) -> Tensor: @@ -119,9 +107,7 @@ def _run_on_device( return device_workload.execute() @staticmethod - def _put_tensor_none_sharding( - tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec - ) -> Tensor: + def _put_tensor_none_sharding(tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec) -> Tensor: """Needed for multichip: Uses put_device to give inputs shardings.""" none_tuple = (None,) * len(in_spec) none_spec = PartitionSpec(*none_tuple) diff --git a/tests/infra/multichip_tester.py b/tests/infra/multichip_tester.py index fe2775a2..dc546ded 100644 --- a/tests/infra/multichip_tester.py +++ b/tests/infra/multichip_tester.py @@ -5,13 +5,14 @@ from __future__ import annotations import jax +from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.shard_map import shard_map -from jax.sharding import NamedSharding from typing import Callable, Sequence from .base_tester import BaseTester from .comparison import ComparisonConfig from .device_runner import DeviceRunner +from .mulitchip_workload import MultichipWorkload from .workload import Workload @@ -19,11 +20,7 @@ class MultichipTester(BaseTester): """Specific tester for ops.""" def __init__( - self, - mesh: jax.Mesh, - in_specs: tuple, - out_specs: jax.sharding.PartitionSpec, - comparison_config: ComparisonConfig = ComparisonConfig(), + self, mesh: jax.Mesh, in_specs: tuple, out_specs: jax.sharding.PartitionSpec, comparison_config: ComparisonConfig = ComparisonConfig() ) -> None: self.mesh = mesh self.in_specs = in_specs @@ -41,14 +38,13 @@ def _compile( ) -> Callable: """Sets up `executable` for just-in-time compile.""" module_sharded = shard_map( - executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs + executable, + mesh=self.mesh, + in_specs=self.in_specs, + out_specs=self.out_specs ) output_sharding = NamedSharding(self.mesh, self.out_specs) - return jax.jit( - module_sharded, - out_shardings=output_sharding, - static_argnames=static_argnames, - ) + return jax.jit(module_sharded, out_shardings=output_sharding, static_argnames=static_argnames) def test(self, workload: Workload, cpu_workload: Workload) -> None: """ @@ -61,13 +57,11 @@ def test(self, workload: Workload, cpu_workload: Workload) -> None: cpu_compiled_executable, cpu_workload.args, cpu_workload.kwargs ) - compiled_workload = Workload( - compiled_executable, workload.args, workload.kwargs + compiled_workload = MultichipWorkload( + compiled_executable, workload.args, workload.kwargs, mesh = self.mesh, in_specs=self.in_specs ) - non_sharded_workload = DeviceRunner.put_with_none_sharding( - compiled_workload, self.mesh, in_specs=self.in_specs - ) + non_sharded_workload = DeviceRunner.put_with_none_sharding(compiled_workload) tt_res = DeviceRunner.run_manual(non_sharded_workload) cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload) @@ -87,10 +81,7 @@ def test_with_random_inputs( TT device and CPU and comparing the results. """ inputs = [ - jax.random.uniform( - key=jax.random.key(0), shape=shape, minval=minval, maxval=maxval - ) - for shape in input_shapes + jax.random.uniform(key = jax.random.key(0), shape = shape, minval=minval, maxval=maxval) for shape in input_shapes ] workload = Workload(f, inputs) cpu_workload = Workload(golden_f, inputs) @@ -101,16 +92,16 @@ def run_multichip_test_with_random_inputs( mesh_test: Callable, golden_test: Callable, input_shapes: Sequence[tuple], - mesh: jax.Mesh, - in_specs: tuple, + mesh: jax.Mesh, + in_specs: Sequence[jax.sharding.PartitionSpec], out_specs: jax.sharding.PartitionSpec, minval: float = 0.0, maxval: float = 1.0, comparison_config: ComparisonConfig = ComparisonConfig(), ) -> None: """ - Tests `mesh_test` with random inputs in range [`minval`, `maxval`) by running it on + Tests `op` with random inputs in range [`minval`, `maxval`) by running it on TT device and CPU and comparing the results based on `comparison_config`. """ tester = MultichipTester(mesh, in_specs, out_specs, comparison_config) - tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval) + tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval) \ No newline at end of file diff --git a/tests/infra/multichip_workload.py b/tests/infra/multichip_workload.py new file mode 100644 index 00000000..94268e63 --- /dev/null +++ b/tests/infra/multichip_workload.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +import jax +from typing import Sequence + +from .workload import Workload + + +@dataclass +class MultichipWorkload(Workload): + """ + Convenience dataclass storing a callable and its positional and keyword arguments. + """ + mesh: jax.sharding.Mesh = None + in_specs: Sequence[jax.sharding.PartitionSpec] = None + diff --git a/tests/jax/multichip/manual/all_gather.py b/tests/jax/multichip/manual/all_gather.py index bb2ccc41..ea937241 100644 --- a/tests/jax/multichip/manual/all_gather.py +++ b/tests/jax/multichip/manual/all_gather.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize("x_shape", [(8192, 784)]) @pytest.mark.skip(reason="Compilation fails") -def all_gather_test(x_shape: tuple): +def test_all_gather(x_shape: tuple): def fwd(batch): act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True) return act diff --git a/tests/jax/multichip/manual/unary_eltwise.py b/tests/jax/multichip/manual/unary_eltwise.py index f1903737..2d725485 100644 --- a/tests/jax/multichip/manual/unary_eltwise.py +++ b/tests/jax/multichip/manual/unary_eltwise.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize("x_shape", [(256, 256)]) @pytest.mark.skip(reason="Multichip still in development") -def unary_eltwise_test(x_shape: tuple): +def test_unary_eltwise(x_shape: tuple): def fwd(a_block): b_block = jnp.negative(a_block) stitched_result = jax.lax.psum(b_block, ("x", "y"))