Skip to content

Commit 782fbe7

Browse files
authored
feat: manage multiple GPU bindings of an application (#5)
* feat: add an API to switch GPUBinding for app * fix: add periodic GC for GPU bindings
1 parent 90a3c66 commit 782fbe7

File tree

3 files changed

+389
-0
lines changed

3 files changed

+389
-0
lines changed

cmd/scheduler/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ func start() error {
120120

121121
// start monitor metrics
122122
go sher.RegisterFromNodeAnnotations()
123+
go sher.CleanupGPUBindingsLoop()
123124
go initMetrics(config.MetricsBindAddress)
124125

125126
// start http server
@@ -130,6 +131,7 @@ func start() error {
130131
router.GET("/healthz", routes.HealthzRoute())
131132

132133
router.GET("/gpus", routes.ListGPUDetails(sher))
134+
router.PUT("/gpus/assignments/bulk", routes.BulkManageAssignments(sher))
133135
router.POST("/gpus/:id/mode", routes.SwitchGPUMode(sher))
134136
router.POST("/gpus/:id/assign", routes.AssignGPUToApp(sher))
135137
router.POST("/gpus/:id/unassign", routes.UnassignGPUFromApp(sher))

pkg/scheduler/routes/gpu_manage.go

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ type UnassignGPURequest struct {
5252
AppName string `json:"appName"`
5353
}
5454

55+
type SwitchAssignItem struct {
56+
ID string `json:"id"`
57+
Memory *resource.Quantity `json:"memory,omitempty"`
58+
}
59+
60+
type SwitchAssignRequest struct {
61+
AppName string `json:"appName"`
62+
Unassign []SwitchAssignItem `json:"unassign,omitempty"`
63+
Assign []SwitchAssignItem `json:"assign,omitempty"`
64+
}
65+
5566
func ListGPUInfos(s *scheduler.Scheduler) httprouter.Handle {
5667
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
5768
klog.Infoln("Listing all GPUs")
@@ -545,3 +556,306 @@ func UnassignGPUFromApp(s *scheduler.Scheduler) httprouter.Handle {
545556
w.WriteHeader(http.StatusOK)
546557
}
547558
}
559+
560+
// SwitchAssign performs an atomic switch of GPU assignments for an app:
561+
// - Unassigns specified GPU IDs if currently bound to the app (ignores non-existent bindings)
562+
// - Assigns specified GPU IDs (with optional memory for mem-slicing mode) to the app
563+
// - Enforces single-node binding policy across the app's final bindings
564+
// - For exclusive GPUs, evicts existing app bindings and restarts their pods
565+
// - Restarts the target app's pods only if its binding relationship changes
566+
func BulkManageAssignments(s *scheduler.Scheduler) httprouter.Handle {
567+
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
568+
var req SwitchAssignRequest
569+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
570+
http.Error(w, fmt.Sprintf("failed to decode request: %v", err), http.StatusBadRequest)
571+
return
572+
}
573+
if req.AppName == "" {
574+
http.Error(w, "AppName is required", http.StatusBadRequest)
575+
return
576+
}
577+
578+
klog.Infof("SwitchAssign request for app %s: unassign=%v assign=%v", req.AppName, req.Unassign, req.Assign)
579+
util.GPUManageLock.Lock()
580+
defer util.GPUManageLock.Unlock()
581+
582+
nodes, err := s.ListNodes()
583+
if err != nil {
584+
klog.Errorln(err)
585+
http.Error(w, fmt.Sprintf("failed to list nodes: %v", err), http.StatusInternalServerError)
586+
return
587+
}
588+
589+
// Build device maps
590+
uuidToDevice := make(map[string]util.DeviceInfo)
591+
uuidToNodeName := make(map[string]string)
592+
for _, node := range nodes {
593+
for _, device := range node.Devices {
594+
uuidToDevice[device.ID] = device
595+
uuidToNodeName[device.ID] = node.Node.Name
596+
}
597+
}
598+
599+
bindings, err := s.ListGPUBindings()
600+
if err != nil {
601+
klog.Errorln(err)
602+
http.Error(w, err.Error(), http.StatusInternalServerError)
603+
return
604+
}
605+
606+
// Group existing bindings
607+
currentBindingsByUUID := make(map[string]*v1alpha1.GPUBinding)
608+
bindingsByUUID := make(map[string][]*v1alpha1.GPUBinding)
609+
for _, b := range bindings {
610+
bindingsByUUID[b.Spec.UUID] = append(bindingsByUUID[b.Spec.UUID], b)
611+
if b.Spec.AppName == req.AppName {
612+
currentBindingsByUUID[b.Spec.UUID] = b
613+
}
614+
}
615+
616+
// Build a quick lookup for GPUs that will be assigned, so unassign won't remove them
617+
assignIDs := make(map[string]struct{})
618+
for _, it := range req.Assign {
619+
if it.ID != "" {
620+
assignIDs[it.ID] = struct{}{}
621+
}
622+
}
623+
624+
// Plan unassignments (skip any UUID that is also in the assign list)
625+
toUnassignNames := make([]string, 0)
626+
unassignSet := make(map[string]struct{})
627+
for _, it := range req.Unassign {
628+
if it.ID == "" {
629+
continue
630+
}
631+
if _, willAssign := assignIDs[it.ID]; willAssign {
632+
continue
633+
}
634+
if b := currentBindingsByUUID[it.ID]; b != nil {
635+
toUnassignNames = append(toUnassignNames, b.Name)
636+
unassignSet[it.ID] = struct{}{}
637+
}
638+
}
639+
640+
// Plan assignments (patches for mem changes, creates for new bindings, evictions for exclusive)
641+
type patchItem struct {
642+
old *v1alpha1.GPUBinding
643+
mem *resource.Quantity
644+
}
645+
patches := make([]patchItem, 0)
646+
type createItem struct {
647+
binding *v1alpha1.GPUBinding
648+
}
649+
creates := make([]createItem, 0)
650+
evictUUIDs := make(map[string]struct{})
651+
deleteOtherBindingNames := make([]string, 0)
652+
653+
seenAssign := make(map[string]struct{})
654+
for _, it := range req.Assign {
655+
if it.ID == "" {
656+
continue
657+
}
658+
if _, duplicated := seenAssign[it.ID]; duplicated {
659+
continue
660+
}
661+
seenAssign[it.ID] = struct{}{}
662+
663+
dev, ok := uuidToDevice[it.ID]
664+
if !ok {
665+
http.Error(w, fmt.Sprintf("GPU %s not found", it.ID), http.StatusNotFound)
666+
return
667+
}
668+
669+
if existing := currentBindingsByUUID[it.ID]; existing != nil {
670+
// Already bound to this app
671+
if dev.ShareMode != util.ShareModeMemSlicing {
672+
// In exclusive/time-slicing, reassign to same GPU is a no-op
673+
continue
674+
}
675+
// mem-slicing: treat missing/zero or unchanged memory as no-op
676+
if it.Memory == nil || it.Memory.Value() == 0 ||
677+
(existing.Spec.Memory != nil && it.Memory.Value() == existing.Spec.Memory.Value()) {
678+
continue
679+
}
680+
// validate memory availability excluding this app's current allocation
681+
totalUsed := int64(0)
682+
for _, b := range bindingsByUUID[it.ID] {
683+
if b.Spec.AppName == req.AppName {
684+
continue
685+
}
686+
if b.Spec.Memory != nil {
687+
totalUsed += b.Spec.Memory.Value()
688+
}
689+
}
690+
if totalUsed+it.Memory.Value() > int64(dev.Devmem) {
691+
err = fmt.Errorf("not enough memory on GPU %s, available: %d, request: %d for app %s",
692+
it.ID, int64(dev.Devmem)-totalUsed, it.Memory.Value(), req.AppName)
693+
klog.Warningln(err)
694+
http.Error(w, err.Error(), http.StatusConflict)
695+
return
696+
}
697+
patches = append(patches, patchItem{old: existing, mem: it.Memory})
698+
continue
699+
}
700+
701+
// Not currently bound to this app
702+
if dev.ShareMode == util.ShareModeMemSlicing {
703+
if it.Memory == nil || it.Memory.Value() == 0 {
704+
err = fmt.Errorf("memory allocation is required for GPU %s in memory slicing mode, refuse assigning to app %s", it.ID, req.AppName)
705+
klog.Warningln(err)
706+
http.Error(w, err.Error(), http.StatusBadRequest)
707+
return
708+
}
709+
totalUsed := int64(0)
710+
for _, b := range bindingsByUUID[it.ID] {
711+
if b.Spec.Memory != nil {
712+
totalUsed += b.Spec.Memory.Value()
713+
}
714+
}
715+
if totalUsed+it.Memory.Value() > int64(dev.Devmem) {
716+
err = fmt.Errorf("not enough memory available on GPU %s, available: %d, request: %d, refuse assigning to app %s",
717+
it.ID, int64(dev.Devmem)-totalUsed, it.Memory.Value(), req.AppName)
718+
klog.Warningln(err)
719+
http.Error(w, err.Error(), http.StatusConflict)
720+
return
721+
}
722+
} else if dev.ShareMode == util.ShareModeExclusive {
723+
// Plan eviction of other app(s) holding this GPU
724+
for _, b := range bindingsByUUID[it.ID] {
725+
if b.Spec.AppName != req.AppName {
726+
deleteOtherBindingNames = append(deleteOtherBindingNames, b.Name)
727+
evictUUIDs[it.ID] = struct{}{}
728+
}
729+
}
730+
}
731+
732+
// Prepare new binding
733+
mem := it.Memory
734+
if dev.ShareMode != util.ShareModeMemSlicing {
735+
mem = nil
736+
}
737+
newBinding := &v1alpha1.GPUBinding{
738+
ObjectMeta: metav1.ObjectMeta{
739+
Name: strings.ToLower(fmt.Sprintf("%s-%s-%d", req.AppName, it.ID, time.Now().Unix())),
740+
},
741+
Spec: v1alpha1.GPUBindingSpec{
742+
UUID: it.ID,
743+
AppName: req.AppName,
744+
PodSelector: &metav1.LabelSelector{
745+
MatchLabels: map[string]string{
746+
util.AppNameLabelKey: req.AppName,
747+
},
748+
},
749+
Memory: mem,
750+
},
751+
}
752+
creates = append(creates, createItem{binding: newBinding})
753+
}
754+
755+
// Determine final set of UUIDs for the app after changes (for node policy check)
756+
finalUUIDSet := make(map[string]struct{})
757+
for uuid := range currentBindingsByUUID {
758+
if _, toUn := unassignSet[uuid]; !toUn {
759+
finalUUIDSet[uuid] = struct{}{}
760+
}
761+
}
762+
for _, c := range creates {
763+
finalUUIDSet[c.binding.Spec.UUID] = struct{}{}
764+
}
765+
766+
// If no effective change to app's own bindings, return OK without restarting its pods
767+
if len(toUnassignNames) == 0 && len(patches) == 0 && len(creates) == 0 {
768+
w.WriteHeader(http.StatusOK)
769+
return
770+
}
771+
772+
// Enforce single-node binding policy
773+
nodeSet := make(map[string]struct{})
774+
for uuid := range finalUUIDSet {
775+
nodeName := uuidToNodeName[uuid]
776+
if nodeName == "" {
777+
http.Error(w, fmt.Sprintf("GPU %s not found", uuid), http.StatusNotFound)
778+
return
779+
}
780+
nodeSet[nodeName] = struct{}{}
781+
}
782+
if len(nodeSet) > 1 {
783+
err = fmt.Errorf("app %s binding spans multiple nodes which is not allowed", req.AppName)
784+
klog.Warningln(err)
785+
http.Error(w, err.Error(), http.StatusConflict)
786+
return
787+
}
788+
789+
// Execute plan
790+
// 1) Evict other apps for exclusive GPUs
791+
if len(evictUUIDs) > 0 {
792+
pods := s.ListPodsInfo()
793+
for _, pod := range pods {
794+
for _, pdev := range pod.Devices {
795+
for _, cdevs := range pdev {
796+
for _, cdev := range cdevs {
797+
if _, needEvict := evictUUIDs[cdev.UUID]; needEvict {
798+
klog.Infof("Evicting pod %s/%s occupying exclusive GPU %s", pod.Namespace, pod.Name, cdev.UUID)
799+
if err := ctrlclient.IgnoreNotFound(client.GetClient().CoreV1().Pods(pod.Namespace).Delete(r.Context(), pod.Name, metav1.DeleteOptions{})); err != nil {
800+
err = fmt.Errorf("failed to delete existing pod occupying GPU %s/%s: %v", pod.Namespace, pod.Name, err)
801+
klog.Errorln(err)
802+
http.Error(w, err.Error(), http.StatusInternalServerError)
803+
return
804+
}
805+
}
806+
}
807+
}
808+
}
809+
}
810+
for _, name := range deleteOtherBindingNames {
811+
if err := ctrlclient.IgnoreNotFound(util.DeleteGPUBinding(r.Context(), name)); err != nil {
812+
err = fmt.Errorf("failed to delete existing GPUBinding %s: %v", name, err)
813+
klog.Errorln(err)
814+
http.Error(w, err.Error(), http.StatusInternalServerError)
815+
return
816+
}
817+
}
818+
}
819+
820+
// 2) Restart this app's pods due to binding changes
821+
if err := ctrlclient.IgnoreNotFound(util.DeletePodsBelongToApp(r.Context(), req.AppName)); err != nil {
822+
err = fmt.Errorf("failed to delete existing pods of app %s: %v", req.AppName, err)
823+
klog.Errorln(err)
824+
http.Error(w, err.Error(), http.StatusInternalServerError)
825+
return
826+
}
827+
828+
// 3) Apply unassignments
829+
for _, name := range toUnassignNames {
830+
if err := ctrlclient.IgnoreNotFound(util.DeleteGPUBinding(r.Context(), name)); err != nil {
831+
err = fmt.Errorf("failed to delete GPUBinding %s: %v", name, err)
832+
klog.Errorln(err)
833+
http.Error(w, err.Error(), http.StatusInternalServerError)
834+
return
835+
}
836+
}
837+
838+
// 4) Apply memory patches
839+
for _, p := range patches {
840+
newBinding := p.old.DeepCopy()
841+
newBinding.Spec.Memory = p.mem
842+
if err := client.GPUClient.Patch(r.Context(), newBinding, ctrlclient.MergeFrom(p.old)); err != nil {
843+
err = fmt.Errorf("failed to patch GPUBinding %s: %v", p.old.Name, err)
844+
klog.Errorln(err)
845+
http.Error(w, err.Error(), http.StatusInternalServerError)
846+
return
847+
}
848+
}
849+
850+
// 5) Create new bindings
851+
for _, c := range creates {
852+
if err := s.CreateGPUBinding(r.Context(), c.binding); err != nil {
853+
klog.Errorln(err)
854+
http.Error(w, err.Error(), http.StatusInternalServerError)
855+
return
856+
}
857+
}
858+
859+
w.WriteHeader(http.StatusOK)
860+
}
861+
}

0 commit comments

Comments
 (0)