Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: DRAFT Add MeshProgram class #13701

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

#0: DRAFT Add MeshProgram class #13701

wants to merge 2 commits into from

Conversation

tt-asaigal
Copy link
Contributor

This is a draft to get additional feedback

  • Includes APIs to set MeshProgram configuration across entire MeshDevice or per device in the Mesh. APIs are analogous to Program config APIs
  • Basic getter APIs to return individual programs and state across MeshProgram
  • Relies on distribute_impl_ and distribute_to_mesh_device_ functions
    • distribute_impl_ serves as the MeshProgram entry point on the Controller or Executor to process attributes of this data structure on host(s)
    • This function is currently implemented as a simple loop, but it can be swapped out for a set of RPC calls on the Controller and asynchronous calls on the executor
    • distribute_to_mesh_device_impl_ serves as the interface between the MeshProgram on host and the MeshDevice. Curently used in EnqueueMeshProgram and implemented using a simple loop. Can be used to interface with TT-Fabric, once the infra is available
  • Design aspects to consider as we go along:
    • Does a MeshProgram span Controllers, or is it limited to a Controller connected to a single cluster. APIs and heirachy may be easier with MeshProgram <--> Controller <--> MeshDevice mapping
    • For programs we want to broadcast, we don't need the host to perform repeated work across the entire device mesh (it currently does). Likely makes sense to have a bcast setting in the MeshProgram class, to ensure that host sets configuration data once with fast dispatch/fabric performing the bcast
    • Potential Hierarchy:
      • Controller populates MeshProgram with kernels, sems, CBs and RTAs (individual programs, single program bcast or multi-program bcast). Population can be done with a reimplementation of distributed_impl_ on the controller
      • MeshProgram is sent to Executors through a virtual CQ (RPC call + cq_id). This is through a specialized distribute_to_mesh_device_impl_ on the Controller. - Executors get a MeshProgram, which can be scattered or broadcasted through the specified CQ using Fast Dispatch and Fabric. Assembling FD/Fabric commands can be host intensive, and it would thus make sense to distribute these steps across Exectuors. Executors get the MeshProgram to the MeshDevice through a specialized distribute_to_mesh_device_impl_
  • For generic entry points to mutate Mesh data structures and send them to the Mesh device, we need generic distribute* functions that can accept any data type and perform generic processing (assemble FD commands, make RPC calls, send programs or data to MeshDevice, etc.)

  - Includes APIs to set MeshProgram configuration across entire MeshDevice
    or per device in the Mesh. APIs are analogous to Program config APIs
  - Basic getter APIs to return individual programs and state across MeshProgram
  - Relies on distribute_impl_ and distribute_to_mesh_device_ functions
    - distribute_impl_ serves as the MeshProgram entry point on the Controller
      or Executor to process attributes of this data structure on host(s)
    - This function is currently implemented as a simple loop, but it can be
      swapped out for a set of RPC calls on the Controller and asynchronous calls
      on the executor
    - distribute_to_mesh_device_impl_ serves as the interface between the MeshProgram
      on host and the MeshDevice. Curently used in EnqueueMeshProgram and implemented
      using a simple loop. Can be used to interface with TT-Fabric, once the infra is
      available
  - Design aspects to consider as we go along:
    - Does a MeshProgram span Controllers, or is it limited to a Controller connected to
      a single cluster. APIs and heirachy may be easier with MeshProgram <--> Controller
      <--> MeshDevice mapping
    - For programs we want to broadcast, we don't need the host to perform repeated work
      across the entire device mesh (it currently does). Likely makes sense to have a bcast
      setting in the MeshProgram class, to ensure that host sets configuration data once
      with fast dispatch/fabric performing the bcast
    - Potential Hierarchy:
      - Controller populates MeshProgram with kernels, sems, CBs and RTAs (individual programs,
        single program bcast or multi-program bcast). Population can be done with a reimplementation
        of distributed_impl_ on the controller
      - MeshProgram is sent to Executors through a virtual CQ (RPC call + cq_id). This is through
        a specialized distribute_to_mesh_device_impl_ on the Controller.
      - Executors get a MeshProgram, which can be scattered or broadcasted through the specified CQ
        using Fast Dispatch and Fabric. Assembling FD/Fabric commands can be host intensive, and it
        would thus make sense to distribute these steps across Exectuors. Executors get the MeshProgram
        to the MeshDevice through a specialized distribute_to_mesh_device_impl_
   - For generic entry points to mutate Mesh data structures and send them to the Mesh device, we need
     generic distribute* functions that can accept any data type and perform generic processing (assemble
     FD commands, make RPC calls, send programs or data to MeshDevice, etc.)
Comment on lines +19 to +73
template<typename T>
T distributed_impl_(const std::function<T(Program&)>& callable) {
if constexpr (std::is_same<T, void>::value) {
for (std::size_t program_idx = 0; program_idx < this->programs.size(); program_idx++) {
callable(*this->programs.at(program_idx));
}
} else {
for (std::size_t program_idx = 0; program_idx < this->programs.size() - 1; program_idx++) {
callable(*this->programs.at(program_idx));
}
return callable(*this->programs.at(this->programs.size() -1));
}
}

