diff --git a/internal/edits/device.go b/internal/edits/device.go index 2b37f34a0..dd55ed06a 100644 --- a/internal/edits/device.go +++ b/internal/edits/device.go @@ -17,6 +17,8 @@ package edits import ( + "os" + "tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/specs-go" @@ -28,15 +30,23 @@ import ( type device discover.Device // toEdits converts a discovered device to CDI Container Edits. -func (d device) toEdits() (*cdi.ContainerEdits, error) { +func (d device) toEdits(allowAdditionalGIDs bool) (*cdi.ContainerEdits, error) { deviceNode, err := d.toSpec() if err != nil { return nil, err } + var additionalGIDs []uint32 + if allowAdditionalGIDs { + if requiredGID := getRequiredGID(deviceNode); requiredGID != 0 { + additionalGIDs = append(additionalGIDs, requiredGID) + } + } + e := cdi.ContainerEdits{ ContainerEdits: &specs.ContainerEdits{ - DeviceNodes: []*specs.DeviceNode{deviceNode}, + DeviceNodes: []*specs.DeviceNode{deviceNode}, + AdditionalGIDs: additionalGIDs, }, } return &e, nil @@ -71,12 +81,38 @@ func (d device) fromPathOrDefault() *specs.DeviceNode { } } - return &specs.DeviceNode{ + deviceNode := &specs.DeviceNode{ HostPath: d.HostPath, Path: d.Path, Major: dn.Major, Minor: dn.Minor, FileMode: &dn.FileMode, Permissions: string(dn.Permissions), + GID: &dn.Gid, + UID: &dn.Uid, } + + return deviceNode +} + +// getRequiredGID returns the group id of the device if the device is not world read/writable. +// If the information cannot be extracted or an error occurs, 0 is returned. +func getRequiredGID(d *specs.DeviceNode) uint32 { + // Handle the underdefined cases where we do not have enough information to + // extract the GID for the device OR whether the additional GID is required. + if d.GID == nil { + return 0 + } + if d.FileMode == nil { + return 0 + } + if d.FileMode.Type() != os.ModeCharDevice { + return 0 + } + + if permissionsForOther := d.FileMode.Perm(); permissionsForOther&06 != 0 { + return *d.GID + } + + return 0 } diff --git a/internal/oci/spec_mock.go b/internal/oci/spec_mock.go index f004d69c3..ff8ff6476 100644 --- a/internal/oci/spec_mock.go +++ b/internal/oci/spec_mock.go @@ -4,9 +4,8 @@ package oci import ( - "sync" - "github.com/opencontainers/runtime-spec/specs-go" + "sync" ) // Ensure, that SpecMock does implement Spec. diff --git a/pkg/nvcdi/namer_nvml_mock.go b/pkg/nvcdi/namer_nvml_mock.go index 6a704b45c..f81a1eee1 100644 --- a/pkg/nvcdi/namer_nvml_mock.go +++ b/pkg/nvcdi/namer_nvml_mock.go @@ -4,9 +4,8 @@ package nvcdi import ( - "sync" - "github.com/NVIDIA/go-nvml/pkg/nvml" + "sync" ) // Ensure, that nvmlUUIDerMock does implement nvmlUUIDer. diff --git a/pkg/nvcdi/transform/deduplicate.go b/pkg/nvcdi/transform/deduplicate.go index 27be1b67b..c7eace03f 100644 --- a/pkg/nvcdi/transform/deduplicate.go +++ b/pkg/nvcdi/transform/deduplicate.go @@ -17,6 +17,8 @@ package transform import ( + "slices" + "tags.cncf.io/container-device-interface/specs-go" ) @@ -50,6 +52,12 @@ func (d dedupe) Transform(spec *specs.Spec) error { } func (d dedupe) transformEdits(edits *specs.ContainerEdits) error { + additionalGIDs, err := d.deduplicateAdditionalGIDs(edits.AdditionalGIDs) + if err != nil { + return err + } + edits.AdditionalGIDs = additionalGIDs + deviceNodes, err := d.deduplicateDeviceNodes(edits.DeviceNodes) if err != nil { return err @@ -150,3 +158,19 @@ func (d dedupe) deduplicateMounts(entities []*specs.Mount) ([]*specs.Mount, erro } return mounts, nil } + +func (d dedupe) deduplicateAdditionalGIDs(entities []uint32) ([]uint32, error) { + seen := make(map[uint32]bool) + var additionalGIDs []uint32 + for _, e := range entities { + if seen[e] { + continue + } + seen[e] = true + additionalGIDs = append(additionalGIDs, e) + } + + slices.Sort(additionalGIDs) + + return additionalGIDs, nil +}