-
Notifications
You must be signed in to change notification settings - Fork 195
Fix some issues with sampling #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
6c4c119
d8dfe60
d56a40a
4a1e261
faa481f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,39 @@ | ||
using Base: mightalias | ||
|
||
if isdefined(Base, :require_one_based_indexing) # TODO: use this directly once we require Julia 1.2+ | ||
using Base: require_one_based_indexing | ||
else | ||
require_one_based_indexing(xs...) = | ||
any((!) ∘ isone ∘ firstindex, xs) && throw(ArgumentError("non 1-based arrays are not supported")) | ||
end | ||
|
||
function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, replace::Bool) | ||
mightalias(input, output) && | ||
throw(ArgumentError("destination array must not share memory with the source array")) | ||
require_one_based_indexing(input, output) | ||
n = length(input) | ||
k = length(output) | ||
if !replace && k > n | ||
throw(DimensionMismatch("cannot draw $k samples of $n values without replacement")) | ||
end | ||
return (n, k) | ||
end | ||
|
||
function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights, | ||
output::AbstractArray, replace::Bool) | ||
mightalias(output, weights) && | ||
throw(ArgumentError("destination array must not share memory with weights array")) | ||
_validate_sample_inputs(input, weights) | ||
return _validate_sample_inputs(input, output, replace) | ||
end | ||
|
||
function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights) | ||
require_one_based_indexing(weights) | ||
n = length(input) | ||
nw = length(weights) | ||
nw == n || throw(DimensionMismatch("source and weight arrays must have the same length, got $n and $nw")) | ||
return n | ||
end | ||
|
||
########################################################### | ||
# | ||
|
@@ -10,16 +46,15 @@ using Random: Sampler, Random.GLOBAL_RNG | |
### Algorithms for sampling with replacement | ||
|
||
function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
s = Sampler(rng, 1:length(a)) | ||
n, k = _validate_sample_inputs(a, x, true) | ||
s = Sampler(rng, 1:n) | ||
b = a[1] - 1 | ||
if b == 0 | ||
for i = 1:length(x) | ||
for i = 1:k | ||
@inbounds x[i] = rand(rng, s) | ||
end | ||
else | ||
for i = 1:length(x) | ||
for i = 1:k | ||
@inbounds x[i] = b + rand(rng, s) | ||
end | ||
end | ||
|
@@ -36,12 +71,9 @@ and set `x[j] = a[i]`, with `n=length(a)` and `k=length(x)`. | |
This algorithm consumes `k` random numbers. | ||
""" | ||
function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
s = Sampler(rng, 1:length(a)) | ||
for i = 1:length(x) | ||
n, k = _validate_sample_inputs(a, x, true) | ||
s = Sampler(rng, 1:n) | ||
for i = 1:k | ||
ararslan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@inbounds x[i] = a[rand(rng, s)] | ||
end | ||
return x | ||
|
@@ -61,11 +93,7 @@ storeindices(n, k, T) = false | |
|
||
# order results of a sampler that does not order automatically | ||
function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n, k = length(a), length(x) | ||
n, k = _validate_sample_inputs(a, x, true) | ||
# todo: if eltype(x) <: Real && eltype(a) <: Real, | ||
# in some cases it might be faster to check | ||
# issorted(a) to see if we can just sort x | ||
|
@@ -140,13 +168,7 @@ memory space. Suitable for the case where memory is tight. | |
""" | ||
function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; | ||
initshuffle::Bool=true) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
# initialize | ||
for i = 1:k | ||
|
@@ -200,13 +222,7 @@ faster than Knuth's algorithm especially when `n` is greater than `k`. | |
It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling | ||
""" | ||
function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
inds = Vector{Int}(undef, n) | ||
for i = 1:n | ||
|
@@ -240,13 +256,7 @@ However, if `k` is large and approaches ``n``, the rejection rate would increase | |
drastically, resulting in poorer performance. | ||
""" | ||
function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
s = Set{Int}() | ||
sizehint!(s, k) | ||
|
@@ -282,13 +292,7 @@ This algorithm consumes ``O(n)`` random numbers, with `n=length(a)`. | |
The outputs are ordered. | ||
""" | ||
function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
i = 0 | ||
j = 0 | ||
|
@@ -324,13 +328,7 @@ This algorithm consumes ``O(k^2)`` random numbers, with `k=length(x)`. | |
The outputs are ordered. | ||
""" | ||
function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
i = 0 | ||
j = 0 | ||
|
@@ -370,13 +368,7 @@ This algorithm consumes ``O(k)`` random numbers, with `k=length(x)`. | |
The outputs are ordered. | ||
""" | ||
function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
N = length(a) | ||
n = length(x) | ||
n <= N || error("length(x) should not exceed length(a)") | ||
N, n = _validate_sample_inputs(a, x, false) | ||
|
||
i = 0 | ||
j = 0 | ||
|
@@ -485,10 +477,7 @@ nor share memory with them, or the result may be incorrect. | |
""" | ||
function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; | ||
replace::Bool=true, ordered::Bool=false) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, x, replace) | ||
k == 0 && return x | ||
|
||
if replace # with replacement | ||
|
@@ -499,8 +488,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; | |
end | ||
|
||
else # without replacement | ||
k <= n || error("Cannot draw more samples without replacement.") | ||
|
||
if ordered | ||
if n > 10 * k * k | ||
seqsample_c!(rng, a, x) | ||
|
@@ -582,8 +569,7 @@ Optionally specify a random number generator `rng` as the first argument | |
(defaults to `Random.GLOBAL_RNG`). | ||
""" | ||
function sample(rng::AbstractRNG, wv::AbstractWeights) | ||
1 == firstindex(wv) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
require_one_based_indexing(wv) | ||
t = rand(rng) * sum(wv) | ||
n = length(wv) | ||
i = 1 | ||
|
@@ -596,7 +582,10 @@ function sample(rng::AbstractRNG, wv::AbstractWeights) | |
end | ||
sample(wv::AbstractWeights) = sample(Random.GLOBAL_RNG, wv) | ||
|
||
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] | ||
function sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) | ||
_validate_sample_inputs(a, wv) | ||
return a[sample(rng, wv)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's weird that this line isn't tested. |
||
end | ||
sample(a::AbstractArray, wv::AbstractWeights) = sample(Random.GLOBAL_RNG, a, wv) | ||
|
||
""" | ||
|
@@ -613,15 +602,8 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm: | |
""" | ||
function direct_sample!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) | ||
for i = 1:length(x) | ||
_, k = _validate_sample_inputs(a, wv, x, true) | ||
for i = 1:k | ||
x[i] = a[sample(rng, wv)] | ||
end | ||
return x | ||
|
@@ -702,14 +684,7 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` ti | |
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers. | ||
""" | ||
function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) | ||
n, k = _validate_sample_inputs(a, wv, x, true) | ||
|
||
# create alias table | ||
ap = Vector{Float64}(undef, n) | ||
|
@@ -718,7 +693,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, | |
|
||
# sampling | ||
s = Sampler(rng, 1:n) | ||
for i = 1:length(x) | ||
for i = 1:k | ||
j = rand(rng, s) | ||
x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]] | ||
end | ||
|
@@ -740,15 +715,8 @@ and has overall time complexity ``O(n k)``. | |
""" | ||
function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
k > 0 || return x | ||
|
||
w = Vector{Float64}(undef, n) | ||
copyto!(w, wv) | ||
|
@@ -786,20 +754,13 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. | |
""" | ||
function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function really part of the official API and needs checks of the arguments? IIRC I had never intended it to be called by any user directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And if users use the (IMO) intended There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't have thought it was intended to be user-facing at all except, as pointed out in #876, it's included in the manual. 😕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But it's not exported, is it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not, no. That said, there are implementations of three different Efraimidis-Spirakis algorithms (A, A-Res, and AExpJ), only one of which (AExpJ) is actually used internally by a function like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docs were added in #254. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
LOL amazing. My brain runs the GC often so 7 years ago is long gone. I definitely buy the argument that the separate, non-exported functions that each implement specific algorithms should not be considered user-facing and thus shouldn't need to perform the same kind of safety checks as those intended to be called directly. What gets me nervous is that there's nothing saying they aren't user-facing, hence issues like #876 and #877. Perhaps we could add admonitions to the docstrings, e.g.
? A bit tangential to this discussion but in the future we could do something for sampling algorithms as is done for sorting algorithms in Base: each algorithm gets a type that subtypes some abstract sampling algorithm type then the user may select a particular algorithm via a keyword argument to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, passing types via an Better perform checks anyway, except if this means we run checks twice when called from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Currently yes. I can add a flag to the internal checking function that makes it a no-op if called from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How expensive are the checks? Is there a noticeable performance difference between calling The alg keyword argument seems a reasonable suggestion for future refactorings. For the time being I would prefer adding a warning or note to the docstrings of these internal functions. I think it was a mistake to add them to the docs at all (also based on the initial + follow-up PRs), so I would be fine even with just removing them from the docs. They're not exported and IMO have never been part of the official API (or at least they were not supposed to be). |
||
k > 0 || return x | ||
|
||
# calculate keys for all items | ||
keys = randexp(rng, n) | ||
for i in 1:n | ||
@inbounds keys[i] = wv.values[i]/keys[i] | ||
@inbounds keys[i] = wv[i]/keys[i] | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# return items with largest keys | ||
|
@@ -827,15 +788,7 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. | |
""" | ||
function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
k > 0 || return x | ||
|
||
# initialize priority queue | ||
|
@@ -844,7 +797,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
s = 0 | ||
@inbounds for _s in 1:n | ||
s = _s | ||
w = wv.values[s] | ||
w = wv[s] | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
w < 0 && error("Negative weight found in weight vector at index $s") | ||
if w > 0 | ||
i += 1 | ||
|
@@ -859,7 +812,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
@inbounds threshold = pq[1].first | ||
|
||
@inbounds for i in s+1:n | ||
w = wv.values[i] | ||
w = wv[i] | ||
w < 0 && error("Negative weight found in weight vector at index $i") | ||
w > 0 || continue | ||
key = w/randexp(rng) | ||
|
@@ -900,15 +853,7 @@ processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random | |
function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray; | ||
ordered::Bool=false) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
k > 0 || return x | ||
|
||
# initialize priority queue | ||
|
@@ -917,7 +862,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
s = 0 | ||
@inbounds for _s in 1:n | ||
s = _s | ||
w = wv.values[s] | ||
w = wv[s] | ||
w < 0 && error("Negative weight found in weight vector at index $s") | ||
if w > 0 | ||
i += 1 | ||
|
@@ -933,7 +878,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
X = threshold*randexp(rng) | ||
|
||
@inbounds for i in s+1:n | ||
w = wv.values[i] | ||
w = wv[i] | ||
w < 0 && error("Negative weight found in weight vector at index $i") | ||
w > 0 || continue | ||
X -= w | ||
|
@@ -968,10 +913,8 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra | |
|
||
function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; | ||
replace::Bool=true, ordered::Bool=false) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, replace) | ||
k > 0 || return x | ||
|
||
if replace | ||
if ordered | ||
|
@@ -991,7 +934,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs | |
end | ||
end | ||
else | ||
k <= n || error("Cannot draw $k samples from $n samples without replacement.") | ||
efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered) | ||
end | ||
return x | ||
|
Uh oh!
There was an error while loading. Please reload this page.