Skip to content

Commit 57b38a8

Browse files
committed
move code
1 parent 3f57415 commit 57b38a8

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/layers/basic.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,6 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
687687
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
688688
(m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...)
689689

690-
(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
691-
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)
692-
693690
function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
694691
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
695692
return m(onecold(x))

src/outputsize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ end
163163

164164
## fixes for layers that don't work out of the box
165165

166+
(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
167+
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)
168+
166169
for (fn, Dims) in ((:conv, DenseConvDims),)
167170
@eval begin
168171
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)

0 commit comments

Comments
 (0)