diff --git a/cmd/nvidia-validator/main.go b/cmd/nvidia-validator/main.go index 3dd1f5224..e154679e7 100644 --- a/cmd/nvidia-validator/main.go +++ b/cmd/nvidia-validator/main.go @@ -1734,6 +1734,22 @@ func (v *VGPUDevices) validate() error { } func (v *VGPUDevices) runValidation() error { + nvpci := nvpci.New() + GPUDevices, err := nvpci.GetGPUs() + if err != nil { + return fmt.Errorf("error checking for GPU devices on the host: %w", err) + } + + for _, device := range GPUDevices { + creatableTypesFile := filepath.Join(device.Path, "virtfn0", "nvidia", "creatable_vgpu_types") + + _, statErr := os.Stat(creatableTypesFile) + if statErr == nil { + log.Infof("Found creatable_vgpu_types file for device: %s", device.Address) + return nil + } + } + nvmdev := nvmdev.New() vGPUDevices, err := nvmdev.GetAllDevices() if err != nil { @@ -1746,14 +1762,14 @@ func (v *VGPUDevices) runValidation() error { return fmt.Errorf("no vGPU devices found") } - log.Infof("Found %d vGPU devices", numDevices) + log.Infof("Found %d MDEV vGPU devices", numDevices) return nil } for { numDevices := len(vGPUDevices) if numDevices > 0 { - log.Infof("Found %d vGPU devices", numDevices) + log.Infof("Found %d MDEV vGPU devices", numDevices) return nil } log.Infof("No vGPU devices found, retrying after %d seconds", sleepIntervalSecondsFlag)