template<typename T>
std::vector<T> distributed_impl_(const std::variant<std::function<T(Program&)>, std::function<T(Program&, Device*)>>& callable, std::shared_ptr<MeshDevice> mesh_device = nullptr) const {
std::vector<T> rval = {};
std::vector<Device*> devices = {};
if (mesh_device != nullptr) {
devices = mesh_device->get_devices();
TT_ASSERT(devices.size() == this->programs.size(),
"MeshProgram created for {} devices cannot be mapped to a MeshDevice with {} devices",
this->programs.size(), devices.size());
TT_ASSERT(std::holds_alternative<std::function<T(Program&, Device*)>>(callable));
auto f = std::get<std::function<T(Program&, Device*)>>(callable);
for (std::size_t program_idx = 0; program_idx < devices.size(); program_idx++) {
rval.push_back(f(*this->programs.at(program_idx), devices.at(program_idx)));
}
} else {
TT_ASSERT(std::holds_alternative<std::function<T(Program&)>>(callable));
auto f = std::get<std::function<T(Program&)>>(callable);
for (std::size_t program_idx = 0; program_idx < this->programs.size() - 1; program_idx++) {
rval.push_back(f(*this->programs.at(program_idx)));
}
}
return rval;
}

template<typename T>
T distribute_to_mesh_device_impl_(const std::function<T(Program&, Device*)>& callable, std::shared_ptr<MeshDevice>& mesh_device) {
auto devices = mesh_device->get_devices();
TT_ASSERT(devices.size() == this->programs.size(),
"MeshProgram created for {} devices cannot be mapped to a MeshDevice with {} devices",
this->programs.size(), devices.size());
if constexpr (std::is_same<T, void>::value) {
for (std::size_t program_idx = 0; program_idx < devices.size(); program_idx++) {
callable(*this->programs.at(program_idx), devices.at(program_idx));
}
} else {
for (std::size_t program_idx = 0; program_idx < devices.size() - 1; program_idx++) {
callable(*this->programs.at(program_idx), devices.at(program_idx));
}
return callable(*this->programs.at(devices.size() -1), devices.at(devices.size() -1));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should definitely not be part of any interface

Comment on lines +79 to +133
uint32_t CreateSemaphore(
MeshProgram& mesh_program,
const std::variant<CoreRange, CoreRangeSet> &core_spec,
uint32_t initial_value,
CoreType core_type = CoreType::WORKER);

uint32_t CreateSemaphore(
MeshProgram& mesh_program,
const std::variant<CoreRange, CoreRangeSet> &core_spec,
uint32_t initial_value,
CoreType core_type,
chip_id_t device_id);

CBHandle CreateCircularBuffer(
MeshProgram& mesh_program,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
const CircularBufferConfig &config);

CBHandle CreateCircularBuffer(
MeshProgram& mesh_program,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
const CircularBufferConfig &config,
chip_id_t device_id);

void SetRuntimeArgs(
MeshProgram& mesh_program,
KernelHandle kernel,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
const std::vector<uint32_t> &runtime_args);

void SetRuntimeArgs(
MeshProgram& mesh_program,
KernelHandle kernel,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
const std::vector<uint32_t> &runtime_args,
chip_id_t device_id);

KernelHandle CreateKernel(
MeshProgram& mesh_program,
const std::string &file_name,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
const std::variant<DataMovementConfig, ComputeConfig, EthernetConfig> &config);

KernelHandle CreateKernel(
MeshProgram& mesh_program,
const std::string &file_name,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
const std::variant<DataMovementConfig, ComputeConfig, EthernetConfig> &config,
chip_id_t device_id);

void EnqueueMeshProgram(
uint8_t cq_id, MeshProgram& mesh_program, std::shared_ptr<MeshDevice> mesh_device, bool blocking);

void Finish(std::shared_ptr<MeshDevice> mesh_device, uint8_t cq_id);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try to stick as true to the APIs as possible for now since there will be some refactor work.

KernelHandle CreateKernel(
MeshProgram& mesh_program,
const std::string &file_name,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like we shouldn't be trying to program to the MeshProgram and then to also the Program level or we'll end up feeding two sets of inputs at different abstraction levels

}

template<typename T>
std::vector<T> distributed_impl_(const std::variant<std::function<T(Program&)>, std::function<T(Program&, Device*)>>& callable, std::shared_ptr<MeshDevice> mesh_device = nullptr) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Includes APIs to set MeshProgram configuration across entire MeshDevice or per device in the Mesh. APIs are analogous to Program config APIs
Basic getter APIs to return individual programs and state across MeshProgram
Relies on distribute_impl_ and distribute_to_mesh_device_ functions
distribute_impl_ serves as the MeshProgram entry point on the Controller or Executor to process attributes of this data structure on host(s)
This function is currently implemented as a simple loop, but it can be swapped out for a set of RPC calls on the Controller and asynchronous calls on the executor
distribute_to_mesh_device_impl_ serves as the interface between the MeshProgram on host and the MeshDevice.

The header interface should remain the same as we swap out the changes in the impl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants