Skip to content

Commit

Permalink
Added multichip infra
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 5, 2025
1 parent 83c3808 commit 0086d31
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 3 deletions.
1 change: 1 addition & 0 deletions tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .model_tester import ModelTester, RunMode
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor, supported_dtypes
from .multichip_tester import run_multichip_test_with_random_inputs
66 changes: 63 additions & 3 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@
#
# SPDX-License-Identifier: Apache-2.0

import inspect
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from typing import Callable, Sequence

import jax

from .device_connector import DeviceType, device_connector
from .multichip_workload import MultichipWorkload
from .types import Tensor
from .workload import Workload

import inspect


class DeviceRunner:
"""
Class providing methods to put and run workload on any supported device.
"""

@staticmethod
def run_manual(workload: Workload) -> Tensor:
"""Runs `workload` on TT device."""
return DeviceRunner._run_manual(workload)

@staticmethod
def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor:
"""Runs `workload` on TT device."""
Expand Down Expand Up @@ -63,6 +69,51 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
"""Puts `tensors` on GPU."""
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def put_with_none_sharding(
multichip_workload: MultichipWorkload,
) -> MultichipWorkload:
"""Gives inputs shardings for multichip workloads"""
args_on_device = []
spec_index = 0
for arg in multichip_workload.args:
if not isinstance(arg, Tensor):
args_on_device.append(arg)
else:
args_on_device.append(
DeviceRunner._put_tensor_none_sharding(
arg,
multichip_workload.mesh,
multichip_workload.in_specs[spec_index],
)
)
spec_index += 1

kwargs_on_device = {}
for key, value in multichip_workload.kwargs.items():
if not isinstance(value, Tensor):
kwargs_on_device[key] = value
else:
kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding(
value,
multichip_workload.mesh,
multichip_workload.in_specs[spec_index],
)
spec_index += 1

return MultichipWorkload(
multichip_workload.executable,
args_on_device,
kwargs_on_device,
mesh=multichip_workload.mesh,
in_specs=multichip_workload.in_specs,
)

@staticmethod
def _run_manual(workload: Workload) -> Tensor:
"""Runs `workload` on a device."""
return workload.execute().block_until_ready()

@staticmethod
def _run_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
Expand All @@ -74,6 +125,15 @@ def _run_on_device(
with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_tensor_none_sharding(
tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec
) -> Tensor:
"""Needed for multichip: Uses put_device to give inputs shardings."""
none_tuple = (None,) * len(in_spec)
none_spec = PartitionSpec(*none_tuple)
return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True)

@staticmethod
def _put_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
Expand Down
119 changes: 119 additions & 0 deletions tests/infra/multichip_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import jax
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.shard_map import shard_map
from typing import Callable, Sequence

from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .multichip_workload import MultichipWorkload
from .workload import Workload


class MultichipTester(BaseTester):
"""Specific tester for ops."""

def __init__(
self,
mesh: jax.Mesh,
in_specs: tuple,
out_specs: jax.sharding.PartitionSpec,
comparison_config: ComparisonConfig = ComparisonConfig(),
) -> None:
self.mesh = mesh
self.in_specs = in_specs
self.out_specs = out_specs
super().__init__(comparison_config)

def _compile_cpu(
self, executable: Callable, static_argnames: Sequence[str] = None
) -> Callable:
"""Sets up `executable` for just-in-time compile - specifically for CPU."""
return jax.jit(executable, static_argnames=static_argnames)

def _compile(
self, executable: Callable, static_argnames: Sequence[str] = None
) -> Callable:
"""Sets up `executable` for just-in-time compile."""
module_sharded = shard_map(
executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs
)
output_sharding = NamedSharding(self.mesh, self.out_specs)
return jax.jit(
module_sharded,
out_shardings=output_sharding,
static_argnames=static_argnames,
)

def test(self, workload: Workload, cpu_workload: Workload) -> None:
"""
Runs test by running `workload` on TT device and CPU and comparing the results.
"""
compiled_executable = self._compile(workload.executable)
cpu_compiled_executable = self._compile_cpu(cpu_workload.executable)

