Skip to content

Commit

Permalink
Adding multichip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 5, 2025
1 parent f29e6c1 commit a059245
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tests/infra/mulitchip_workload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
import jax
from typing import Sequence

from .workload import Workload


@dataclass
class MultichipWorkload(Workload):
"""
Convenience dataclass storing a callable and its positional and keyword arguments.
"""
mesh: jax.sharding.Mesh = None
in_specs: Sequence[jax.sharding.PartitionSpec] = None

0 comments on commit a059245

Please sign in to comment.