Skip to content

Expr -> Symbol MethodError when combining mapslices and reshapes #67

@MilesCranmer

Description

@MilesCranmer

I'm trying to write an absolute minimal amount of code to implement a multi-headed self-attention layer. I want to try to do this with TensorCast.jl, both to learn the syntax better, and perhaps as a nice demo of ML in Julia.

Right now I am trying to compute the query matrix in one go. This works:

batch = 4
length = 100
m = 32

# Data:
x = randn(Float32, length, m, batch)

heads = 10

# Layer to compute Q for all heads:
Q = Dense(m, m*heads)

# Computation:
@cast q1[ℓ,ch,n] := Q(x[ℓ,:,n])[ch]
@cast q2[ℓ,c,h,n] := q1[ℓ,(c,h),n] (h in 1:heads)

However, if I try to combine them in one go:

@cast q2[ℓ,c,h,n] := Q(x[ℓ,:,n])[(c,h)] (h in 1:heads)

it gives me the error:

ERROR: LoadError: MethodError: Cannot `convert` an object of type Expr to an object of type Symbol

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:64
  Symbol(::Any...)
   @ Base strings/basic.jl:229

Stacktrace:
  [1] setindex!(h::Dict{Symbol, Nothing}, v0::Nothing, key0::Expr)
    @ Base ./dict.jl:361
  [2] push!(s::Set{Symbol}, x::Expr)
    @ Base ./set.jl:103
  [3] checknorepeats(flat::Vector{Any}, call::TensorCast.CallInfo, msg::String)
    @ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:1484
  [4] standardglue(ex::Any, target::Vector{Any}, store::NamedTuple{(:dict, :assert, :mustassert, :seen, :need, :top, :main), Tuple{Dict{Any, Any}, Vararg{Vector{Any}, 6}}}, call::TensorCast.CallInfo)
    @ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:420
  [5] (::TensorCast.var"#3#5"{TensorCast.CallInfo, NamedTuple{(:dict, :assert, :mustassert, :seen, :need, :top, :main), Tuple{Dict{Any, Any}, Vararg{Vector{Any}, 6}}}})(x::Expr)
    @ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:189
  [6] walk(x::Expr, inner::Function, outer::TensorCast.var"#3#5"{TensorCast.CallInfo, NamedTuple{(:dict, :assert, :mustassert, :seen, :need, :top, :main), Tuple{Dict{Any, Any}, Vararg{Vector{Any}, 6}}}})
    @ MacroTools ~/.julia/packages/MacroTools/qijNY/src/utils.jl:112
  [7] postwalk(f::Function, x::Expr)
    @ MacroTools ~/.julia/packages/MacroTools/qijNY/src/utils.jl:122
  [8] _macro(exone::Expr, extwo::Expr, exthree::Nothing; call::TensorCast.CallInfo, dict::Dict{Any, Any})
    @ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:189
  [9] _macro
    @ ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:154 [inlined]
 [10] var"@cast"(__source__::LineNumberNode, __module__::Module, exs::Vararg{Any})
    @ TensorCast ~/.julia/packages/TensorCast/mQB8h/src/macro.jl:74

Is this sort of thing possible? Or maybe it's too tricky to stack notation like this?

Thanks,
Miles

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions