Skip to content

Commit 86dc920

Browse files
committed
Embedding and autosize
1 parent fa9279c commit 86dc920

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

src/outputsize.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,11 @@ is needed to make `@autosize (2,3,4) Dense(_ => 5)` return
288288
"""
289289
autosizefor(::Type, x::AbstractArray) = size(x, max(1, ndims(x)-1))
290290
autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1)
291-
autosizefor(::Type{<:Embedding}, x::AbstractArray) = size(x, 1)
292291
autosizefor(::Type{<:LayerNorm}, x::AbstractArray) = size(x, 1)
293292

293+
autosizefor(::Type{<:Embedding}, x::AbstractArray) = error(
294+
"@autosize Embeeding(_ => n) cannot work, as this _ is the size of the vocabulary, not an array size")
295+
294296
_replaceunderscore(e, s) = e === :_ ? s : e
295297
_replaceunderscore(ex::Expr, s) = Expr(ex.head, map(a -> _replaceunderscore(a, s), ex.args)...)
296298

test/outputsize.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,6 @@ end
190190
m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
191191
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)
192192

193-
@test_broken begin # outputsize fails on Embedding
194-
m = @autosize (2, 3, 4, 5) Embedding(_ => 10) # goes by first dim, not 2nd-last
195-
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)
196-
end
197-
198193
m = @autosize (9,) Dense(_ => div(_,2))
199194
@test randn(9) |> m |> size == (4,)
200195

@@ -249,6 +244,11 @@ end
249244
# https://github.com/FluxML/Flux.jl/issues/2086
250245
m = @autosize (3, 1) Chain(; c = Dense(_ => 2, sigmoid), b = BatchNorm(_, affine=false))
251246
@test randn(Float32, 3, 32) |> m |> size == (2, 32)
247+
248+
# Embedding takes a vocab size, not an array size
249+
@test_throws ErrorException @autosize (2, 3) Embedding(_ => 10)
250+
m = @autosize (3,) Chain(Embedding(26 => 10), Dense(_, 4))
251+
@test rand(1:26, 3) |> m |> size == (4, 3)
252252
end
253253

254254
@testset "LazyLayer" begin

0 commit comments

Comments
 (0)