cpu_compiled_workload = Workload(
cpu_compiled_executable, cpu_workload.args, cpu_workload.kwargs
)

compiled_workload = MultichipWorkload(
compiled_executable,
workload.args,
workload.kwargs,
mesh=self.mesh,
in_specs=self.in_specs,
)

non_sharded_workload = DeviceRunner.put_with_none_sharding(compiled_workload)

tt_res = DeviceRunner.run_manual(non_sharded_workload)
cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload)

self._compare(tt_res, cpu_res)

def test_with_random_inputs(
self,
f: Callable,
golden_f: Callable,
input_shapes: Sequence[tuple],
minval: float = 0.0,
maxval: float = 1.0,
) -> None:
"""
Tests `f` by running it with random inputs in range [`minval`, `maxval`) on
TT device and CPU and comparing the results.
"""
inputs = [
jax.random.uniform(
key=jax.random.key(0), shape=shape, minval=minval, maxval=maxval
)
for shape in input_shapes
]
workload = Workload(f, inputs)
cpu_workload = Workload(golden_f, inputs)
self.test(workload, cpu_workload)


def run_multichip_test_with_random_inputs(
mesh_test: Callable,
golden_test: Callable,
input_shapes: Sequence[tuple],
mesh: jax.Mesh,
in_specs: Sequence[jax.sharding.PartitionSpec],
out_specs: jax.sharding.PartitionSpec,
minval: float = 0.0,
maxval: float = 1.0,
comparison_config: ComparisonConfig = ComparisonConfig(),
) -> None:
"""
Tests `op` with random inputs in range [`minval`, `maxval`) by running it on
TT device and CPU and comparing the results based on `comparison_config`.
"""
tester = MultichipTester(mesh, in_specs, out_specs, comparison_config)
tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval)
19 changes: 19 additions & 0 deletions tests/infra/multichip_workload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
import jax
from typing import Sequence

from .workload import Workload


@dataclass
class MultichipWorkload(Workload):
"""
Convenience dataclass storing a callable and its positional and keyword arguments.
"""

mesh: jax.sharding.Mesh = None
in_specs: Sequence[jax.sharding.PartitionSpec] = None
33 changes: 33 additions & 0 deletions tests/jax/multichip/manual/all_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import jax
import jax.numpy as jnp
from jax import jit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec
from functools import partial
from infra import run_multichip_test_with_random_inputs
import pytest


@pytest.mark.parametrize("x_shape", [(8192, 784)])
@pytest.mark.skip(reason="Compilation fails")
def test_all_gather(x_shape: tuple):
def fwd(batch):
act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True)
return act

def golden_fwd(batch):
return jnp.tile(batch, (2, 1))

devices = jax.devices("tt")
mesh = jax.make_mesh((2,), ("batch"), devices=devices)

in_specs = (PartitionSpec("batch"),)
out_specs = PartitionSpec("batch")

run_multichip_test_with_random_inputs(
fwd, golden_fwd, [x_shape], mesh, in_specs, out_specs
)
38 changes: 38 additions & 0 deletions tests/jax/multichip/manual/unary_eltwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import jax
import jax.numpy as jnp
from jax import jit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec
from functools import partial
from infra import run_multichip_test_with_random_inputs
import pytest


@pytest.mark.parametrize("x_shape", [(256, 256)])
@pytest.mark.skip(reason="Multichip still in development")
def test_unary_eltwise(x_shape: tuple):
def fwd(a_block):
b_block = jnp.negative(a_block)
stitched_result = jax.lax.psum(b_block, ("x", "y"))
return stitched_result

def fwd_single_device(a_block):
a1, a2 = jnp.split(a_block, 2, axis=1)

b1, b2 = jnp.negative(a1), jnp.negative(a2)

stitched_result = b1 + b2
return stitched_result

devices = jax.devices("tt")
mesh = jax.make_mesh((1, 2), ("x", "y"), devices=devices)
in_specs = (PartitionSpec("x", "y"),)
out_specs = PartitionSpec(None, None)

run_multichip_test_with_random_inputs(
fwd, fwd_single_device, [x_shape], mesh, in_specs, out_specs
)

0 comments on commit 0086d31

Please sign in to comment.