diff --git a/api/config/v1/runtime.go b/api/config/v1/runtime.go index 5df04e90f..946f75a23 100644 --- a/api/config/v1/runtime.go +++ b/api/config/v1/runtime.go @@ -53,6 +53,9 @@ type jitCDIModeConfig struct { type csvModeConfig struct { MountSpecPath string `toml:"mount-spec-path"` + // CompatContainerRoot specifies the compat root used when the the standard + // CUDA compat libraries should not be used. + CompatContainerRoot string `toml:"compat-container-root,omitempty"` } type legacyModeConfig struct { diff --git a/cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header.go b/cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header.go index e805283db..7e6577485 100644 --- a/cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header.go +++ b/cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header.go @@ -25,11 +25,14 @@ import ( "fmt" "os" "slices" + "strings" + + "golang.org/x/mod/semver" ) type compatElfHeader struct { Format int - CUDAVersion string `json:"CUDA Version"` + CUDAVersion cudaVersion `json:"CUDA Version"` Driver []int Device []int } @@ -114,10 +117,35 @@ func getCUDAFwdCompatibilitySection(lib *elf.File) *elf.Section { return nil } -func (h *compatElfHeader) UseCompat(driverMajor int) bool { +// UseCompat checks whether the CUDA compat libraries with the specified elf +// header should be used given the specified host versions. +// If the hostDriverVersion is specified and the ELF header includes a list of +// driver verions, this is checked, otherwise the CUDA version specified in the +// ELF section is checked. +func (h *compatElfHeader) UseCompat(hostDriverMajor int, hostCUDAVersion string) bool { if h == nil { return false } - return slices.Contains(h.Driver, driverMajor) + if hostDriverMajor != 0 && len(h.Driver) > 0 { + return slices.Contains(h.Driver, hostDriverMajor) + } + + return h.CUDAVersion.UseCompat(hostCUDAVersion) +} + +type cudaVersion string + +// UseCompat is true if the container CUDA version is strictly greater than the +// host CUDA version. +func (containerVersion cudaVersion) UseCompat(hostVersion string) bool { + if containerVersion == "" || hostVersion == "" { + return false + } + + return semver.Compare(normalizeVersion(containerVersion), normalizeVersion(hostVersion)) > 0 +} + +func normalizeVersion[T string | cudaVersion](v T) string { + return "v" + strings.TrimPrefix(string(v), "v") } diff --git a/cmd/nvidia-cdi-hook/cudacompat/cudacompat.go b/cmd/nvidia-cdi-hook/cudacompat/cudacompat.go index 717a7b41a..2ef7218ae 100644 --- a/cmd/nvidia-cdi-hook/cudacompat/cudacompat.go +++ b/cmd/nvidia-cdi-hook/cudacompat/cudacompat.go @@ -46,6 +46,7 @@ type command struct { type options struct { cudaCompatContainerRoot string hostDriverVersion string + hostCudaVersion string // containerSpec allows the path to the container spec to be specified for // testing. containerSpec string @@ -79,6 +80,11 @@ func (m command) build() *cli.Command { Usage: "Specify the host driver version. If the CUDA compat libraries detected in the container do not have a higher MAJOR version, the hook is a no-op.", Destination: &options.hostDriverVersion, }, + &cli.StringFlag{ + Name: "host-cuda-version", + Usage: "Specify the CUDA version supported by the host driver.", + Destination: &options.hostCudaVersion, + }, &cli.StringFlag{ Name: "cuda-compat-container-root", Usage: "Specify the folder in which CUDA compat libraries are located in the container", @@ -103,7 +109,9 @@ func (m command) validateFlags(_ *cli.Command, _ *options) error { } func (m command) run(_ *cli.Command, o *options) error { - if o.hostDriverVersion == "" { + // If neither the host driver version nor the host cuda version is specified + // the hook is a no-op. + if o.hostDriverVersion == "" && o.hostCudaVersion == "" { return nil } @@ -129,8 +137,8 @@ func (m command) run(_ *cli.Command, o *options) error { } func (m command) getContainerForwardCompatDir(containerRoot containerRoot, o *options) (string, error) { - if o.hostDriverVersion == "" { - m.logger.Debugf("Host driver version not specified") + if o.hostDriverVersion == "" && o.hostCudaVersion == "" { + m.logger.Debugf("Neither a host driver version nor a host CUDA version was specified") return "", nil } if !containerRoot.hasPath(o.cudaCompatContainerRoot) { @@ -156,7 +164,7 @@ func (m command) getContainerForwardCompatDir(containerRoot containerRoot, o *op libCudaCompatPath := libs[0] - useCompatLibs, err := m.useCompatLibraries(libCudaCompatPath, o.hostDriverVersion) + useCompatLibs, err := m.useCompatLibraries(libCudaCompatPath, o.hostDriverVersion, o.hostCudaVersion) if err != nil { return "", err } @@ -168,16 +176,24 @@ func (m command) getContainerForwardCompatDir(containerRoot containerRoot, o *op return resolvedCompatDir, nil } -func (m command) useCompatLibraries(libcudaCompatPath string, hostDriverVersion string) (bool, error) { +func (m command) useCompatLibraries(libcudaCompatPath string, hostDriverVersion string, hostCUDAVersion string) (bool, error) { driverMajor, err := extractMajorVersion(hostDriverVersion) if err != nil { return false, fmt.Errorf("failed to extract major version from %q: %v", hostDriverVersion, err) } - // First check the elf header. + // First check the ELF header. If this is present, we use the ELF header to + // determine whether the CUDA compat libraries in the container should be + // used. cudaCompatHeader, _ := GetCUDACompatElfHeader(libcudaCompatPath) if cudaCompatHeader != nil { - return cudaCompatHeader.UseCompat(driverMajor), nil + return cudaCompatHeader.UseCompat(driverMajor, hostCUDAVersion), nil + } + + // If no CUDA Compat ELF header is available, and NO host driver version + // was specified, we don't use the CUDA compat libraries in the container. + if hostDriverVersion == "" { + return false, nil } compatDriverVersion := strings.TrimPrefix(filepath.Base(libcudaCompatPath), "libcuda.so.") @@ -241,6 +257,9 @@ func (m command) createLdsoconfdFile(in containerRoot, pattern string, dirs ...s // extractMajorVersion parses a version string and returns the major version as an int. func extractMajorVersion(version string) (int, error) { + if version == "" { + return 0, nil + } majorString := strings.SplitN(version, ".", 2)[0] return strconv.Atoi(majorString) } diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 1517e7584..b07220b90 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -70,8 +70,9 @@ type options struct { featureFlags []string csv struct { - files []string - ignorePatterns []string + files []string + ignorePatterns []string + CompatContainerRoot string } noAllDevice bool @@ -212,6 +213,12 @@ func (m command) build() *cli.Command { Destination: &opts.csv.ignorePatterns, Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_CSV_IGNORE_PATTERNS"), }, + &cli.StringFlag{ + Name: "csv.compat-container-root", + Usage: "specify the container folder to use for CUDA Forward Compatibility in non-standard containers", + Destination: &opts.csv.CompatContainerRoot, + Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_CSV_CONTAINER_COMPAT_ROOT"), + }, &cli.StringSliceFlag{ Name: "disable-hook", Aliases: []string{"disable-hooks"}, @@ -384,6 +391,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) { nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths), nvcdi.WithCSVFiles(opts.csv.files), nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns), + nvcdi.WithCSVCompatContainerRoot(opts.csv.CompatContainerRoot), nvcdi.WithDisabledHooks(opts.disabledHooks...), nvcdi.WithEnabledHooks(opts.enabledHooks...), nvcdi.WithFeatureFlags(opts.featureFlags...), diff --git a/internal/discover/compat_libs.go b/internal/discover/compat_libs.go index 977fdf189..cc65385e5 100644 --- a/internal/discover/compat_libs.go +++ b/internal/discover/compat_libs.go @@ -6,16 +6,33 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) +// EnableCUDACompatHookOptions defines the options that can be specified +// when creating the enable-cuda-compat hook. +type EnableCUDACompatHookOptions struct { + HostDriverVersion string + HostCUDAVersion string + CUDACompatContainerRoot string +} + // NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook. // This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version. -func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, version string, cudaCompatContainerRoot string) Discover { +func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, o *EnableCUDACompatHookOptions) Discover { + return hookCreator.Create(EnableCudaCompatHook, o.args()...) +} + +func (o *EnableCUDACompatHookOptions) args() []string { + if o == nil { + return nil + } var args []string - if version != "" && !strings.Contains(version, "*") { - args = append(args, "--host-driver-version="+version) + if o.HostDriverVersion != "" && !strings.Contains(o.HostDriverVersion, "*") { + args = append(args, "--host-driver-version="+o.HostDriverVersion) } - if cudaCompatContainerRoot != "" { - args = append(args, "--cuda-compat-container-root="+cudaCompatContainerRoot) + if o.HostCUDAVersion != "" { + args = append(args, "--host-cuda-version="+o.HostCUDAVersion) } - - return hookCreator.Create("enable-cuda-compat", args...) + if o.CUDACompatContainerRoot != "" { + args = append(args, "--cuda-compat-container-root="+o.CUDACompatContainerRoot) + } + return args } diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index 2d3c372ff..bc835f654 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -59,6 +59,7 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), nvcdi.WithMode(nvcdi.ModeCSV), nvcdi.WithCSVFiles(csvFiles), + nvcdi.WithCSVCompatContainerRoot(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot), ) if err != nil { return nil, fmt.Errorf("failed to construct CDI library: %v", err) diff --git a/internal/modifier/gated.go b/internal/modifier/gated.go index 369a22085..ef6ae5586 100644 --- a/internal/modifier/gated.go +++ b/internal/modifier/gated.go @@ -107,7 +107,7 @@ func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, dr return nil, fmt.Errorf("failed to get driver version: %w", err) } - compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, version, "") + compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version}) // For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook. if cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" { return compatLibHookDiscoverer, nil diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index c48ab929b..5f6474e36 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -101,7 +101,7 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string, libcudaSoParentDir ) discoverers = append(discoverers, driverDotSoSymlinksDiscoverer) - cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, version, "") + cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version}) discoverers = append(discoverers, cudaCompatLibHookDiscoverer) updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath) diff --git a/pkg/nvcdi/lib-csv.go b/pkg/nvcdi/lib-csv.go index 94543e0a3..dc0536478 100644 --- a/pkg/nvcdi/lib-csv.go +++ b/pkg/nvcdi/lib-csv.go @@ -33,14 +33,36 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra" + "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" ) -type csvlib nvcdilib +const ( + defaultOrinCompatContainerRoot = "/usr/local/cuda/compat-orin" +) + +type csvOptions struct { + Files []string + IgnorePatterns []string + CompatContainerRoot string +} +type csvlib nvcdilib type mixedcsvlib nvcdilib var _ deviceSpecGeneratorFactory = (*csvlib)(nil) +// asCSVLib sets any CSV-specific defaults and casts the nvcdilib instance as a +// *csvlib. +func (l *nvcdilib) asCSVLib() *csvlib { + if len(l.csv.Files) == 0 { + l.csv.Files = csv.DefaultFileList() + } + if l.csv.CompatContainerRoot == "" { + l.csv.CompatContainerRoot = defaultOrinCompatContainerRoot + } + return (*csvlib)(l) +} + // DeviceSpecGenerators creates a set of generators for the specified set of // devices. // If NVML is not available or the disable-multiple-csv-devices feature flag is @@ -171,7 +193,7 @@ func (l *csvDeviceGenerator) deviceNodeDiscoverer() (discover.Discover, error) { func (l *csvDeviceGenerator) deviceNodeMountSpecs() tegra.MountSpecPathsByTyper { mountSpecs := tegra.Transform( - tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...), + tegra.MountSpecsFromCSVFiles(l.logger, l.csv.Files...), // We remove non-device nodes. tegra.OnlyDeviceNodes(), ) @@ -388,10 +410,10 @@ func isIntegratedGPU(d nvml.Device) (bool, error) { func (l *csvlib) driverDiscoverer() (discover.Discover, error) { mountSpecs := tegra.Transform( tegra.Transform( - tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...), + tegra.MountSpecsFromCSVFiles(l.logger, l.csv.Files...), tegra.WithoutDeviceNodes(), ), - tegra.IgnoreSymlinkMountSpecsByPattern(l.csvIgnorePatterns...), + tegra.IgnoreSymlinkMountSpecsByPattern(l.csv.IgnorePatterns...), ) driverDiscoverer, err := tegra.New( tegra.WithLogger(l.logger), @@ -428,26 +450,55 @@ func (l *csvlib) driverDiscoverer() (discover.Discover, error) { // version to be passed to the hook. // On Orin-based systems, the compat library root in the container is also set. func (l *csvlib) cudaCompatDiscoverer() discover.Discover { + c, err := l.getEnableCUDACompatHookOptions() + if err != nil { + l.logger.Warningf("Skipping CUDA Forward Compat hook creation: %v", err) + } + if c == nil { + return nil + } + + return discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, c) +} + +func (l *csvlib) getEnableCUDACompatHookOptions() (*discover.EnableCUDACompatHookOptions, error) { hasNvml, _ := l.infolib.HasNvml() if !hasNvml { - return nil + return nil, nil } ret := l.nvmllib.Init() if ret != nvml.SUCCESS { - l.logger.Warningf("Failed to initialize NVML: %v", ret) - return nil + return nil, fmt.Errorf("failed to initialize NVML: %v", ret) } defer func() { _ = l.nvmllib.Shutdown() }() - version, ret := l.nvmllib.SystemGetDriverVersion() - if ret != nvml.SUCCESS { - l.logger.Warningf("Failed to get driver version: %v", ret) - return nil + if !l.hasOrinDevices() { + hostDriverVersion, ret := l.nvmllib.SystemGetDriverVersion() + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get driver version: %v", ret) + } + f := &discover.EnableCUDACompatHookOptions{ + HostDriverVersion: hostDriverVersion, + } + return f, nil + } + + hostCUDAVersion, err := l.getCUDAVersionString() + if err != nil { + return nil, fmt.Errorf("failed to get host CUDA version: %v", ret) } + f := &discover.EnableCUDACompatHookOptions{ + HostCUDAVersion: hostCUDAVersion, + CUDACompatContainerRoot: l.csv.CompatContainerRoot, + } + return f, nil +} + +func (l *csvlib) hasOrinDevices() bool { var names []string err := l.devicelib.VisitDevices(func(i int, d device.Device) error { name, ret := d.GetName() @@ -458,19 +509,26 @@ func (l *csvlib) cudaCompatDiscoverer() discover.Discover { return nil }) if err != nil { - l.logger.Warningf("Failed to get device names: %v", err) - return nil + l.logger.Warningf("Failed to get device names: %v; assuming non-orin devices", err) + return false } - var cudaCompatContainerRoot string for _, name := range names { - // TODO: Should this be overridable through a feature flag / config option? if strings.Contains(name, "Orin (nvgpu)") { - // TODO: This should probably be a constant or configurable. - cudaCompatContainerRoot = "/usr/local/cuda/compat-orin" - break + return true } } - return discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, version, cudaCompatContainerRoot) + return false +} + +func (l *csvlib) getCUDAVersionString() (string, error) { + v, ret := l.nvmllib.SystemGetCudaDriverVersion() + if ret != nvml.SUCCESS { + return "", ret + } + major := v / 1000 + minor := v % 1000 / 10 + + return fmt.Sprintf("%d.%d", major, minor), nil } diff --git a/pkg/nvcdi/lib-csv_test.go b/pkg/nvcdi/lib-csv_test.go index 6810b4c3f..55eca016d 100644 --- a/pkg/nvcdi/lib-csv_test.go +++ b/pkg/nvcdi/lib-csv_test.go @@ -63,39 +63,50 @@ func TestDeviceSpecGenerators(t *testing.T) { infolib: &infoInterfaceMock{ HasNvmlFunc: func() (bool, string) { return true, "forced" }, }, - // TODO: Replace this with a system-specific implementation once available. - nvmllib: &mock.Interface{ - InitFunc: func() nvml.Return { - return nvml.SUCCESS - }, - ShutdownFunc: func() nvml.Return { - return nvml.SUCCESS - }, - SystemGetDriverVersionFunc: func() (string, nvml.Return) { - return "540.3.0", nvml.SUCCESS - }, - DeviceGetCountFunc: func() (int, nvml.Return) { - return 1, nvml.SUCCESS + nvmllib: mockOrinServer(), + }, + expectedDeviceSpecs: []specs.Device{ + { + Name: "0", + ContainerEdits: specs.ContainerEdits{ + DeviceNodes: []*specs.DeviceNode{ + {Path: "/dev/nvidia0", HostPath: "/dev/nvidia0"}, + }, }, - DeviceGetHandleByIndexFunc: func(n int) (nvml.Device, nvml.Return) { - if n != 0 { - return nil, nvml.ERROR_INVALID_ARGUMENT - } - device := &mock.Device{ - GetUUIDFunc: func() (string, nvml.Return) { - return "GPU-orin", nvml.SUCCESS - }, - GetNameFunc: func() (string, nvml.Return) { - return "Orin (nvgpu)", nvml.SUCCESS - }, - GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) { - return nvml.PciInfo{}, nvml.ERROR_NOT_SUPPORTED - }, - } - return device, nvml.SUCCESS + }, + }, + expectedCommonEdits: &cdi.ContainerEdits{ + ContainerEdits: &specs.ContainerEdits{ + Hooks: []*specs.Hook{ + { + HookName: "createContainer", + Path: "/usr/bin/nvidia-cdi-hook", + Args: []string{"nvidia-cdi-hook", "enable-cuda-compat", "--host-cuda-version=13.1", "--cuda-compat-container-root=/usr/local/cuda/compat-orin"}, + Env: []string{"NVIDIA_CTK_DEBUG=false"}, + }, + { + HookName: "createContainer", + Path: "/usr/bin/nvidia-cdi-hook", + Args: []string{"nvidia-cdi-hook", "update-ldcache"}, + Env: []string{"NVIDIA_CTK_DEBUG=false"}, + }, }, }, }, + }, + { + description: "single orin CSV device; custom container compat root", + rootfsFolder: "rootfs-orin", + lib: &csvlib{ + // test-case specific + infolib: &infoInterfaceMock{ + HasNvmlFunc: func() (bool, string) { return true, "forced" }, + }, + nvmllib: mockOrinServer(), + csv: csvOptions{ + CompatContainerRoot: "/another/compat/root", + }, + }, expectedDeviceSpecs: []specs.Device{ { Name: "0", @@ -112,7 +123,7 @@ func TestDeviceSpecGenerators(t *testing.T) { { HookName: "createContainer", Path: "/usr/bin/nvidia-cdi-hook", - Args: []string{"nvidia-cdi-hook", "enable-cuda-compat", "--host-driver-version=540.3.0", "--cuda-compat-container-root=/usr/local/cuda/compat-orin"}, + Args: []string{"nvidia-cdi-hook", "enable-cuda-compat", "--host-cuda-version=13.1", "--cuda-compat-container-root=/another/compat/root"}, Env: []string{"NVIDIA_CTK_DEBUG=false"}, }, { @@ -188,10 +199,13 @@ func TestDeviceSpecGenerators(t *testing.T) { tc.lib.driverRoot = driverRoot tc.lib.devRoot = driverRoot - tc.lib.csvFiles = []string{ + tc.lib.csv.Files = []string{ filepath.Join(driverRoot, "/etc/nvidia-container-runtime/host-files-for-container.d/devices.csv"), filepath.Join(driverRoot, "/etc/nvidia-container-runtime/host-files-for-container.d/drivers.csv"), } + if tc.lib.csv.CompatContainerRoot == "" { + tc.lib.csv.CompatContainerRoot = defaultOrinCompatContainerRoot + } t.Run(tc.description, func(t *testing.T) { generator, err := tc.lib.DeviceSpecGenerators("all") @@ -230,6 +244,44 @@ func stripRoot[T any](root string, v T) T { return modified } +// TODO: We should move this mock to go-nvml/mock +func mockOrinServer() nvml.Interface { + return &mock.Interface{ + InitFunc: func() nvml.Return { + return nvml.SUCCESS + }, + ShutdownFunc: func() nvml.Return { + return nvml.SUCCESS + }, + SystemGetDriverVersionFunc: func() (string, nvml.Return) { + return "540.3.0", nvml.SUCCESS + }, + SystemGetCudaDriverVersionFunc: func() (int, nvml.Return) { + return 13010, nvml.SUCCESS + }, + DeviceGetCountFunc: func() (int, nvml.Return) { + return 1, nvml.SUCCESS + }, + DeviceGetHandleByIndexFunc: func(n int) (nvml.Device, nvml.Return) { + if n != 0 { + return nil, nvml.ERROR_INVALID_ARGUMENT + } + device := &mock.Device{ + GetUUIDFunc: func() (string, nvml.Return) { + return "GPU-orin", nvml.SUCCESS + }, + GetNameFunc: func() (string, nvml.Return) { + return "Orin (nvgpu)", nvml.SUCCESS + }, + GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) { + return nvml.PciInfo{}, nvml.ERROR_NOT_SUPPORTED + }, + } + return device, nvml.SUCCESS + }, + } +} + // TODO: We should move this mock to go-nvml/mock func mockIGXServer() nvml.Interface { thor := &mock.Device{ diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 4369a7215..fe54540e0 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -27,7 +27,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" - "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" ) @@ -45,8 +44,7 @@ type nvcdilib struct { configSearchPaths []string librarySearchPaths []string - csvFiles []string - csvIgnorePatterns []string + csv csvOptions vendor string class string @@ -115,10 +113,7 @@ func New(opts ...Option) (Interface, error) { var factory deviceSpecGeneratorFactory switch l.resolveMode() { case ModeCSV: - if len(l.csvFiles) == 0 { - l.csvFiles = csv.DefaultFileList() - } - factory = (*csvlib)(l) + factory = l.asCSVLib() case ModeManagement: if l.vendor == "" { l.vendor = "management.nvidia.com" diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index eab27f05d..6af7a8492 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -131,14 +131,22 @@ func WithMergedDeviceOptions(opts ...transform.MergedDeviceOption) Option { // WithCSVFiles sets the CSV files for the library func WithCSVFiles(csvFiles []string) Option { return func(o *nvcdilib) { - o.csvFiles = csvFiles + o.csv.Files = csvFiles } } // WithCSVIgnorePatterns sets the ignore patterns for entries in the CSV files. func WithCSVIgnorePatterns(csvIgnorePatterns []string) Option { return func(o *nvcdilib) { - o.csvIgnorePatterns = csvIgnorePatterns + o.csv.IgnorePatterns = csvIgnorePatterns + } +} + +// WithCSVCompatContainerRoot sets the compat root to use for the container in +// the case of nvgpu-only devices. +func WithCSVCompatContainerRoot(csvCompatContainerRoot string) Option { + return func(o *nvcdilib) { + o.csv.CompatContainerRoot = csvCompatContainerRoot } }