From 6c4c119904f9739dc9e19ad380929fc8908fb703 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Wed, 5 Jul 2023 13:30:25 -0700 Subject: [PATCH 1/5] Unify argument checking for sampling and check more cases Currently sampling functions need to perform the same set of checks on the inputs and those checks are copied and pasted for each method. We can instead define a simple input validation function that can be used by all sampling functions so that any additional corner cases that need to be caught can be fixed in one place and propagated elsewhere. Relatedly, this adds checks for agreement between the length of the source array to be sampled and the array of weights (issue 871) as well as that the destination array is not larger than the source when sampling without replacement (issue 877). --- src/sampling.jl | 194 ++++++++++++++++------------------------------ test/wsampling.jl | 14 +++- 2 files changed, 81 insertions(+), 127 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index 357d8d9c4..b904ef898 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,3 +1,39 @@ +using Base: mightalias + +if isdefined(Base, :require_one_based_indexing) # TODO: use this directly once we require Julia 1.2+ + using Base: require_one_based_indexing +else + require_one_based_indexing(xs...) = + any((!) ∘ isone ∘ firstindex, xs) && throw(ArgumentError("non 1-based arrays are not supported")) +end + +function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, replace::Bool) + mightalias(input, output) && + throw(ArgumentError("destination array must not share memory with the source array")) + require_one_based_indexing(input, output) + n = length(input) + k = length(output) + if !replace && k > n + throw(DimensionMismatch("cannot draw $k samples of $n values without replacement")) + end + return (n, k) +end + +function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights, + output::AbstractArray, replace::Bool) + mightalias(output, weights) && + throw(ArgumentError("destination array must not share memory with weights array")) + _validate_sample_inputs(input, weights) + return _validate_sample_inputs(input, output, replace) +end + +function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights) + require_one_based_indexing(weights) + n = length(input) + nw = length(weights) + nw == n || throw(DimensionMismatch("source and weight arrays must have the same length, got $n and $nw")) + return n +end ########################################################### # @@ -10,16 +46,15 @@ using Random: Sampler, Random.GLOBAL_RNG ### Algorithms for sampling with replacement function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - s = Sampler(rng, 1:length(a)) + n, k = _validate_sample_inputs(a, x, true) + s = Sampler(rng, 1:n) b = a[1] - 1 if b == 0 - for i = 1:length(x) + for i = 1:k @inbounds x[i] = rand(rng, s) end else - for i = 1:length(x) + for i = 1:k @inbounds x[i] = b + rand(rng, s) end end @@ -36,12 +71,9 @@ and set `x[j] = a[i]`, with `n=length(a)` and `k=length(x)`. This algorithm consumes `k` random numbers. """ function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - s = Sampler(rng, 1:length(a)) - for i = 1:length(x) + n, k = _validate_sample_inputs(a, x, true) + s = Sampler(rng, 1:n) + for i = 1:k @inbounds x[i] = a[rand(rng, s)] end return x @@ -61,11 +93,7 @@ storeindices(n, k, T) = false # order results of a sampler that does not order automatically function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - n, k = length(a), length(x) + n, k = _validate_sample_inputs(a, x, true) # todo: if eltype(x) <: Real && eltype(a) <: Real, # in some cases it might be faster to check # issorted(a) to see if we can just sort x @@ -140,13 +168,7 @@ memory space. Suitable for the case where memory is tight. """ function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; initshuffle::Bool=true) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - n = length(a) - k = length(x) - k <= n || error("length(x) should not exceed length(a)") + n, k = _validate_sample_inputs(a, x, false) # initialize for i = 1:k @@ -200,13 +222,7 @@ faster than Knuth's algorithm especially when `n` is greater than `k`. It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling """ function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - n = length(a) - k = length(x) - k <= n || error("length(x) should not exceed length(a)") + n, k = _validate_sample_inputs(a, x, false) inds = Vector{Int}(undef, n) for i = 1:n @@ -240,13 +256,7 @@ However, if `k` is large and approaches ``n``, the rejection rate would increase drastically, resulting in poorer performance. """ function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - n = length(a) - k = length(x) - k <= n || error("length(x) should not exceed length(a)") + n, k = _validate_sample_inputs(a, x, false) s = Set{Int}() sizehint!(s, k) @@ -282,13 +292,7 @@ This algorithm consumes ``O(n)`` random numbers, with `n=length(a)`. The outputs are ordered. """ function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - n = length(a) - k = length(x) - k <= n || error("length(x) should not exceed length(a)") + n, k = _validate_sample_inputs(a, x, false) i = 0 j = 0 @@ -324,13 +328,7 @@ This algorithm consumes ``O(k^2)`` random numbers, with `k=length(x)`. The outputs are ordered. """ function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - n = length(a) - k = length(x) - k <= n || error("length(x) should not exceed length(a)") + n, k = _validate_sample_inputs(a, x, false) i = 0 j = 0 @@ -370,13 +368,7 @@ This algorithm consumes ``O(k)`` random numbers, with `k=length(x)`. The outputs are ordered. """ function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - N = length(a) - n = length(x) - n <= N || error("length(x) should not exceed length(a)") + N, n = _validate_sample_inputs(a, x, false) i = 0 j = 0 @@ -485,10 +477,7 @@ nor share memory with them, or the result may be incorrect. """ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; replace::Bool=true, ordered::Bool=false) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - k = length(x) + n, k = _validate_sample_inputs(a, x, replace) k == 0 && return x if replace # with replacement @@ -499,8 +488,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; end else # without replacement - k <= n || error("Cannot draw more samples without replacement.") - if ordered if n > 10 * k * k seqsample_c!(rng, a, x) @@ -582,8 +569,7 @@ Optionally specify a random number generator `rng` as the first argument (defaults to `Random.GLOBAL_RNG`). """ function sample(rng::AbstractRNG, wv::AbstractWeights) - 1 == firstindex(wv) || - throw(ArgumentError("non 1-based arrays are not supported")) + require_one_based_indexing(wv) t = rand(rng) * sum(wv) n = length(wv) i = 1 @@ -596,7 +582,10 @@ function sample(rng::AbstractRNG, wv::AbstractWeights) end sample(wv::AbstractWeights) = sample(Random.GLOBAL_RNG, wv) -sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] +function sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) + _validate_sample_inputs(a, wv) + return a[sample(rng, wv)] +end sample(a::AbstractArray, wv::AbstractWeights) = sample(Random.GLOBAL_RNG, a, wv) """ @@ -613,15 +602,8 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm: """ function direct_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) - for i = 1:length(x) + _, k = _validate_sample_inputs(a, wv, x, true) + for i = 1:k x[i] = a[sample(rng, wv)] end return x @@ -702,14 +684,7 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` ti for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers. """ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + n, k = _validate_sample_inputs(a, wv, x, true) # create alias table ap = Vector{Float64}(undef, n) @@ -718,7 +693,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, # sampling s = Sampler(rng, 1:n) - for i = 1:length(x) + for i = 1:k j = rand(rng, s) x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]] end @@ -740,15 +715,8 @@ and has overall time complexity ``O(n k)``. """ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) - k = length(x) + n, k = _validate_sample_inputs(a, wv, x, false) + k > 0 || return x w = Vector{Float64}(undef, n) copyto!(w, wv) @@ -786,15 +754,8 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. """ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) - k = length(x) + n, k = _validate_sample_inputs(a, wv, x, false) + k > 0 || return x # calculate keys for all items keys = randexp(rng, n) @@ -827,15 +788,7 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. """ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) - k = length(x) + n, k = _validate_sample_inputs(a, wv, x, false) k > 0 || return x # initialize priority queue @@ -900,15 +853,7 @@ processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; ordered::Bool=false) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) - k = length(x) + n, k = _validate_sample_inputs(a, wv, x, false) k > 0 || return x # initialize priority queue @@ -968,10 +913,8 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; replace::Bool=true, ordered::Bool=false) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - k = length(x) + n, k = _validate_sample_inputs(a, wv, x, replace) + k > 0 || return x if replace if ordered @@ -991,7 +934,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs end end else - k <= n || error("Cannot draw $k samples from $n samples without replacement.") efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered) end return x diff --git a/test/wsampling.jl b/test/wsampling.jl index d1de4c855..d13854ff5 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -143,6 +143,7 @@ end oz = OffsetArray(z, -4:5) @test_throws ArgumentError sample(weights(ox)) + @test_throws DimensionMismatch sample(x, weights(1:5)) for f in (sample!, wsample!, naive_wsample_norep!, efraimidis_a_wsample_norep!, efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep!) @@ -158,8 +159,19 @@ end @test_throws ArgumentError f(x, weights(x), x) @test_throws ArgumentError f(y, weights(view(x, 3:5)), view(x, 2:4)) @test_throws ArgumentError f(view(x, 2:4), weights(view(x, 3:5)), view(x, 1:2)) + + # Test that source and weight lengths agree + @test_throws DimensionMismatch f(x, weights(1:5), z) + + # Test that sampling without replacement can't draw more than what's available + if endswith(String(nameof(f)), "_norep!") + @test_throws DimensionMismatch f(x, weights(y), vcat(z, z)) + else + @test_throws DimensionMismatch f(x, weights(y), vcat(z, z); replace=false) + end + # This corner case should theoretically succeed # but it currently fails as Base.mightalias is not smart enough @test_broken f(y, weights(view(x, 5:6)), view(x, 2:4)) end -end \ No newline at end of file +end From d8dfe60d3ac2147b19cfa8d0faa8d118ae07d175 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Wed, 5 Jul 2023 13:53:38 -0700 Subject: [PATCH 2/5] Use testsets and loops to avoid copy-pastas in test/wsampling.jl --- test/wsampling.jl | 58 +++++++++++++++++------------------------------ 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/test/wsampling.jl b/test/wsampling.jl index d13854ff5..0391e2614 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -93,45 +93,29 @@ end import StatsBase: naive_wsample_norep!, efraimidis_a_wsample_norep!, efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep! -n = 10^5 -wv = weights([0.2, 0.8, 0.4, 0.6]) - -a = zeros(Int, 3, n) -for j = 1:n - naive_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(naive_wsample_norep!, 4:7, wv, zeros(Int, 2)) - -a = zeros(Int, 3, n) -for j = 1:n - efraimidis_a_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(efraimidis_a_wsample_norep!, 4:7, wv, zeros(Int, 2)) - -a = zeros(Int, 3, n) -for j = 1:n - efraimidis_ares_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(efraimidis_ares_wsample_norep!, 4:7, wv, zeros(Int, 2)) - -a = zeros(Int, 3, n) -for j = 1:n - efraimidis_aexpj_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(efraimidis_aexpj_wsample_norep!, 4:7, wv, zeros(Int, 2)) +@testset "Weighted sampling without replacement" begin + n = 10^5 + wv = weights([0.2, 0.8, 0.4, 0.6]) + + @testset "$f" for f in (naive_wsample_norep!, efraimidis_a_wsample_norep!, + efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep!) + a = zeros(Int, 3, n) + for j = 1:n + f(4:7, wv, view(a,:,j)) + end + check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) + test_rng_use(f, 4:7, wv, zeros(Int, 2)) + end -a = sample(4:7, wv, 3; replace=false, ordered=false) -check_wsample_norep(a, (4, 7), wv, -1; ordered=false) + a = sample(4:7, wv, 3; replace=false, ordered=false) + check_wsample_norep(a, (4, 7), wv, -1; ordered=false) -for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) - r = rev ? reverse(4:7) : (4:7) - r = T===Int ? r : T.(r) - aa = Int.(sample(r, wv, 3; replace=false, ordered=true)) - check_wsample_norep(aa, (4, 7), wv, -1; ordered=true, rev=rev) + for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) + r = rev ? reverse(4:7) : (4:7) + r = T===Int ? r : T.(r) + aa = Int.(sample(r, wv, 3; replace=false, ordered=true)) + check_wsample_norep(aa, (4, 7), wv, -1; ordered=true, rev=rev) + end end @testset "validation of inputs" begin From d56a40a9a23bf1cd17bcfd787f3b079046269c2d Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Wed, 5 Jul 2023 13:55:03 -0700 Subject: [PATCH 3/5] Don't access weight vector `.values` unnecessarily Not all `AbstractWeights` subtypes have that field, e.g. `UnitWeights`, but all have indexing defined, so that can be used instead of trying to index into the underlying array. --- src/sampling.jl | 10 +++++----- test/wsampling.jl | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index b904ef898..728e87fdf 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -760,7 +760,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, # calculate keys for all items keys = randexp(rng, n) for i in 1:n - @inbounds keys[i] = wv.values[i]/keys[i] + @inbounds keys[i] = wv[i]/keys[i] end # return items with largest keys @@ -797,7 +797,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, s = 0 @inbounds for _s in 1:n s = _s - w = wv.values[s] + w = wv[s] w < 0 && error("Negative weight found in weight vector at index $s") if w > 0 i += 1 @@ -812,7 +812,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds threshold = pq[1].first @inbounds for i in s+1:n - w = wv.values[i] + w = wv[i] w < 0 && error("Negative weight found in weight vector at index $i") w > 0 || continue key = w/randexp(rng) @@ -862,7 +862,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, s = 0 @inbounds for _s in 1:n s = _s - w = wv.values[s] + w = wv[s] w < 0 && error("Negative weight found in weight vector at index $s") if w > 0 i += 1 @@ -878,7 +878,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, X = threshold*randexp(rng) @inbounds for i in s+1:n - w = wv.values[i] + w = wv[i] w < 0 && error("Negative weight found in weight vector at index $i") w > 0 || continue X -= w diff --git a/test/wsampling.jl b/test/wsampling.jl index 0391e2614..e0ff9cc47 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -105,6 +105,11 @@ import StatsBase: naive_wsample_norep!, efraimidis_a_wsample_norep!, end check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) test_rng_use(f, 4:7, wv, zeros(Int, 2)) + # Check that the function is using the weight vector's own indexing method(s) + # by trying with `UnitWeights`, which doesn't store an underlying array and thus + # doesn't have a `values` field to access. Here we're effectively just ensuring + # there's no error thrown. + @test length(f(rand(4), uweights(Float64, 4), zeros(2))) == 2 end a = sample(4:7, wv, 3; replace=false, ordered=false) From 4a1e261f1074de473c1fa4d0460dc29693791268 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 10 Jul 2023 08:38:47 -0700 Subject: [PATCH 4/5] Improve error message Co-authored-by: Milan Bouchet-Valat --- src/sampling.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index 728e87fdf..90c151929 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -14,7 +14,8 @@ function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, re n = length(input) k = length(output) if !replace && k > n - throw(DimensionMismatch("cannot draw $k samples of $n values without replacement")) + throw(DimensionMismatch("cannot draw a sample of $k values from an array " * + "with $n values without replacement")) end return (n, k) end From faa481f2b3c83f368796a70677f4354cb8ad4fa7 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 10 Jul 2023 08:48:40 -0700 Subject: [PATCH 5/5] Don't have `_validate_sample_inputs` return lengths --- src/sampling.jl | 78 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index 90c151929..f185d37e8 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -17,7 +17,7 @@ function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, re throw(DimensionMismatch("cannot draw a sample of $k values from an array " * "with $n values without replacement")) end - return (n, k) + return nothing end function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights, @@ -25,7 +25,8 @@ function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights, mightalias(output, weights) && throw(ArgumentError("destination array must not share memory with weights array")) _validate_sample_inputs(input, weights) - return _validate_sample_inputs(input, output, replace) + _validate_sample_inputs(input, output, replace) + return nothing end function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights) @@ -33,7 +34,7 @@ function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights) n = length(input) nw = length(weights) nw == n || throw(DimensionMismatch("source and weight arrays must have the same length, got $n and $nw")) - return n + return nothing end ########################################################### @@ -47,8 +48,9 @@ using Random: Sampler, Random.GLOBAL_RNG ### Algorithms for sampling with replacement function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray) - n, k = _validate_sample_inputs(a, x, true) - s = Sampler(rng, 1:n) + _validate_sample_inputs(a, x, true) + k = length(x) + s = Sampler(rng, 1:length(a)) b = a[1] - 1 if b == 0 for i = 1:k @@ -72,9 +74,9 @@ and set `x[j] = a[i]`, with `n=length(a)` and `k=length(x)`. This algorithm consumes `k` random numbers. """ function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - n, k = _validate_sample_inputs(a, x, true) - s = Sampler(rng, 1:n) - for i = 1:k + _validate_sample_inputs(a, x, true) + s = Sampler(rng, 1:length(a)) + for i = 1:length(x) @inbounds x[i] = a[rand(rng, s)] end return x @@ -94,7 +96,9 @@ storeindices(n, k, T) = false # order results of a sampler that does not order automatically function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - n, k = _validate_sample_inputs(a, x, true) + _validate_sample_inputs(a, x, true) + n = length(a) + k = length(x) # todo: if eltype(x) <: Real && eltype(a) <: Real, # in some cases it might be faster to check # issorted(a) to see if we can just sort x @@ -169,7 +173,9 @@ memory space. Suitable for the case where memory is tight. """ function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; initshuffle::Bool=true) - n, k = _validate_sample_inputs(a, x, false) + _validate_sample_inputs(a, x, false) + n = length(a) + k = length(x) # initialize for i = 1:k @@ -223,7 +229,9 @@ faster than Knuth's algorithm especially when `n` is greater than `k`. It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling """ function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - n, k = _validate_sample_inputs(a, x, false) + _validate_sample_inputs(a, x, false) + n = length(a) + k = length(x) inds = Vector{Int}(undef, n) for i = 1:n @@ -257,7 +265,9 @@ However, if `k` is large and approaches ``n``, the rejection rate would increase drastically, resulting in poorer performance. """ function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - n, k = _validate_sample_inputs(a, x, false) + _validate_sample_inputs(a, x, false) + n = length(a) + k = length(x) s = Set{Int}() sizehint!(s, k) @@ -293,7 +303,9 @@ This algorithm consumes ``O(n)`` random numbers, with `n=length(a)`. The outputs are ordered. """ function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - n, k = _validate_sample_inputs(a, x, false) + _validate_sample_inputs(a, x, false) + n = length(a) + k = length(x) i = 0 j = 0 @@ -329,7 +341,9 @@ This algorithm consumes ``O(k^2)`` random numbers, with `k=length(x)`. The outputs are ordered. """ function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - n, k = _validate_sample_inputs(a, x, false) + _validate_sample_inputs(a, x, false) + n = length(a) + k = length(x) i = 0 j = 0 @@ -369,7 +383,9 @@ This algorithm consumes ``O(k)`` random numbers, with `k=length(x)`. The outputs are ordered. """ function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - N, n = _validate_sample_inputs(a, x, false) + _validate_sample_inputs(a, x, false) + N = length(a) + n = length(x) i = 0 j = 0 @@ -478,8 +494,10 @@ nor share memory with them, or the result may be incorrect. """ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; replace::Bool=true, ordered::Bool=false) - n, k = _validate_sample_inputs(a, x, replace) + _validate_sample_inputs(a, x, replace) + k = length(x) k == 0 && return x + n = length(a) if replace # with replacement if ordered @@ -603,8 +621,8 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm: """ function direct_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - _, k = _validate_sample_inputs(a, wv, x, true) - for i = 1:k + _validate_sample_inputs(a, wv, x, true) + for i = 1:length(x) x[i] = a[sample(rng, wv)] end return x @@ -685,7 +703,9 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` ti for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers. """ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - n, k = _validate_sample_inputs(a, wv, x, true) + _validate_sample_inputs(a, wv, x, true) + n = length(a) + k = length(x) # create alias table ap = Vector{Float64}(undef, n) @@ -716,8 +736,10 @@ and has overall time complexity ``O(n k)``. """ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - n, k = _validate_sample_inputs(a, wv, x, false) + _validate_sample_inputs(a, wv, x, false) + k = length(x) k > 0 || return x + n = length(a) w = Vector{Float64}(undef, n) copyto!(w, wv) @@ -755,8 +777,10 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. """ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - n, k = _validate_sample_inputs(a, wv, x, false) + _validate_sample_inputs(a, wv, x, false) + k = length(x) k > 0 || return x + n = length(a) # calculate keys for all items keys = randexp(rng, n) @@ -789,8 +813,10 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. """ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - n, k = _validate_sample_inputs(a, wv, x, false) + _validate_sample_inputs(a, wv, x, false) + k = length(x) k > 0 || return x + n = length(a) # initialize priority queue pq = Vector{Pair{Float64,Int}}(undef, k) @@ -854,8 +880,10 @@ processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; ordered::Bool=false) - n, k = _validate_sample_inputs(a, wv, x, false) + _validate_sample_inputs(a, wv, x, false) + k = length(x) k > 0 || return x + n = length(a) # initialize priority queue pq = Vector{Pair{Float64,Int}}(undef, k) @@ -914,8 +942,10 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; replace::Bool=true, ordered::Bool=false) - n, k = _validate_sample_inputs(a, wv, x, replace) + _validate_sample_inputs(a, wv, x, replace) + k = length(x) k > 0 || return x + n = length(a) if replace if ordered