Skip to content

Commit 4a50321

Browse files
authored
Fix getindex (#202)
1 parent 553ea01 commit 4a50321

File tree

3 files changed

+72
-63
lines changed

3 files changed

+72
-63
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.14.0"
3+
version = "1.14.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/tracked.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,6 @@ Base.promote_rule(::Type{TrackedReal{V1,D1,O1}}, ::Type{TrackedReal{V2,D2,O2}})
280280
# AbstractArray Interface #
281281
###########################
282282

283-
Base.getindex(t::TrackedArray, i::Int) = TrackedReal(value(t)[i], deriv(t)[i], tape(t), i, t)
284-
285283
colon2range(s, i) = i
286284
colon2range(s, ::Colon) = s
287285

@@ -296,10 +294,10 @@ function index_iterable(shape::NTuple{N,Any}, i::NTuple{M,Any}) where {N,M}
296294
end
297295

298296
for T in (:AbstractRange, :Colon, :(Union{Colon,AbstractRange}))
299-
@eval function Base.getindex(t::TrackedArray, i::$(T)...)
297+
@eval Base.@propagate_inbounds function Base.getindex(t::TrackedArray, i1::$(T), is::$(T)...)
300298
tp = tape(t)
301-
out = TrackedArray(value(t)[i...], deriv(t)[i...], tp)
302-
idx = index_iterable(axes(t), i)
299+
out = TrackedArray(value(t)[i1, is...], deriv(t)[i1, is...], tp)
300+
idx = index_iterable(axes(t), (i1, is...))
303301
record!(tp, SpecialInstruction, getindex, (t, idx), out)
304302
return out
305303
end
@@ -329,24 +327,25 @@ end
329327
return nothing
330328
end
331329

332-
function Base.getindex(t::TrackedArray, inds::AbstractArray{<:CartesianIndex})
330+
Base.@propagate_inbounds function Base.getindex(t::TrackedArray, inds::AbstractArray{<:CartesianIndex})
333331
tp = tape(t)
334332
out = TrackedArray(value(t)[inds], deriv(t)[inds], tp)
335333
record!(tp, SpecialInstruction, getindex, (t, inds), out)
336334
return out
337335
end
338-
function Base.getindex(t::TrackedArray, i::Int...)
339-
ind = LinearIndices(t)[i...]
340-
return TrackedReal(value(t)[i...], deriv(t)[i...], tape(t), ind, t)
336+
Base.@propagate_inbounds function Base.getindex(t::TrackedArray, i1::Integer, is::Integer...)
337+
ind = LinearIndices(t)[i1, is...]
338+
return TrackedReal(value(t)[i1, is...], deriv(t)[i1, is...], tape(t), ind, t)
341339
end
342-
function Base.getindex(t::TrackedArray, _inds::Union{Integer, Colon, AbstractArray{<:Integer}}...)
343-
inds = ntuple(Val(length(_inds))) do i
344-
_inds[i] isa Colon && return firstindex(t,i):lastindex(t,i)
345-
return _inds[i]
340+
Base.@propagate_inbounds function Base.getindex(t::TrackedArray, _inds1::Union{Integer, Colon, AbstractArray{<:Integer}}, _inds2::Union{Integer, Colon, AbstractArray{<:Integer}}...)
341+
inds1 = _inds1 isa Colon ? axes(t, 1) : _inds1
342+
inds2 = ntuple(Val(length(_inds2))) do i
343+
_inds2[i] isa Colon && return axes(t, i+1)
344+
return _inds2[i]
346345
end
347346
tp = tape(t)
348-
out = TrackedArray(value(t)[inds...], deriv(t)[inds...], tp)
349-
record!(tp, SpecialInstruction, (getindex, Val(:generic)), (t, inds), out)
347+
out = TrackedArray(value(t)[inds1, inds2...], deriv(t)[inds1, inds2...], tp)
348+
record!(tp, SpecialInstruction, (getindex, Val(:generic)), (t, (inds1, inds2...)), out)
350349
return out
351350
end
352351
@noinline function special_reverse_exec!(instruction::SpecialInstruction{<:Tuple{typeof(getindex), Val{:generic}}})

test/TrackedTests.jl

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,12 @@ ta = TrackedArray(varr, darr, tp)
616616

617617
@test isa(similar(ta), Matrix{eltype(ta)})
618618

619-
@test samefields(ta[2], TrackedReal(varr[2], darr[2], tp, 2, ta))
619+
for T in (UInt, Int)
620+
@test samefields(ta[T(2)], TrackedReal(varr[2], darr[2], tp, 2, ta))
621+
@test samefields(ta[T(2), T(3)], TrackedReal(varr[2, 3], darr[2, 3], tp, 8, ta))
622+
S = T === UInt ? Int : UInt
623+
@test samefields(ta[S(2), T(3)], TrackedReal(varr[2, 3], darr[2, 3], tp, 8, ta))
624+
end
620625

621626
ta_sub = ta[:,:]
622627
idx = ReverseDiff.index_iterable(axes(ta), (:, :))
@@ -630,53 +635,58 @@ instr = tp[1]
630635
@test instr.cache === nothing
631636
empty!(tp)
632637

633-
ta_sub = ta[:,1:2]
634-
idx = ReverseDiff.index_iterable(axes(ta), (:, 1:2))
635-
@test collect(idx) == [(i, j) for i in 1:3, j in 1:2]
636-
@test samefields(ta_sub, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
637-
@test length(tp) == 1
638-
instr = tp[1]
639-
@test instr.func === getindex
640-
@test instr.input === (ta, idx)
641-
@test samefields(instr.output, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
642-
@test instr.cache === nothing
643-
empty!(tp)
644-
645-
ta_sub = ta[2:3,:]
646-
idx = ReverseDiff.index_iterable(axes(ta), (2:3, :))
647-
@test collect(idx) == [(i, j) for i in 2:3, j in 1:3]
648-
@test samefields(ta_sub, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
649-
@test length(tp) == 1
650-
instr = tp[1]
651-
@test instr.func === getindex
652-
@test instr.input === (ta, idx)
653-
@test samefields(instr.output, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
654-
@test instr.cache === nothing
655-
empty!(tp)
656-
657-
ta_sub = ta[1:2,2:3]
658-
idx = ReverseDiff.index_iterable(axes(ta), (1:2, 2:3))
659-
@test collect(idx) == [(i, j) for i in 1:2, j in 2:3]
660-
@test samefields(ta_sub, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
661-
@test length(tp) == 1
662-
instr = tp[1]
663-
@test instr.func === getindex
664-
@test instr.input === (ta, idx)
665-
@test samefields(instr.output, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
666-
@test instr.cache === nothing
667-
empty!(tp)
638+
for T in (UInt, Int)
639+
ta_sub = ta[:,T(1):T(2)]
640+
idx = ReverseDiff.index_iterable(axes(ta), (:, T(1):T(2)))
641+
@test collect(idx) == [(i, j) for i in 1:3, j in 1:2]
642+
@test samefields(ta_sub, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
643+
@test length(tp) == 1
644+
instr = tp[1]
645+
@test instr.func === getindex
646+
@test instr.input === (ta, idx)
647+
@test samefields(instr.output, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
648+
@test instr.cache === nothing
649+
empty!(tp)
650+
651+
ta_sub = ta[T(2):T(3),:]
652+
idx = ReverseDiff.index_iterable(axes(ta), (T(2):T(3), :))
653+
@test collect(idx) == [(i, j) for i in 2:3, j in 1:3]
654+
@test samefields(ta_sub, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
655+
@test length(tp) == 1
656+
instr = tp[1]
657+
@test instr.func === getindex
658+
@test instr.input === (ta, idx)
659+
@test samefields(instr.output, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
660+
@test instr.cache === nothing
661+
empty!(tp)
662+
663+
S = T === UInt ? Int : UInt
664+
for U in (S, T)
665+
ta_sub = ta[S(1):S(2),T(2):T(3)]
666+
idx = ReverseDiff.index_iterable(axes(ta), (S(1):S(2), T(2):T(3)))
667+
@test collect(idx) == [(i, j) for i in 1:2, j in 2:3]
668+
@test samefields(ta_sub, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
669+
@test length(tp) == 1
670+
instr = tp[1]
671+
@test instr.func === getindex
672+
@test instr.input === (ta, idx)
673+
@test samefields(instr.output, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
674+
@test instr.cache === nothing
675+
empty!(tp)
676+
end
668677

669-
ta_sub = ta[2:6]
670-
idx = ReverseDiff.index_iterable(axes(ta), (2:6,))
671-
@test collect(idx) == [(i,) for i in 2:6]
672-
@test samefields(ta_sub, TrackedArray(varr[2:6], darr[2:6], tp))
673-
@test length(tp) == 1
674-
instr = tp[1]
675-
@test instr.func === getindex
676-
@test instr.input === (ta, idx)
677-
@test samefields(instr.output, TrackedArray(varr[2:6], darr[2:6], tp))
678-
@test instr.cache === nothing
679-
empty!(tp)
678+
ta_sub = ta[T(2):T(6)]
679+
idx = ReverseDiff.index_iterable(axes(ta), (T(2):T(6),))
680+
@test collect(idx) == [(i,) for i in 2:6]
681+
@test samefields(ta_sub, TrackedArray(varr[2:6], darr[2:6], tp))
682+
@test length(tp) == 1
683+
instr = tp[1]
684+
@test instr.func === getindex
685+
@test instr.input === (ta, idx)
686+
@test samefields(instr.output, TrackedArray(varr[2:6], darr[2:6], tp))
687+
@test instr.cache === nothing
688+
empty!(tp)
689+
end
680690

681691
ta_sub = ta[:]
682692
idx = ReverseDiff.index_iterable(axes(ta), (:,))

0 commit comments

Comments
 (0)