Skip to content
76 changes: 58 additions & 18 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>"]
version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -14,6 +16,3 @@ julia = "≥ 1.0.0"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
Comment thread
kraftpunk97-zz marked this conversation as resolved.
130 changes: 130 additions & 0 deletions src/Access.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
using Flux: softmax

cosine_sim(u, v) = (u'v)/(norm(u)*norm(v))

# content-based addressing
"""
memprobdistrib(MemMat, key, keystrength)

Defines a normalized probability distribution over the memory locations.
"""
function memprobdistrib(M, k, β)
out = [cosine_sim(k, M[i, :]) for i in 1:size(M, 1)] .* β
out = softmax(out)
end

oneplus(x::AbstractVecOrMat) = log.(exp.(x) .+ 1)

mutable struct Access
MemMat # R^{N*W}
LinkMat # R^{N*N}
readWts # RH element Array of [0, 1]^N
wrtWt # WH element Array of (0, 1)^{N}
usageVec # (0, 1)^{N}
precedenceWt # (0, 1)^{N}
readVecs
numWriteHeads
numReadHeads
wordSize
memorySize
end

function Access(memorySize=128, wordSize=20, numReadHeads=1, numWriteHeads=1)
MemMat = zeros(Float32, memorySize, wordSize)
LinkMat = zeros(Float32, memorySize, memorySize)

readWts = rand(Float32, memorySize, numReadHeads)
fill!(readWts, 1f-6)

wrtWt = rand(Float32, memorySize, numWriteHeads)
fill!(wrtWt, 1f-6)

usageVec = zeros(Float32, memorySize, numWriteHeads)
fill!(usageVec, 1f-6)
precedenceWt = zeros(Float32, memorySize)

readVecs = zeros(Float32, wordSize, numReadHeads)
fill!(readVecs, 1f-6)

Access(MemMat, LinkMat, readWts, wrtWt, usageVec, precedenceWt, readVecs,
numWriteHeads, numReadHeads, wordSize, memorySize)
end

"""
interfacedisect(interfaceVec, writeHead, wordSize, readHeads)

Disects the interface vector obtained as the output to obtain various memory
access controls for the DNC.
"""
function interfacedisect(interfaceVec, writeHeads, wordSize, readHeads)
demarcations = cumsum([0, # Starting Index
readHeads*wordSize, # read keys
readHeads, # read strengths
writeHeads*wordSize, # write keys
writeHeads, # write strengths
writeHeads*wordSize, # erase vectors
writeHeads*wordSize, # write vectors
readHeads, # free gates
writeHeads, # allocation gates
writeHeads, # write gates
readHeads * (1 + 2writeHeads) # read modes
])


readkeys = interfaceVec[demarcations[1]+1:demarcations[2]]
readstrengths = oneplus(interfaceVec[demarcations[2]+1:demarcations[3]])
writekeys = interfaceVec[demarcations[3]+1:demarcations[4]]
writestrengths = oneplus(interfaceVec[demarcations[4]+1:demarcations[5]])
eraseVec = σ.(interfaceVec[demarcations[5]+1:demarcations[6]])
writeVec = interfaceVec[demarcations[6]+1:demarcations[7]]
freeGts = σ.(interfaceVec[demarcations[7]+1:demarcations[8]])
allocGt = σ.(interfaceVec[demarcations[8]+1:demarcations[9]])
writeGt = σ.(interfaceVec[demarcations[9]+1:demarcations[10]])
readmodes = softmax(interfaceVec[demarcations[10]+1:demarcations[11]])

readkeys = reshape(readkeys, wordSize, readHeads) # W * RH
writekeys = reshape(writekeys, wordSize, writeHeads) # W * WH
eraseVec = reshape(eraseVec, wordSize, writeHeads) # W * WH
writeVec = reshape(writeVec, wordSize, writeHeads) # W * WH
readmodes = reshape(readmodes, (1+2writeHeads), readHeads) # (WH for backward + WH for forward + 1 for content lookup) * RH

return (readkeys=readkeys, readstrengths=readstrengths, writekeys=writekeys,
writestrengths=writestrengths, eraseVec=eraseVec, writeVec=writeVec,
freeGts=freeGts, allocGts=allocGt, writeGts=writeGt, readmodes=readmodes)
end

function (access::Access)(interfaceVec)
# dynamic memory allocation

interface = interfacedisect(interfaceVec, access.numWriteHeads, access.wordSize, access.numReadHeads)

memRetVec = prod(1 .- interface[:freeGts]' .* access.readWts, dims=2) # Memory Retention Vector = [0, 1]^{N}
access.usageVec = (access.usageVec .+ access.wrtWt .- access.usageVec .* access.wrtWt) .* memRetVec
freelist = sortperm(access.usageVec) # Z^{N}
allocWt = zeros(access.usageVec)
@. allocWt[freelist] = (1 - access.usageVec[freelist]) * cumprod([1; access.usageVec[freelist]][1:end-1]) # (0, 1)^{N}

# writing
wrtcntWt = memprobdistrib(access.MemMat, interface[:writekeys], interface[:writestrengths]) # Write content weighting = (0, 1)^{N}
access.wrtWt .= interface[:writeGts] * (interface[:allocGts] * allocWt + (1 - interface[:allocGts])*wrtcntWt)
@. access.MemMat *= (ones(access.MemMat) - access.wrtWt*interface[:eraseVec]') # First we erase...
@. access.MemMat += access.wrtWt*interface[:writeVec]' # Then we write.

# temporal linkage
eye = Matrix{Float32}(I, size(access.LinkMat)...)
prevlinkscale = @. 1 - access.wrtWt - access.wrtWt'
newlink = @. access.wrtWt * access.precedenceWt'
@. access.LinkMat = (1 - eye) * (prevlinkscale * access.LinkMat + newlink)

access.precedenceWt = (1 - sum(access.wrtWt)) .* access.precedenceWt .+ access.wrtWt

# reading
forwardWts = [access.LinkMat * readWt for readWt in access.readWts]
backwardWts = [access.LinkMat' * readWt for readWt in access.readWts]
readcntWts = memprobdistrib.([access.MemMat], readkeys, readstrengths) # Read content weightings

access.readWts = [π[1].*b .+ π[2].*readcntWts .+ π[3].*f for (π, b, f) in (interface[:readmodes], backwardWts, forwardWts)]
access.readvecs = [access.MemMat' * W_r for W_r in access.readWts]

return(readvecs)
end
44 changes: 44 additions & 0 deletions src/DNC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include("DNCLSTM.jl")
include("Access.jl")

using Distributions: TruncatedNormal
using Flux

mutable struct DNC
controller
access::Access
interfaceVec
end

function DNC(memory_size=16, word_size=16, num_read_heads=4, num_write_heads=1,
hidden_size=64, output_size=4, input_size=4)
controller_input_size = input_size + word_size * num_read_heads

# A truncated normal distribution with no elements further than 2σ of μ.
dist = TruncatedNormal(0, 0.01, -0.02, 0.02)

# The interface vector of this dimension will only work when the number of
# write heads is equal to 1. I have yet to figure out how this value changes
# as the number of write heads change.
interface_vec_dimensions = word_size * num_read_heads + 3word_size + 5num_read_heads + 3
controller_output_size = output_size + interface_vec_dimensions
controller = LSTM(controller_input_size, controller_output_size)
access = Access(memory_size, word_size, num_read_heads, num_write_heads)

interface_vec = rand(dist, interface_vec_dimensions)
DNC(controller, access, interface_vec)
end

function (dnc::DNC)(input)
readVecs = dnc.access.readVecs

# flattening readVecs
readVecs = reshape(readVecs, readVecs |> size |> prod)

# concatinating them with input to form controller input
controller_input = [input; readVecs]

controller_output = dnc.controller(controller_input)


end
2 changes: 1 addition & 1 deletion src/DNCLSTM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
@treelike DNCLSTMCell

function DNCLSTMCell(in::Integer, hidden::Integer; init=glorot_uniform)
DNCLSTMCell([init(hidden, in+2*hidden) for i in 1:4]...,
DNCLSTMCell([init(hidden, in) for i in 1:4]...,
[zeros(hidden) for i in 1:6]...)
end

Expand Down
13 changes: 3 additions & 10 deletions src/DifferentiableNeuralComputer.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
module DifferentiableNeuralComputer

import Flux
import Zygote

include("DNCLSTM.jl")
include("Access.jl")
include("DNC.jl")

struct DNC end

function (dnc::DNC)(x)
return x
end

end # module
end # module DifferentiableNeuralComputer
Loading