Skip to content

Commit

Permalink
Adding workload
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 5, 2025
1 parent d2dcb1e commit f29e6c1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 54 deletions.
40 changes: 13 additions & 27 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
#
# SPDX-License-Identifier: Apache-2.0

import inspect
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from typing import Callable, Sequence

import jax

from .device_connector import DeviceType, device_connector
from .mulitchip_workload import MultichipWorkload
from .types import Tensor
from .workload import Workload

from jax.sharding import Mesh, PartitionSpec, NamedSharding

import inspect


class DeviceRunner:
"""
Class providing methods to put and run workload on any supported device.
Expand Down Expand Up @@ -71,36 +69,26 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def put_with_none_sharding(
workload: Workload,
mesh: jax.sharding.Mesh,
in_specs: Sequence[jax.sharding.PartitionSpec],
) -> Tensor:
def put_with_none_sharding(multichip_workload: MultichipWorkload) -> MultichipWorkload:
"""Gives inputs shardings for multichip workloads"""
args_on_device = []
spec_index = 0
for arg in workload.args:
for arg in multichip_workload.args:
if not isinstance(arg, Tensor):
args_on_device.append(arg)
else:
args_on_device.append(
DeviceRunner._put_tensor_none_sharding(
arg, mesh, in_specs[spec_index]
)
)
spec_index += 1
args_on_device.append(DeviceRunner._put_tensor_none_sharding(arg, multichip_workload.mesh, multichip_workload.in_specs[spec_index]))
spec_index+=1

kwargs_on_device = {}
for key, value in workload.kwargs.items():
for key, value in multichip_workload.kwargs.items():
if not isinstance(value, Tensor):
kwargs_on_device[key] = value
else:
kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding(
value, mesh, in_specs[spec_index]
)
spec_index += 1
kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding(value, multichip_workload.mesh, multichip_workload.in_specs[spec_index])
spec_index+=1

return Workload(workload.executable, args_on_device, kwargs_on_device)
return MultichipWorkload(multichip_workload.executable, args_on_device, kwargs_on_device, mesh = multichip_workload.mesh, in_specs = multichip_workload.in_specs)

@staticmethod
def _run_manual(workload: Workload) -> Tensor:
Expand All @@ -119,9 +107,7 @@ def _run_on_device(
return device_workload.execute()

@staticmethod
def _put_tensor_none_sharding(
tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec
) -> Tensor:
def _put_tensor_none_sharding(tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec) -> Tensor:
"""Needed for multichip: Uses put_device to give inputs shardings."""
none_tuple = (None,) * len(in_spec)
none_spec = PartitionSpec(*none_tuple)
Expand Down
41 changes: 16 additions & 25 deletions tests/infra/multichip_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,22 @@
from __future__ import annotations

import jax
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
from typing import Callable, Sequence

from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .mulitchip_workload import MultichipWorkload
from .workload import Workload


class MultichipTester(BaseTester):
"""Specific tester for ops."""

def __init__(
self,
mesh: jax.Mesh,
in_specs: tuple,
out_specs: jax.sharding.PartitionSpec,
comparison_config: ComparisonConfig = ComparisonConfig(),
self, mesh: jax.Mesh, in_specs: tuple, out_specs: jax.sharding.PartitionSpec, comparison_config: ComparisonConfig = ComparisonConfig()
) -> None:
self.mesh = mesh
self.in_specs = in_specs
Expand All @@ -41,14 +38,13 @@ def _compile(
) -> Callable:
"""Sets up `executable` for just-in-time compile."""
module_sharded = shard_map(
executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs
executable,
mesh=self.mesh,
in_specs=self.in_specs,
out_specs=self.out_specs
)
output_sharding = NamedSharding(self.mesh, self.out_specs)
return jax.jit(
module_sharded,
out_shardings=output_sharding,
static_argnames=static_argnames,
)
return jax.jit(module_sharded, out_shardings=output_sharding, static_argnames=static_argnames)

def test(self, workload: Workload, cpu_workload: Workload) -> None:
"""
Expand All @@ -61,13 +57,11 @@ def test(self, workload: Workload, cpu_workload: Workload) -> None:
cpu_compiled_executable, cpu_workload.args, cpu_workload.kwargs
)

compiled_workload = Workload(
compiled_executable, workload.args, workload.kwargs
compiled_workload = MultichipWorkload(
compiled_executable, workload.args, workload.kwargs, mesh = self.mesh, in_specs=self.in_specs
)

non_sharded_workload = DeviceRunner.put_with_none_sharding(
compiled_workload, self.mesh, in_specs=self.in_specs
)
non_sharded_workload = DeviceRunner.put_with_none_sharding(compiled_workload)

tt_res = DeviceRunner.run_manual(non_sharded_workload)
cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload)
Expand All @@ -87,10 +81,7 @@ def test_with_random_inputs(
TT device and CPU and comparing the results.
"""
inputs = [
jax.random.uniform(
key=jax.random.key(0), shape=shape, minval=minval, maxval=maxval
)
for shape in input_shapes
jax.random.uniform(key = jax.random.key(0), shape = shape, minval=minval, maxval=maxval) for shape in input_shapes
]
workload = Workload(f, inputs)
cpu_workload = Workload(golden_f, inputs)
Expand All @@ -101,16 +92,16 @@ def run_multichip_test_with_random_inputs(
mesh_test: Callable,
golden_test: Callable,
input_shapes: Sequence[tuple],
mesh: jax.Mesh,
in_specs: tuple,
mesh: jax.Mesh,
in_specs: Sequence[jax.sharding.PartitionSpec],
out_specs: jax.sharding.PartitionSpec,
minval: float = 0.0,
maxval: float = 1.0,
comparison_config: ComparisonConfig = ComparisonConfig(),
) -> None:
"""
Tests `mesh_test` with random inputs in range [`minval`, `maxval`) by running it on
Tests `op` with random inputs in range [`minval`, `maxval`) by running it on
TT device and CPU and comparing the results based on `comparison_config`.
"""
tester = MultichipTester(mesh, in_specs, out_specs, comparison_config)
tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval)
tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval)
19 changes: 19 additions & 0 deletions tests/infra/multichip_workload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
import jax
from typing import Sequence

from .workload import Workload


@dataclass
class MultichipWorkload(Workload):
"""
Convenience dataclass storing a callable and its positional and keyword arguments.
"""
mesh: jax.sharding.Mesh = None
in_specs: Sequence[jax.sharding.PartitionSpec] = None

2 changes: 1 addition & 1 deletion tests/jax/multichip/manual/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@pytest.mark.parametrize("x_shape", [(8192, 784)])
@pytest.mark.skip(reason="Compilation fails")
def all_gather_test(x_shape: tuple):
def test_all_gather(x_shape: tuple):
def fwd(batch):
act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True)
return act
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/multichip/manual/unary_eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@pytest.mark.parametrize("x_shape", [(256, 256)])
@pytest.mark.skip(reason="Multichip still in development")
def unary_eltwise_test(x_shape: tuple):
def test_unary_eltwise(x_shape: tuple):
def fwd(a_block):
b_block = jnp.negative(a_block)
stitched_result = jax.lax.psum(b_block, ("x", "y"))
Expand Down

0 comments on commit f29e6c1

Please sign in to comment.