-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
I'm trying to write an absolute minimal amount of code to implement a multi-headed self-attention layer. I want to try to do this with TensorCast.jl, both to learn the syntax better, and perhaps as a nice demo of ML in Julia.
Right now I am trying to compute the query matrix in one go. This works:
batch = 4
length = 100
m = 32
# Data:
x = randn(Float32, length, m, batch)
heads = 10
# Layer to compute Q for all heads:
Q = Dense(m, m*heads)
# Computation:
@cast q1[ℓ,ch,n] := Q(x[ℓ,:,n])[ch]
@cast q2[ℓ,c,h,n] := q1[ℓ,(c,h),n] (h in 1:heads)However, if I try to combine them in one go:
@cast q2[ℓ,c,h,n] := Q(x[ℓ,:,n])[(c,h)] (h in 1:heads)it gives me the error:
ERROR: LoadError: MethodError: Cannot `convert` an object of type Expr to an object of type Symbol
Closest candidates are:
convert(::Type{T}, ::T) where T
@ Base Base.jl:64
Symbol(::Any...)
@ Base strings/basic.jl:229
Stacktrace:
[1] setindex!(h::Dict{Symbol, Nothing}, v0::Nothing, key0::Expr)
@ Base ./dict.jl:361
[2] push!(s::Set{Symbol}, x::Expr)
@ Base ./set.jl:103
[3] checknorepeats(flat::Vector{Any}, call::TensorCast.CallInfo, msg::String)
@ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:1484
[4] standardglue(ex::Any, target::Vector{Any}, store::NamedTuple{(:dict, :assert, :mustassert, :seen, :need, :top, :main), Tuple{Dict{Any, Any}, Vararg{Vector{Any}, 6}}}, call::TensorCast.CallInfo)
@ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:420
[5] (::TensorCast.var"#3#5"{TensorCast.CallInfo, NamedTuple{(:dict, :assert, :mustassert, :seen, :need, :top, :main), Tuple{Dict{Any, Any}, Vararg{Vector{Any}, 6}}}})(x::Expr)
@ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:189
[6] walk(x::Expr, inner::Function, outer::TensorCast.var"#3#5"{TensorCast.CallInfo, NamedTuple{(:dict, :assert, :mustassert, :seen, :need, :top, :main), Tuple{Dict{Any, Any}, Vararg{Vector{Any}, 6}}}})
@ MacroTools ~/.julia/packages/MacroTools/qijNY/src/utils.jl:112
[7] postwalk(f::Function, x::Expr)
@ MacroTools ~/.julia/packages/MacroTools/qijNY/src/utils.jl:122
[8] _macro(exone::Expr, extwo::Expr, exthree::Nothing; call::TensorCast.CallInfo, dict::Dict{Any, Any})
@ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:189
[9] _macro
@ ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:154 [inlined]
[10] var"@cast"(__source__::LineNumberNode, __module__::Module, exs::Vararg{Any})
@ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:74Is this sort of thing possible? Or maybe it's too tricky to stack notation like this?
Thanks,
Miles
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels