Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ import (
"fmt"
"os"
"slices"
"strings"

"golang.org/x/mod/semver"
)

type compatElfHeader struct {
Format int
CUDAVersion string `json:"CUDA Version"`
CUDAVersion cudaVersion `json:"CUDA Version"`
Driver []int
Device []int
}
Expand Down Expand Up @@ -114,10 +117,35 @@ func getCUDAFwdCompatibilitySection(lib *elf.File) *elf.Section {
return nil
}

func (h *compatElfHeader) UseCompat(driverMajor int) bool {
// UseCompat checks whether the CUDA compat libraries with the specified elf
// header should be used given the specified host versions.
// If the hostDriverVersion is specified and the ELF header includes a list of
// driver verions, this is checked, otherwise the CUDA version specified in the
// ELF section is checked.
func (h *compatElfHeader) UseCompat(hostDriverMajor int, hostCUDAVersion string) bool {
if h == nil {
return false
}

return slices.Contains(h.Driver, driverMajor)
if hostDriverMajor != 0 && len(h.Driver) > 0 {
return slices.Contains(h.Driver, hostDriverMajor)
}

return h.CUDAVersion.UseCompat(hostCUDAVersion)
}

type cudaVersion string

// UseCompat is true if the container CUDA version is strictly greater than the
// host CUDA version.
func (containerVersion cudaVersion) UseCompat(hostVersion string) bool {
if containerVersion == "" || hostVersion == "" {
return false
}

return semver.Compare(normalizeVersion(containerVersion), normalizeVersion(hostVersion)) > 0
}

func normalizeVersion[T string | cudaVersion](v T) string {
return "v" + strings.TrimPrefix(string(v), "v")
}
33 changes: 26 additions & 7 deletions cmd/nvidia-cdi-hook/cudacompat/cudacompat.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type command struct {
type options struct {
cudaCompatContainerRoot string
hostDriverVersion string
hostCudaVersion string
// containerSpec allows the path to the container spec to be specified for
// testing.
containerSpec string
Expand Down Expand Up @@ -79,6 +80,11 @@ func (m command) build() *cli.Command {
Usage: "Specify the host driver version. If the CUDA compat libraries detected in the container do not have a higher MAJOR version, the hook is a no-op.",
Destination: &options.hostDriverVersion,
},
&cli.StringFlag{
Name: "host-cuda-version",
Usage: "Specify the CUDA version supported by the host driver.",
Destination: &options.hostCudaVersion,
},
&cli.StringFlag{
Name: "cuda-compat-container-root",
Usage: "Specify the folder in which CUDA compat libraries are located in the container",
Expand All @@ -103,7 +109,9 @@ func (m command) validateFlags(_ *cli.Command, _ *options) error {
}

func (m command) run(_ *cli.Command, o *options) error {
if o.hostDriverVersion == "" {
// If neither the host driver version nor the host cuda version is specified
// the hook is a no-op.
if o.hostDriverVersion == "" && o.hostCudaVersion == "" {
return nil
}

Expand All @@ -129,8 +137,8 @@ func (m command) run(_ *cli.Command, o *options) error {
}

func (m command) getContainerForwardCompatDir(containerRoot containerRoot, o *options) (string, error) {
if o.hostDriverVersion == "" {
m.logger.Debugf("Host driver version not specified")
if o.hostDriverVersion == "" && o.hostCudaVersion == "" {
m.logger.Debugf("Neither a host driver version nor a host CUDA version was specified")
return "", nil
}
if !containerRoot.hasPath(o.cudaCompatContainerRoot) {
Expand All @@ -156,7 +164,7 @@ func (m command) getContainerForwardCompatDir(containerRoot containerRoot, o *op

libCudaCompatPath := libs[0]

useCompatLibs, err := m.useCompatLibraries(libCudaCompatPath, o.hostDriverVersion)
useCompatLibs, err := m.useCompatLibraries(libCudaCompatPath, o.hostDriverVersion, o.hostCudaVersion)
if err != nil {
return "", err
}
Expand All @@ -168,16 +176,24 @@ func (m command) getContainerForwardCompatDir(containerRoot containerRoot, o *op
return resolvedCompatDir, nil
}

func (m command) useCompatLibraries(libcudaCompatPath string, hostDriverVersion string) (bool, error) {
func (m command) useCompatLibraries(libcudaCompatPath string, hostDriverVersion string, hostCUDAVersion string) (bool, error) {
driverMajor, err := extractMajorVersion(hostDriverVersion)
if err != nil {
return false, fmt.Errorf("failed to extract major version from %q: %v", hostDriverVersion, err)
}

// First check the elf header.
// First check the ELF header. If this is present, we use the ELF header to
// determine whether the CUDA compat libraries in the container should be
// used.
cudaCompatHeader, _ := GetCUDACompatElfHeader(libcudaCompatPath)
if cudaCompatHeader != nil {
return cudaCompatHeader.UseCompat(driverMajor), nil
return cudaCompatHeader.UseCompat(driverMajor, hostCUDAVersion), nil
}

// If no CUDA Compat ELF header is available, and NO host driver version
// was specified, we don't use the CUDA compat libraries in the container.
if hostDriverVersion == "" {
return false, nil
}

compatDriverVersion := strings.TrimPrefix(filepath.Base(libcudaCompatPath), "libcuda.so.")
Expand Down Expand Up @@ -241,6 +257,9 @@ func (m command) createLdsoconfdFile(in containerRoot, pattern string, dirs ...s

// extractMajorVersion parses a version string and returns the major version as an int.
func extractMajorVersion(version string) (int, error) {
if version == "" {
return 0, nil
}
majorString := strings.SplitN(version, ".", 2)[0]
return strconv.Atoi(majorString)
}
12 changes: 10 additions & 2 deletions cmd/nvidia-ctk/cdi/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ type options struct {
featureFlags []string

csv struct {
files []string
ignorePatterns []string
files []string
ignorePatterns []string
CompatContainerRoot string
}

noAllDevice bool
Expand Down Expand Up @@ -212,6 +213,12 @@ func (m command) build() *cli.Command {
Destination: &opts.csv.ignorePatterns,
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_CSV_IGNORE_PATTERNS"),
},
&cli.StringFlag{
Name: "csv.compat-container-root",
Usage: "specify the container folder to use for CUDA Forward Compatibility in non-standard containers",
Destination: &opts.csv.CompatContainerRoot,
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_CSV_CONTAINER_COMPAT_ROOT"),
},
&cli.StringSliceFlag{
Name: "disable-hook",
Aliases: []string{"disable-hooks"},
Expand Down Expand Up @@ -384,6 +391,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths),
nvcdi.WithCSVFiles(opts.csv.files),
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns),
nvcdi.WithCSVCompatContainerRoot(opts.csv.CompatContainerRoot),
nvcdi.WithDisabledHooks(opts.disabledHooks...),
nvcdi.WithEnabledHooks(opts.enabledHooks...),
nvcdi.WithFeatureFlags(opts.featureFlags...),
Expand Down
3 changes: 3 additions & 0 deletions internal/config/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ type jitCDIModeConfig struct {

type csvModeConfig struct {
MountSpecPath string `toml:"mount-spec-path"`
// CompatContainerRoot specifies the compat root used when the the standard
// CUDA compat libraries should not be used.
CompatContainerRoot string `toml:"compat-container-root,omitempty"`
}

type legacyModeConfig struct {
Expand Down
31 changes: 24 additions & 7 deletions internal/discover/compat_libs.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,33 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

// EnableCUDACompatHookOptions defines the options that can be specified
// when creating the enable-cuda-compat hook.
type EnableCUDACompatHookOptions struct {
HostDriverVersion string
HostCUDAVersion string
CUDACompatContainerRoot string
}

// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, version string, cudaCompatContainerRoot string) Discover {
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, o *EnableCUDACompatHookOptions) Discover {
return hookCreator.Create(EnableCudaCompatHook, o.args()...)
}

func (o *EnableCUDACompatHookOptions) args() []string {
if o == nil {
return nil
}
var args []string
if version != "" && !strings.Contains(version, "*") {
args = append(args, "--host-driver-version="+version)
if o.HostDriverVersion != "" && !strings.Contains(o.HostDriverVersion, "*") {
args = append(args, "--host-driver-version="+o.HostDriverVersion)
}
if cudaCompatContainerRoot != "" {
args = append(args, "--cuda-compat-container-root="+cudaCompatContainerRoot)
if o.HostCUDAVersion != "" {
args = append(args, "--host-cuda-version="+o.HostCUDAVersion)
}

return hookCreator.Create("enable-cuda-compat", args...)
if o.CUDACompatContainerRoot != "" {
args = append(args, "--cuda-compat-container-root="+o.CUDACompatContainerRoot)
}
return args
}
1 change: 1 addition & 0 deletions internal/modifier/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithMode(nvcdi.ModeCSV),
nvcdi.WithCSVFiles(csvFiles),
nvcdi.WithCSVCompatContainerRoot(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI library: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion internal/modifier/gated.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, dr
return nil, fmt.Errorf("failed to get driver version: %w", err)
}

compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, version, "")
compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version})
// For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook.
if cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" {
return compatLibHookDiscoverer, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/nvcdi/driver-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string, libcudaSoParentDir
)
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)

cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, version, "")
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version})
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)

updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath)
Expand Down
Loading