Skip to content

Commit

Permalink
Implement MPI Plugin for Kubeflow Trainer
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
andreyvelich committed Feb 12, 2025
1 parent e47d8f7 commit 58beb0c
Show file tree
Hide file tree
Showing 24 changed files with 514 additions and 93 deletions.
2 changes: 1 addition & 1 deletion api.v2/openapi-spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
"type": "boolean"
},
"sshAuthMountPath": {
"description": "Directory where SSH keys are mounted.",
"description": "Directory where SSH keys are mounted. Defaults to /root/.ssh.",
"type": "string"
}
}
Expand Down
9 changes: 5 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.9.0
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.31.0
k8s.io/api v0.31.3
k8s.io/apimachinery v0.31.3
k8s.io/client-go v0.31.3
Expand Down Expand Up @@ -70,10 +71,10 @@ require (
golang.org/x/mod v0.20.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/term v0.25.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.6.0 // indirect
golang.org/x/tools v0.24.0 // indirect
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
Expand Down
18 changes: 10 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
Expand All @@ -135,20 +137,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ spec:
Defaults to false.
type: boolean
sshAuthMountPath:
description: Directory where SSH keys are mounted.
description: |-
Directory where SSH keys are mounted.
Defaults to /root/.ssh.
type: string
type: object
numNodes:
Expand Down
4 changes: 3 additions & 1 deletion manifests/v2/base/crds/kubeflow.org_trainingruntimes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ spec:
Defaults to false.
type: boolean
sshAuthMountPath:
description: Directory where SSH keys are mounted.
description: |-
Directory where SSH keys are mounted.
Defaults to /root/.ssh.
type: string
type: object
numNodes:
Expand Down
2 changes: 2 additions & 0 deletions manifests/v2/base/manager/manager.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ spec:
spec:
containers:
- name: manager
args:
- "--zap-log-level=5"
image: kubeflow/training-operator-v2
env:
- name: MY_POD_NAMESPACE
Expand Down
2 changes: 2 additions & 0 deletions manifests/v2/base/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ rules:
- apiGroups:
- ""
resources:
- configmaps
- secrets
verbs:
- create
- get
- list
- update
Expand Down
1 change: 1 addition & 0 deletions manifests/v2/base/runtimes/pre-training/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
resources:
- torch-distributed.yaml
- mpi-distributed.yaml
45 changes: 45 additions & 0 deletions manifests/v2/base/runtimes/pre-training/mpi-distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# TODO (andreyvelich): Change this to DeepSpeed or MLX runtime.
apiVersion: kubeflow.org/v2alpha1
kind: ClusterTrainingRuntime
metadata:
name: mpi-distributed
labels:
training.kubeflow.org/phase: pre-training
spec:
mlPolicy:
numNodes: 1
mpi:
numProcPerNode: 1
mpiImplementation: OpenMPI
sshAuthMountPath: /root/.ssh
template:
spec:
# TODO (andreyvelich): Use dependsOn when it is released.
startupPolicy:
startupPolicyOrder: InOrder
replicatedJobs:
- name: launcher
template:
spec:
template:
spec:
# TODO (andreyvelich): Change the command with mpirun.
containers:
- name: launcher
image: busybox
command:
- /bin/sh
- -c
- "echo 'launcher runs for 10 seconds' && sleep 100"
- name: trainer-node
template:
spec:
template:
spec:
containers:
- name: trainer
image: busybox
command:
- /bin/sh
- -c
- "echo 'launcher runs for 10 seconds' && sleep 100"
4 changes: 3 additions & 1 deletion manifests/v2/overlays/standalone/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ resources:
- https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml
images:
- name: kubeflow/training-operator-v2
newTag: latest
# TODO (andreyvelich): Change it back.
newName: training-operator-local
newTag: v1
secretGenerator:
- name: training-operator-v2-webhook-cert
namespace: kubeflow-system
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ type MPIMLPolicySource struct {

// Implementation name for the MPI to create the appropriate hostfile.
// Defaults to OpenMPI.
MPIImplementation *MPIImplementation `json:"mpiImplementation,omitempty"`
MPIImplementation MPIImplementation `json:"mpiImplementation,omitempty"`

// Directory where SSH keys are mounted.
SSHAuthMountPath *string `json:"sshAuthMountPath,omitempty"`
// Defaults to /root/.ssh.
SSHAuthMountPath string `json:"sshAuthMountPath,omitempty"`

// Whether to run training process on the launcher Job.
// Defaults to false.
Expand Down
10 changes: 0 additions & 10 deletions pkg/apis/kubeflow.org/v2alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 55 additions & 15 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ const (
// PodGroupKind is the Kind name for the PodGroup.
PodGroupKind string = "PodGroup"

// TrainJobJobsCreationSucceededMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationSucceeded"} condition.
TrainJobJobsCreationSucceededMessage = "Succeeded to create Jobs"

// TrainJobJobsBuildFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsBuildFailed"} condition.
TrainJobJobsBuildFailedMessage = "Failed to build Jobs"

// TrainJobJobsCreationFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationFailed"} condition.
TrainJobJobsCreationFailedMessage = "Failed to create Jobs"

// TrainJobSuspendedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Suspended"} condition.
TrainJobSuspendedMessage = "TrainJob is suspended"

// TrainJobResumedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
TrainJobResumedMessage = "TrainJob is resumed"

// Distributed envs for torchrun.
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
// TorchEnvNumNodes is the env name for the number of training nodes.
Expand All @@ -52,25 +72,45 @@ const (
// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TrainJobJobsCreationSucceededMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationSucceeded"} condition.
TrainJobJobsCreationSucceededMessage = "Succeeded to create Jobs"
// JobLauncher is the Job name for the launcher.
JobLauncher string = "launcher"

// TrainJobJobsBuildFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsBuildFailed"} condition.
TrainJobJobsBuildFailedMessage = "Failed to build Jobs"
// ContainerLauncher is the container name for the launcher.
ContainerLauncher string = "launcher"

// TrainJobJobsCreationFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationFailed"} condition.
TrainJobJobsCreationFailedMessage = "Failed to create Jobs"
// MPISSHAuthSecretSuffix is the name suffix for Secret with MPI SSH keys.
MPISSHAuthSecretSuffix string = "-mpi-ssh-auth"

// TrainJobSuspendedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Suspended"} condition.
TrainJobSuspendedMessage = "TrainJob is suspended"
// MPISSHAuthVolumeName is the volume name for Secret with MPI SSH keys.
MPISSHAuthVolumeName string = "mpi-ssh-auth"

// TrainJobResumedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
TrainJobResumedMessage = "TrainJob is resumed"
// MPISSHPrivateKeyFile is the file name for the private key.
MPISSHPrivateKeyFile string = "id_rsa"

// MPISSHPublicKey is the value in Secret data for the public key.
MPISSHPublicKey string = "ssh-publickey"

// MPISSHPublicKeyFile is the file name for the public key.
MPISSHPublicKeyFile string = MPISSHPrivateKeyFile + ".pub"

// MPISSHAuthorizedKeys is the file name for authorized keys.
MPISSHAuthorizedKeys string = "authorized_keys"

// MPIHostfilePath is the directory for the MPI hostfile.
MPIHostfileDir string = "/etc/mpi"

// MPIHostfileName is the file name for the MPI hostfile.
MPIHostfileName string = "hostfile"

// MPIHostfileConfigMapSuffix is the name suffix for ConfigMap with MPI hostfile.
MPIHostfileConfigMapSuffix string = "-mpi-hostfile"

// MPIHostfileVolumeName is the volume name for ConfigMap with MPI hostfile.
MPIHostfileVolumeName string = "mpi-hostfile"

// Distributed envs for mpirun.
// Values for OpenMPI implementation.
OpenMPIEnvHostFileLocation string = "OMPI_MCA_orte_default_hostfile"
)

var (
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTe
return nil, err
}
if obj != nil {
objs = append(objs, obj)
objs = append(objs, obj...)
}
}
return objs, nil
Expand Down
12 changes: 6 additions & 6 deletions pkg/runtime.v2/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ type Plugin interface {
Name() string
}

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

type WatchExtensionPlugin interface {
Plugin
ReconcilerBuilders() []runtime.ReconcilerBuilder
Expand All @@ -47,14 +52,9 @@ type EnforceMLPolicyPlugin interface {
EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJob) error
}

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

type ComponentBuilderPlugin interface {
Plugin
Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error)
Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *kubeflowv2.TrainJob) ([]client.Object, error)
}

type TerminalConditionPlugin interface {
Expand Down
Loading

0 comments on commit 58beb0c

Please sign in to comment.