Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 227 additions & 8 deletions src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,37 @@ function soft_memory_limit()
SOFT_MEMORY_LIMIT[] = soft_limit
end


## allocation statistics

mutable struct AllocStats
Base.@atomic alloc_count::Int
Base.@atomic alloc_bytes::Int

Base.@atomic free_count::Int
Base.@atomic free_bytes::Int

Base.@atomic total_time::Float64
end

AllocStats() = AllocStats(0, 0, 0, 0, 0.0)

Base.copy(s::AllocStats) =
AllocStats(s.alloc_count, s.alloc_bytes,
s.free_count, s.free_bytes, s.total_time)

Base.:(-)(a::AllocStats, b::AllocStats) = (;
alloc_count = a.alloc_count - b.alloc_count,
alloc_bytes = a.alloc_bytes - b.alloc_bytes,
free_count = a.free_count - b.free_count,
free_bytes = a.free_bytes - b.free_bytes,
total_time = a.total_time - b.total_time)

const alloc_stats = AllocStats()


## memory accounting

mutable struct MemoryStats
# Maximum size of the heap.
# Estimated during `maybe_collect` stage.
Expand Down Expand Up @@ -134,7 +165,7 @@ function account!(stats::MemoryStats, bytes::Integer)
Base.@atomic stats.live += bytes
end

const EAGER_GC::Ref{Bool} = Ref{Bool}(@load_preference("eager_gc", false))
const EAGER_GC::Ref{Bool} = Ref{Bool}(@load_preference("eager_gc", true))

function eager_gc!(flag::Bool)
global EAGER_GC[] = flag
Expand Down Expand Up @@ -210,6 +241,169 @@ function maybe_collect(; blocking::Bool = false)
return
end


## pool activity tracking

const POOL_STATUS = AMDGPU.LockedObject(Dict{Int, Ref{Bool}}())

function pool_mark(dev::HIPDevice)
ps = POOL_STATUS.payload
did = HIP.device_id(dev)
status = get(ps, did, nothing)
status === nothing && return nothing
return status[]
end

function pool_mark!(dev::HIPDevice, val::Bool)
ps = POOL_STATUS.payload
did = HIP.device_id(dev)
box = get(ps, did, nothing)
if box === nothing
Base.@lock POOL_STATUS.lock begin
box = get!(ps, did) do
Ref{Bool}(val)
end
end
end
box[] = val
return
end


## reclaim hooks

"""
reclaim_hooks

A list of callables that are invoked when memory needs to be reclaimed.
Downstream packages can push functions into this list to free cached resources
(e.g., workspace buffers, FFT plans, etc.) when GPU memory is scarce.
"""
const reclaim_hooks = Any[]


## pool cleanup

const _pool_cleanup_task = Ref{Task}()

function pool_cleanup()
idle_counters = Dict{Int, Int}()
while true
try
sleep(60)
catch ex
if ex isa EOFError
break
else
rethrow()
end
end

for dev in HIP.devices()
did = HIP.device_id(dev)
status = pool_mark(dev)
status === nothing && continue

if status
idle_counters[did] = 0
else
idle_counters[did] = get(idle_counters, did, 0) + 1
end
pool_mark!(dev, false)

if get(idle_counters, did, 0) >= 5
HIP.device!(dev) do
reclaim()
end
end
end
end
end


## reclaim

"""
reclaim([sz=typemax(Int)])

Reclaims `sz` bytes of cached memory. Use this to free GPU memory before
calling into functionality that does not use the memory pool. Returns the
number of bytes actually reclaimed.
"""
function reclaim(sz::Int=typemax(Int))
dev = AMDGPU.device()
for hook in reclaim_hooks
hook()
end
HIP.device_synchronize()
pool = Mem.pool_create(dev)
before = HIP.reserved_memory(pool)
HIP.trim(pool)
after = HIP.reserved_memory(pool)
return Int(before - after)
end


## pool status & queries

"""
used_memory()

Returns the amount of memory from the HIP memory pool that is currently
in use by the application.
"""
function used_memory()
pool = Mem.pool_create(AMDGPU.device())
Int(HIP.used_memory(pool))
end

"""
cached_memory()

Returns the amount of backing memory currently allocated (reserved) for the
HIP memory pool.
"""
function cached_memory()
pool = Mem.pool_create(AMDGPU.device())
Int(HIP.reserved_memory(pool))
end

"""
pool_status([io=stdout])

Report to `io` on the memory status of the current GPU and the active memory pool.
"""
function pool_status(io::IO=stdout)
free_bytes, total_bytes = info()
used_bytes = total_bytes - free_bytes
used_ratio = used_bytes / total_bytes
@printf(io, "Effective GPU memory usage: %.2f%% (%s/%s)\n",
100*used_ratio, Base.format_bytes(used_bytes),
Base.format_bytes(total_bytes))

pool = Mem.pool_create(AMDGPU.device())
pool_used = HIP.used_memory(pool)
pool_reserved = HIP.reserved_memory(pool)
@printf(io, "Memory pool usage: %s (%s reserved)\n",
Base.format_bytes(pool_used),
Base.format_bytes(pool_reserved))

hard_limit = hard_memory_limit()
soft_limit = soft_memory_limit()
if hard_limit != typemax(UInt64) || soft_limit != typemax(UInt64)
print(io, "Memory limit: ")
parts = String[]
if soft_limit != typemax(UInt64)
push!(parts, "soft = $(Base.format_bytes(soft_limit))")
end
if hard_limit != typemax(UInt64)
push!(parts, "hard = $(Base.format_bytes(hard_limit))")
end
println(io, join(parts, ", "))
end
end


# TODO handle stream capturing when we support HIP graphs
mutable struct Managed{M}
const mem::M
Expand Down Expand Up @@ -275,16 +469,41 @@ function Base.convert(::Type{Mem.AbstractAMDBuffer}, managed::Managed{M}) where
end

function pool_alloc(::Type{B}, bytesize) where B
s = AMDGPU.stream()
# @info "[pool_alloc] $(Base.format_bytes(bytesize))"
# display(stacktrace()); println()
# println()
# println()
Managed(B(bytesize; stream=s); stream=s)
maybe_collect()
time = Base.@elapsed begin
s = AMDGPU.stream()
managed = Managed(B(bytesize; stream=s); stream=s)
end

Base.@atomic alloc_stats.alloc_count += 1
Base.@atomic alloc_stats.alloc_bytes += bytesize
Base.@atomic alloc_stats.total_time += time

pool_mark!(AMDGPU.device(), true)

if isinteractive() && !isassigned(_pool_cleanup_task)
_pool_cleanup_task[] = errormonitor(Threads.@spawn pool_cleanup())
end

return managed
end

function pool_free(managed::Managed{M}) where M
_pool_free(managed.mem, managed.stream)
sz = Int(sizeof(managed.mem))
sz == 0 && return

try
time = Base.@elapsed _pool_free(managed.mem, managed.stream)
Base.@atomic alloc_stats.free_count += 1
Base.@atomic alloc_stats.free_bytes += sz
Base.@atomic alloc_stats.total_time += time
catch ex
Base.showerror_nostdio(ex,
"WARNING: Error while freeing $(Base.format_bytes(sz)) of GPU memory")
Base.show_backtrace(Core.stdout, catch_backtrace())
Core.println()
end
return
end

function _pool_free(buf, stream::HIPStream)
Expand Down
Loading