Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/header/TransferBench.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ namespace TransferBench
#define hipMemGenericAllocationHandle_t CUmemGenericAllocationHandle
#define hipMemAccessDesc CUmemAccessDesc
#define hipMemFabricHandle_t CUmemFabricHandle
#define hipMemLocation CUmemLocation

// Enumerations
#define hipDeviceAttributeClockRate cudaDevAttrClockRate
Expand All @@ -628,6 +629,7 @@ namespace TransferBench
#define hipMemcpyHostToDevice cudaMemcpyHostToDevice
#define hipSuccess cudaSuccess
#define hipMemLocationTypeDevice CU_MEM_LOCATION_TYPE_DEVICE
#define hipMemLocationTypeHostNuma CU_MEM_LOCATION_TYPE_HOST_NUMA
#define hipMemAllocationTypePinned CU_MEM_ALLOCATION_TYPE_PINNED
#define hipMemHandleTypeFabric CU_MEM_HANDLE_TYPE_FABRIC
#define hipMemAllocationGranularityRecommended CU_MEM_ALLOC_GRANULARITY_RECOMMENDED
Expand Down Expand Up @@ -1411,6 +1413,25 @@ namespace {
}

#ifdef POD_COMM_ENABLED
static ErrResult GetMemLocation(MemDevice const& memDevice, hipMemLocation& location)
{
if (IsCpuMemType(memDevice.memType)) {
location.type = hipMemLocationTypeHostNuma;
} else if (IsGpuMemType(memDevice.memType) && memDevice.memType != MEM_MANAGED) {
location.type = hipMemLocationTypeDevice;
} else {
return {ERR_FATAL, "Unsupported memory location"};
}

// Determine location id
if (memDevice.memType == MEM_CPU_CLOSEST) {
location.id = GetClosestCpuNumaToGpu(memDevice.memIndex);
} else {
location.id = memDevice.memIndex;
}
Comment on lines +1426 to +1431
return ERR_NONE;
}

static ErrResult GetMemAllocationProp(MemDevice const& memDevice, hipMemAllocationProp& prop)
{

Expand All @@ -1428,10 +1449,7 @@ namespace {
}

prop.requestedHandleTypes = hipMemHandleTypeFabric;
// at this point shouldn't have any memtype other than device
// ERR_CHECK(GetMemLocation(memDevice, prop.location));
prop.location.type = hipMemLocationTypeDevice;
prop.location.id = memDevice.memIndex;
ERR_CHECK(GetMemLocation(memDevice, prop.location));
return ERR_NONE;
}
#endif
Expand Down Expand Up @@ -1519,19 +1537,20 @@ namespace {

// Specify memory access descriptor to enable local read/write
hipMemAccessDesc desc;
// ERR_CHECK(GetMemLocation(memDevice, desc.location));
desc.location.type = hipMemLocationTypeDevice;
desc.location.id = memDevice.memIndex;
ERR_CHECK(GetMemLocation(memDevice, desc.location));
desc.flags = hipMemAccessFlagsProtReadWrite;

// Set access flags for virtual address range
ERR_CHECK(hipMemSetAccess((gpu_device_ptr)*memPtr, roundedUpBytes, &desc, 1));

// Clear the memory
if (IsCpuMemType(memType)) {
// Note: CheckPages() / move_pages() is intentionally NOT called here.
// For fabric-exportable HOST_NUMA memory the VA is owned by the driver
// (not a normal anonymous mmap VMA), so move_pages() returns
// -EFAULT/-EINVAL and would falsely trip a fatal error. NUMA placement
// should be already enforced by the prop.location passed to hipMemCreate().
memset(*memPtr, 0, roundedUpBytes);
// Check that the allocated pages are actually on the correct NUMA node
ERR_CHECK(CheckPages((char*)*memPtr, roundedUpBytes, deviceIdx));
} else if (IsGpuMemType(memType)) {
ERR_CHECK(hipSetDevice(memDevice.memIndex));
ERR_CHECK(hipMemset(*memPtr, 0, numBytes));
Expand Down Expand Up @@ -8064,6 +8083,7 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
#undef hipMemGenericAllocationHandle_t
#undef hipMemAccessDesc
#undef hipMemFabricHandle_t
#undef hipMemLocation

// Enumerations
#undef hipDeviceAttributeClockRate
Expand All @@ -8077,6 +8097,7 @@ static bool IsConfiguredGid(union ibv_gid const& gid)
#undef hipMemcpyHostToDevice
#undef hipSuccess
#undef hipMemLocationTypeDevice
#undef hipMemLocationTypeHostNuma
#undef hipMemAllocationTypePinned
//#undef hipMemAllocationTypeUncached
#undef hipMemHandleTypeFabric
Expand Down