Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 7, 2025
1 parent cd3f367 commit 1863d25
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 45 deletions.
2 changes: 1 addition & 1 deletion tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from .model_tester import ModelTester, RunMode
from .multichip_tester import run_multichip_test_with_random_inputs
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor, supported_dtypes, make_partition_spec
from .utils import random_tensor, make_partition_spec
70 changes: 38 additions & 32 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor:
@staticmethod
def run_on_multichip_device(multichip_workload: MultichipWorkload) -> Tensor:
"""Runs `workload` on a multichip device."""
sharded_workload = DeviceRunner._safely_put_sharding_on_workload(
sharded_workload = DeviceRunner._put_multichip_workload_on_device(
multichip_workload
)
return sharded_workload.execute().block_until_ready()
Expand Down Expand Up @@ -72,18 +72,50 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def _safely_put_sharding_on_workload(
def _run_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(workload, device_type, device_num)
device = device_connector.connect_device(device_type, device_num)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_sharded_tensors_on_multichip_device(
tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec
) -> Tensor:
"""
Needed for multichip: Uses put_device to give inputs shardings.
We just put dummy sharding equal to none, since jax needs to have some notion of sharding when running graph with these buffers as input.
"""
none_tuple = (None,) * len(in_spec)
none_spec = PartitionSpec(*none_tuple)
return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True)

@staticmethod
def _put_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type, device_num)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
def _put_multichip_workload_on_device(
multichip_workload: MultichipWorkload,
) -> MultichipWorkload:
"""Gives the workload inputs shardings, necessary for multichip workloads"""
args_on_device = []
spec_index = 0
# TODO: It might necessary to put a try-except block here, but holding that off until we come across a case where it's needed.
for arg in multichip_workload.args:
if not isinstance(arg, Tensor):
args_on_device.append(arg)
else:
args_on_device.append(
DeviceRunner._put_tensor_none_sharding(
DeviceRunner._put_sharded_tensors_on_multichip_device(
arg,
multichip_workload.mesh,
multichip_workload.in_specs[spec_index],
Expand All @@ -96,7 +128,9 @@ def _safely_put_sharding_on_workload(
if not isinstance(value, Tensor):
kwargs_on_device[key] = value
else:
kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding(
kwargs_on_device[
key
] = DeviceRunner._put_sharded_tensors_on_multichip_device(
value,
multichip_workload.mesh,
multichip_workload.in_specs[spec_index],
Expand All @@ -111,34 +145,6 @@ def _safely_put_sharding_on_workload(
in_specs=multichip_workload.in_specs,
)

@staticmethod
def _run_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Tensor:
"""Runs `workload` on device identified by `device_type`."""
device_workload = DeviceRunner._put_on_device(workload, device_type, device_num)
device = device_connector.connect_device(device_type, device_num)

with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_tensor_none_sharding(
tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec
) -> Tensor:
"""Needed for multichip: Uses put_device to give inputs shardings."""
none_tuple = (None,) * len(in_spec)
none_spec = PartitionSpec(*none_tuple)
return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True)

@staticmethod
def _put_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
) -> Workload:
"""Puts `workload` on device and returns it."""
device = device_connector.connect_device(device_type, device_num)
return DeviceRunner._safely_put_workload_on_device(workload, device)

@staticmethod
def _put_tensors_on_device(
device_type: DeviceType, tensors: Sequence[Tensor]
Expand Down
25 changes: 14 additions & 11 deletions tests/infra/multichip_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ class MultichipTester(BaseTester):
and output sharding specifications.
Attributes:
mesh (jax.Mesh): The device mesh over which the computation is distributed.
device_mesh (jax.Mesh): The device mesh over which the computation is distributed.
in_specs (tuple): The sharding specifications for the input tensors.
out_specs (jax.sharding.PartitionSpec): The sharding specification for the output tensor.
"""

def __init__(
self,
mesh: jax.Mesh,
device_mesh: jax.Mesh,
in_specs: tuple,
out_specs: jax.sharding.PartitionSpec,
comparison_config: ComparisonConfig = ComparisonConfig(),
) -> None:
self.mesh = mesh
super().__init__(comparison_config)
self.device_mesh = device_mesh
self.in_specs = in_specs
self.out_specs = out_specs
super().__init__(comparison_config)

def _compile_for_cpu(
self, executable: Callable, static_argnames: Sequence[str] = None
Expand All @@ -53,9 +53,12 @@ def _compile_for_device(
) -> Callable:
"""Sets up executable for just-in-time compile and execution on multichip device."""
module_sharded = shard_map(
executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs
executable,
mesh=self.device_mesh,
in_specs=self.in_specs,
out_specs=self.out_specs,
)
output_sharding = NamedSharding(self.mesh, self.out_specs)
output_sharding = NamedSharding(self.device_mesh, self.out_specs)
return jax.jit(
module_sharded,
out_shardings=output_sharding,
Expand All @@ -66,13 +69,13 @@ def test(
self, multichip_workload: MultichipWorkload, cpu_workload: Workload
) -> None:
"""
Runs test by running `workload` on TT device and CPU and comparing the results.
Runs test by running `workload` on TT device and 'cpu_workload' on the CPU and comparing the results.
"""
multichip_compiled_workload = MultichipWorkload(
self._compile_for_device(multichip_workload.executable),
multichip_workload.args,
multichip_workload.kwargs,
mesh=self.mesh,
device_mesh=self.device_mesh,
in_specs=self.in_specs,
)

Expand Down Expand Up @@ -107,11 +110,11 @@ def test_with_random_inputs(
)
for shape in input_shapes
]
multichip_workload = MultichipWorkload(
device_executable, inputs, mesh=self.mesh, in_specs=self.in_specs
device_workload = MultichipWorkload(
device_executable, inputs, mesh=self.device_mesh, in_specs=self.in_specs
)
cpu_workload = Workload(cpu_executable, inputs)
self.test(multichip_workload, cpu_workload)
self.test(device_workload, cpu_workload)


def run_multichip_test_with_random_inputs(
Expand Down
2 changes: 1 addition & 1 deletion tests/infra/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ class MultichipWorkload(Workload):
An extension of the Workload dataclass that includes a mesh and partition specs, necessary for multichip sharding.
"""

mesh: jax.sharding.Mesh = None
device_mesh: jax.sharding.Mesh = None
in_specs: Sequence[jax.sharding.PartitionSpec] = None

0 comments on commit 1863d25

Please sign in to comment.