Skip to content

Commit

Permalink
Add logic to autogenerate a DeviceClass from a MultiNodeEnvironment
Browse files Browse the repository at this point in the history
  • Loading branch information
klueska committed Jan 13, 2025
1 parent f30a967 commit 36c5d68
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 6 deletions.
146 changes: 140 additions & 6 deletions cmd/nvidia-dra-controller/mnenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (

const (
multiNodeEnvironmentFinalizer = "gpu.nvidia.com/finalizer.multiNodeEnvironment"
imexDeviceClass = "imex.nvidia.com"
)

type MultiNodeEnvironmentManager struct {
Expand All @@ -50,6 +49,7 @@ type MultiNodeEnvironmentManager struct {
multiNodeEnvironmentInformer cache.SharedIndexInformer
multiNodeEnvironmentLister nvlisters.MultiNodeEnvironmentLister
resourceClaimLister resourcelisters.ResourceClaimLister
deviceClassLister resourcelisters.DeviceClassLister
}

// StartManager starts a MultiNodeEnvironmentManager.
Expand All @@ -65,11 +65,15 @@ func StartMultiNodeEnvironmentManager(ctx context.Context, config *Config) (*Mul
rcInformer := coreInformerFactory.Resource().V1beta1().ResourceClaims().Informer()
rcLister := resourcelisters.NewResourceClaimLister(rcInformer.GetIndexer())

dcInformer := coreInformerFactory.Resource().V1beta1().DeviceClasses().Informer()
dcLister := resourcelisters.NewDeviceClassLister(dcInformer.GetIndexer())

m := &MultiNodeEnvironmentManager{
clientsets: config.clientsets,
multiNodeEnvironmentInformer: mneInformer,
multiNodeEnvironmentLister: mneLister,
resourceClaimLister: rcLister,
deviceClassLister: dcLister,
}

var err error
Expand Down Expand Up @@ -101,6 +105,14 @@ func StartMultiNodeEnvironmentManager(ctx context.Context, config *Config) (*Mul
return nil, fmt.Errorf("error adding event handlers for ResourceClaim informer: %w", err)
}

_, err = dcInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: func(obj any) { queue.Enqueue(obj, m.onDeviceClassAddOrUpdate) },
UpdateFunc: func(objOld, objNew any) { queue.Enqueue(objNew, m.onDeviceClassAddOrUpdate) },
})
if err != nil {
return nil, fmt.Errorf("error adding event handlers for DeviceClass informer: %w", err)
}

m.waitGroup.Add(3)
go func() {
defer m.waitGroup.Done()
Expand All @@ -115,10 +127,10 @@ func StartMultiNodeEnvironmentManager(ctx context.Context, config *Config) (*Mul
queue.Run(ctx.Done())
}()

if !cache.WaitForCacheSync(ctx.Done(), mneInformer.HasSynced, rcInformer.HasSynced) {
if !cache.WaitForCacheSync(ctx.Done(), mneInformer.HasSynced, rcInformer.HasSynced, dcInformer.HasSynced) {
klog.Warning("Cache sync failed; retrying in 5 seconds")
time.Sleep(5 * time.Second)
if !cache.WaitForCacheSync(ctx.Done(), mneInformer.HasSynced, rcInformer.HasSynced) {
if !cache.WaitForCacheSync(ctx.Done(), mneInformer.HasSynced, rcInformer.HasSynced, dcInformer.HasSynced) {
return nil, fmt.Errorf("informer cache sync failed twice")
}
}
Expand Down Expand Up @@ -155,13 +167,57 @@ func (m *MultiNodeEnvironmentManager) onMultiNodeEnvironmentAdd(obj any) error {
Controller: ptr.To(true),
}

if _, err := m.createResourceClaim(mne.Namespace, mne.Spec.ResourceClaimName, ownerReference); err != nil {
dc, err := m.createDeviceClass("", ownerReference)
if err != nil {
return fmt.Errorf("error creating DeviceClass '%s': %w", "<generated-name>", err)
}

if _, err := m.createResourceClaim(mne.Namespace, mne.Spec.ResourceClaimName, dc.Name, ownerReference); err != nil {
return fmt.Errorf("error creating ResourceClaim '%s/%s': %w", mne.Namespace, mne.Spec.ResourceClaimName, err)
}

return nil
}

func (m *MultiNodeEnvironmentManager) onDeviceClassAddOrUpdate(obj any) error {
dc, ok := obj.(*resourceapi.DeviceClass)
if !ok {
return fmt.Errorf("failed to cast to DeviceClass")
}

klog.Infof("Processing added or updated DeviceClass: %s", dc.Name)

if len(dc.OwnerReferences) != 1 {
return nil
}

if dc.OwnerReferences[0].Kind != nvapi.MultiNodeEnvironmentKind {
return nil
}

if !cache.WaitForCacheSync(context.Background().Done(), m.multiNodeEnvironmentInformer.HasSynced) {
return fmt.Errorf("cache sync failed for MultiNodeEnvironment")
}

mnes, err := m.multiNodeEnvironmentInformer.GetIndexer().ByIndex("uid", string(dc.OwnerReferences[0].UID))
if err != nil {
return fmt.Errorf("error retrieving MultiNodeInformer OwnerReference by UID from indexer: %w", err)
}
if len(mnes) != 0 {
return nil
}

if err := m.removeDeviceClassFinalizer(dc.Name); err != nil {
return fmt.Errorf("error removing finalizer on DeviceClass '%s': %w", dc.Name, err)
}

if err := m.deleteDeviceClass(dc.Name); err != nil {
return fmt.Errorf("error deleting DeviceClass '%s': %w", dc.Name, err)
}

return nil
}

func (m *MultiNodeEnvironmentManager) onResourceClaimAddOrUpdate(obj any) error {
rc, ok := obj.(*resourceapi.ResourceClaim)
if !ok {
Expand Down Expand Up @@ -201,7 +257,51 @@ func (m *MultiNodeEnvironmentManager) onResourceClaimAddOrUpdate(obj any) error
return nil
}

func (m *MultiNodeEnvironmentManager) createResourceClaim(namespace, name string, ownerReference metav1.OwnerReference) (*resourceapi.ResourceClaim, error) {
func (m *MultiNodeEnvironmentManager) createDeviceClass(name string, ownerReference metav1.OwnerReference) (*resourceapi.DeviceClass, error) {
if name != "" {
dc, err := m.deviceClassLister.Get(name)
if err == nil {
if len(dc.OwnerReferences) != 1 && dc.OwnerReferences[0] != ownerReference {
return nil, fmt.Errorf("DeviceClass '%s' exists without expected OwnerReference: %v", name, ownerReference)
}
return dc, nil
}
if !errors.IsNotFound(err) {
return nil, fmt.Errorf("error retrieving DeviceClass: %w", err)
}
}

deviceClass := &resourceapi.DeviceClass{
ObjectMeta: metav1.ObjectMeta{
OwnerReferences: []metav1.OwnerReference{ownerReference},
Finalizers: []string{multiNodeEnvironmentFinalizer},
},
Spec: resourceapi.DeviceClassSpec{
Selectors: []resourceapi.DeviceSelector{
{
CEL: &resourceapi.CELDeviceSelector{
Expression: "device.driver == 'gpu.nvidia.com' && device.attributes['gpu.nvidia.com'].type == 'imex-channel'",
},
},
},
},
}

if name == "" {
deviceClass.GenerateName = ownerReference.Name
} else {
deviceClass.Name = name
}

dc, err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Create(context.Background(), deviceClass, metav1.CreateOptions{})
if err != nil {
return nil, fmt.Errorf("error creating DeviceClass: %w", err)
}

return dc, nil
}

func (m *MultiNodeEnvironmentManager) createResourceClaim(namespace, name, deviceClassName string, ownerReference metav1.OwnerReference) (*resourceapi.ResourceClaim, error) {
rc, err := m.resourceClaimLister.ResourceClaims(namespace).Get(name)
if err == nil {
if len(rc.OwnerReferences) != 1 && rc.OwnerReferences[0] != ownerReference {
Expand All @@ -223,7 +323,7 @@ func (m *MultiNodeEnvironmentManager) createResourceClaim(namespace, name string
Spec: resourceapi.ResourceClaimSpec{
Devices: resourceapi.DeviceClaim{
Requests: []resourceapi.DeviceRequest{{
Name: "imex", DeviceClassName: imexDeviceClass,
Name: "device", DeviceClassName: deviceClassName,
}},
},
},
Expand All @@ -237,6 +337,32 @@ func (m *MultiNodeEnvironmentManager) createResourceClaim(namespace, name string
return rc, nil
}

func (m *MultiNodeEnvironmentManager) removeDeviceClassFinalizer(name string) error {
dc, err := m.deviceClassLister.Get(name)
if err != nil && errors.IsNotFound(err) {
return nil
}
if err != nil {
return fmt.Errorf("error retrieving DeviceClass: %w", err)
}

newDC := dc.DeepCopy()

newDC.Finalizers = []string{}
for _, f := range dc.Finalizers {
if f != multiNodeEnvironmentFinalizer {
newDC.Finalizers = append(newDC.Finalizers, f)
}
}

_, err = m.clientsets.Core.ResourceV1beta1().DeviceClasses().Update(context.Background(), newDC, metav1.UpdateOptions{})
if err != nil {
return fmt.Errorf("error updating DeviceClass: %w", err)
}

return nil
}

func (m *MultiNodeEnvironmentManager) removeResourceClaimFinalizer(namespace, name string) error {
rc, err := m.resourceClaimLister.ResourceClaims(namespace).Get(name)
if err != nil && errors.IsNotFound(err) {
Expand All @@ -263,6 +389,14 @@ func (m *MultiNodeEnvironmentManager) removeResourceClaimFinalizer(namespace, na
return nil
}

func (m *MultiNodeEnvironmentManager) deleteDeviceClass(name string) error {
err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Delete(context.Background(), name, metav1.DeleteOptions{})
if err != nil && !errors.IsNotFound(err) {
return fmt.Errorf("erroring deleting DeviceClass: %w", err)
}
return nil
}

func (m *MultiNodeEnvironmentManager) deleteResourceClaim(namespace, name string) error {
err := m.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Delete(context.Background(), name, metav1.DeleteOptions{})
if err != nil && !errors.IsNotFound(err) {
Expand Down
3 changes: 3 additions & 0 deletions deployments/helm/k8s-dra-driver/templates/clusterrole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ rules:
- apiGroups: ["resource.k8s.io"]
resources: ["resourceclaims"]
verbs: ["get", "list", "watch", "create", "update", "delete"]
- apiGroups: ["resource.k8s.io"]
resources: ["deviceclasses"]
verbs: ["get", "list", "watch", "create", "update", "delete"]
- apiGroups: ["resource.k8s.io"]
resources: ["resourceclaims/status"]
verbs: ["update"]
Expand Down

0 comments on commit 36c5d68

Please sign in to comment.