Skip to content

Commit 322e9fe

Browse files
authored
Merge pull request #13 from JuliaReinforcementLearning/fix_12
fix #12
2 parents cf91dd6 + 9c5ce7d commit 322e9fe

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

src/CircularArrayBuffers.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,25 @@ Base.size(cb::CircularArrayBuffer{T,N}, i::Integer) where {T,N} = i == N ? cb.nf
4747
Base.size(cb::CircularArrayBuffer{T,N}) where {T,N} = ntuple(i -> size(cb, i), N)
4848
Base.getindex(cb::CircularArrayBuffer{T,N}, i::Int) where {T,N} = getindex(cb.buffer, _buffer_index(cb, i))
4949
Base.getindex(cb::CircularArrayBuffer{T,N}, I...) where {T,N} = getindex(cb.buffer, Base.front(I)..., _buffer_frame(cb, Base.last(I)))
50+
51+
# !!!
52+
# strange, but we need this function to show `CircularVectorBuffer` correctly
53+
# `Base.print_array` will try to use `isassigned(cb, i, j)` to print elements
54+
# And, `X::AbstractVector[2, 1]` is valid !!!
55+
# without this line
56+
# ```julia
57+
# julia> cb = CircularArrayBuffer([1., 2.])
58+
# CircularVectorBuffer(::Vector{Float64}) with eltype Float64:
59+
# 1.0
60+
# 2.0
61+
62+
# julia> push!(cb, 3)
63+
# CircularVectorBuffer(::Vector{Float64}) with eltype Float64:
64+
# #undef
65+
# #undef
66+
# ```
67+
Base.getindex(cb::CircularVectorBuffer, i, j) = getindex(cb.buffer, _buffer_frame(cb, i), j)
68+
5069
Base.setindex!(cb::CircularArrayBuffer{T,N}, v, i::Int) where {T,N} = setindex!(cb.buffer, v, _buffer_index(cb, i))
5170
Base.setindex!(cb::CircularArrayBuffer{T,N}, v, I...) where {T,N} = setindex!(cb.buffer, v, Base.front(I)..., _buffer_frame(cb, Base.last(I)))
5271

@@ -92,11 +111,7 @@ function Base.push!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
92111
end
93112
if N == 1
94113
i = _buffer_frame(cb, cb.nframes)
95-
if ndims(data) == 0
96-
cb.buffer[i:i] .= data[]
97-
else
98-
cb.buffer[i:i] .= data
99-
end
114+
cb.buffer[i:i] .= Ref(data)
100115
else
101116
cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)] .= data
102117
end

test/runtests.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,23 @@ CUDA.allowscalar(false)
1111
# https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/551
1212
@testset "1D with 0d data" begin
1313
b = CircularArrayBuffer{Int}(3)
14-
push!(b, zeros(Int, ()))
14+
append!(b, zeros(Int, ())) # !!! not push!
1515
@test length(b) == 1
1616
@test b[1] == 0
1717
end
1818

19+
@testset "1D vector" begin
20+
b = CircularArrayBuffer([[1], [2, 3]])
21+
push!(b, [4, 5, 6])
22+
@test b == [[2, 3], [4, 5, 6]]
23+
end
24+
25+
@testset "1D Symbol" begin
26+
b = CircularArrayBuffer([:a, :b])
27+
push!(b, :c)
28+
@test b == [:b, :c]
29+
end
30+
1931
@testset "1D Int" begin
2032
b = CircularArrayBuffer{Int}(3)
2133

@@ -189,7 +201,7 @@ if CUDA.functional()
189201
# https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/551
190202
@testset "1D with 0d data" begin
191203
b = adapt(CuArray, CircularArrayBuffer{Int}(3))
192-
CUDA.@allowscalar push!(b, CUDA.zeros(Int, ()))
204+
append!(b, CUDA.zeros(Int, ())) # !!! not push!
193205
@test length(b) == 1
194206
@test CUDA.@allowscalar b[1] == 0
195207
end

0 commit comments

Comments
 (0)