Skip to content

Commit fe9fa37

Browse files
committed
fix: jldoc error; Number => Float64; remove useless overload show
1 parent 7ee2a6f commit fe9fa37

File tree

2 files changed

+14
-25
lines changed

2 files changed

+14
-25
lines changed

src/layers/conv.jl

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -595,10 +595,6 @@ function (g::GlobalMaxPool)(x)
595595
return maxpool(x, pdims)
596596
end
597597

598-
function Base.show(io::IO, g::GlobalMaxPool)
599-
print(io, "GlobalMaxPool()")
600-
end
601-
602598
"""
603599
GlobalMeanPool()
604600
@@ -629,12 +625,8 @@ function (g::GlobalMeanPool)(x)
629625
return meanpool(x, pdims)
630626
end
631627

632-
function Base.show(io::IO, g::GlobalMeanPool)
633-
print(io, "GlobalMeanPool()")
634-
end
635-
636628
"""
637-
GlobalLPNormPool
629+
GlobalLPNormPool(p::Float64)
638630
639631
Global lp norm pooling layer.
640632
@@ -646,14 +638,14 @@ See also [`LPNormPool`](@ref).
646638
```jldoctest
647639
julia> xs = rand(Float32, 100, 100, 3, 50)
648640
649-
julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool())
641+
julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0))
650642
651643
julia> m(xs) |> size
652644
(1, 1, 7, 50)
653645
```
654646
"""
655647
struct GlobalLPNormPool
656-
p::Number
648+
p::Float64
657649
end
658650

659651
function (g::GlobalLPNormPool)(x)
@@ -663,10 +655,6 @@ function (g::GlobalLPNormPool)(x)
663655
return lpnormpool(x, pdims; p=g.p)
664656
end
665657

666-
function Base.show(io::IO, g::GlobalLPNormPool)
667-
print(io, "GlobalLPNormPool(p=", g.p, ")")
668-
end
669-
670658
"""
671659
MaxPool(window::NTuple; pad=0, stride=window)
672660
@@ -790,7 +778,7 @@ function Base.show(io::IO, m::MeanPool)
790778
end
791779

792780
"""
793-
LPNormPool(window::NTuple, p::Number; pad=0, stride=window)
781+
LPNormPool(window::NTuple, p::Float64; pad=0, stride=window)
794782
795783
Lp norm pooling layer, calculating p-norm distance for each window,
796784
also known as LPPool in pytorch.
@@ -802,14 +790,15 @@ By default the window size is also the stride in each dimension.
802790
The keyword `pad` accepts the same options as for the `Conv` layer,
803791
including `SamePad()`.
804792
805-
See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalLPNormPool`](@ref).
793+
See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalLPNormPool`](@ref),
794+
[`pytorch LPPool`](https://pytorch.org/docs/stable/generated/torch.nn.LPPool2d.html).
806795
807796
# Examples
808797
809798
```jldoctest
810799
julia> xs = rand(Float32, 100, 100, 3, 50);
811800
812-
julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2; pad=SamePad()))
801+
julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2.0; pad=SamePad()))
813802
Chain(
814803
Conv((5, 5), 3 => 7), # 532 parameters
815804
LPNormPool((5, 5), p=2, pad=2),
@@ -821,7 +810,7 @@ julia> m[1](xs) |> size
821810
julia> m(xs) |> size
822811
(20, 20, 7, 50)
823812
824-
julia> layer = LPNormPool((5,), 2, pad=2, stride=(3,)) # one-dimensional window
813+
julia> layer = LPNormPool((5,), 2.0, pad=2, stride=(3,)) # one-dimensional window
825814
LPNormPool((5,), p=2, pad=2, stride=3)
826815
827816
julia> layer(rand(Float32, 100, 7, 50)) |> size
@@ -830,12 +819,12 @@ julia> layer(rand(Float32, 100, 7, 50)) |> size
830819
"""
831820
struct LPNormPool{N,M}
832821
k::NTuple{N,Int}
833-
p::Number
822+
p::Float64
834823
pad::NTuple{M,Int}
835824
stride::NTuple{N,Int}
836825
end
837826

838-
function LPNormPool(k::NTuple{N,Integer}, p::Number; pad = 0, stride = k) where N
827+
function LPNormPool(k::NTuple{N,Integer}, p::Float64; pad = 0, stride = k) where N
839828
stride = expand(Val(N), stride)
840829
pad = calc_padding(LPNormPool, pad, k, 1, stride)
841830
return LPNormPool(k, p, pad, stride)
@@ -847,7 +836,7 @@ function (l::LPNormPool)(x)
847836
end
848837

849838
function Base.show(io::IO, l::LPNormPool)
850-
print(io, "LPNormPool(", l.k, ", p=", l.p)
839+
print(io, "LPNormPool(", l.k, ", ", l.p)
851840
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
852841
l.stride == l.k || print(io, ", stride=", _maybetuple_string(l.stride))
853842
print(io, ")")

test/layers/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ using Flux: gradient
1717
@test size(gmp(x)) == (1, 1, 3, 2)
1818
gmp = GlobalMeanPool()
1919
@test size(gmp(x)) == (1, 1, 3, 2)
20-
glmp = GlobalLPNormPool(2)
20+
glmp = GlobalLPNormPool(2.0)
2121
@test size(glmp(x)) == (1, 1, 3, 2)
2222
mp = MaxPool((2, 2))
2323
@test mp(x) == maxpool(x, PoolDims(x, 2))
2424
mp = MeanPool((2, 2))
2525
@test mp(x) == meanpool(x, PoolDims(x, 2))
26-
lnp = LPNormPool((2,2), 2)
27-
@test lnp(x) == lpnormpool(x, PoolDims(x, 2); p=2)
26+
lnp = LPNormPool((2,2), 2.0)
27+
@test lnp(x) == lpnormpool(x, PoolDims(x, 2); p=2.0)
2828
end
2929

3030
@testset "CNN" begin

0 commit comments

Comments
 (0)