From a05924520a9e025b20436aeef7456496684dfd51 Mon Sep 17 00:00:00 2001 From: ajakovljevicTT Date: Wed, 5 Feb 2025 07:16:09 +0000 Subject: [PATCH] Adding multichip tests --- tests/infra/mulitchip_workload.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/infra/mulitchip_workload.py diff --git a/tests/infra/mulitchip_workload.py b/tests/infra/mulitchip_workload.py new file mode 100644 index 00000000..94268e63 --- /dev/null +++ b/tests/infra/mulitchip_workload.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +import jax +from typing import Sequence + +from .workload import Workload + + +@dataclass +class MultichipWorkload(Workload): + """ + Convenience dataclass storing a callable and its positional and keyword arguments. + """ + mesh: jax.sharding.Mesh = None + in_specs: Sequence[jax.sharding.PartitionSpec] = None +