@@ -15,16 +15,25 @@ extractargs!(arguments::Vector{Symbol}, defined::Dict{Symbol,Symbol}, sym, mod)
15
15
# end
16
16
# nothing
17
17
# end
18
+
19
+ function _symbols (args)
20
+ s = Vector {Symbol} (undef, length (args))
21
+ for i in eachindex (args, s)
22
+ s[i] = args[i]
23
+ end
24
+ return s
25
+ end
26
+
18
27
function define_tup! (arguments:: Vector{Symbol} , defined:: Dict{Symbol,Symbol} , ex:: Expr , mod)
19
28
for (i, a) ∈ enumerate (ex. args)
20
29
if a isa Symbol
21
30
ex. args[i] = getgensym! (defined, a)
22
31
elseif Meta. isexpr (a, :tuple )
23
- define_tup! (Symbol[ a. args... ] , defined, a, mod)
32
+ define_tup! (_symbols ( a. args) , defined, a, mod)
24
33
elseif Meta. isexpr (a, :ref )
25
34
extractargs! (arguments, defined, a, mod)
26
35
elseif Meta. isexpr (a, :parameters )
27
- define_tup! (Symbol[ a. args... ] , defined, a, mod)
36
+ define_tup! (_symbols ( a. args) , defined, a, mod)
28
37
else
29
38
throw (" Don't know how to handle:\n $a " )
30
39
end
@@ -171,7 +180,12 @@ function extractargs!(
171
180
end
172
181
173
182
function symbolsubs (e:: Expr , old:: Symbol , new:: Symbol )
174
- return Expr (e. head, (symbolsubs (a, old, new) for a in e. args). .. )
183
+ ex = Expr (e. head)
184
+ resize! (ex. args, length (e. args))
185
+ for i in eachindex (e. args, ex. args)
186
+ ex. args[i] = symbolsubs (e. args[i], old, new)
187
+ end
188
+ return ex
175
189
end
176
190
function symbolsubs (e:: Symbol , old:: Symbol , new:: Symbol )
177
191
e == old ? new : e
@@ -365,7 +379,8 @@ function enclose(exorig::Expr, minbatchsize, per, threadlocal, reduction, stride
365
379
# threadlocal stuff
366
380
threadlocal_var_single = gensym (threadlocal_var)
367
381
threadlocal_val, threadlocal_type = threadlocal
368
- q_single = threadlocal_val === Symbol (" " ) ? exorig :
382
+ q_single =
383
+ threadlocal_val === Symbol (" " ) ? exorig :
369
384
symbolsubs (exorig, threadlocal_var, threadlocal_var_single)
370
385
# threadlocal_type = getfield(mod, threadlocal_type)
371
386
threadlocal_accum = Symbol (" ##THREADLOCAL##ACCUM##" )
@@ -378,10 +393,11 @@ function enclose(exorig::Expr, minbatchsize, per, threadlocal, reduction, stride
378
393
threadlocal_val === Symbol (" " ) ? donothing :
379
394
:($ (esc (threadlocal_var)) = [single_thread_result])
380
395
threadlocal_init =
381
- threadlocal_val === Symbol (" " ) ? donothing : quote
382
- $ (esc (threadlocal_accum)) =
383
- Vector {$threadlocal_type} (undef, max (1 , $ (threadtup. args[2 ])))
384
- end
396
+ threadlocal_val === Symbol (" " ) ? donothing :
397
+ quote
398
+ $ (esc (threadlocal_accum)) =
399
+ Vector {$threadlocal_type} (undef, max (1 , $ (threadtup. args[2 ])))
400
+ end
385
401
threadlocal_vect =
386
402
threadlocal_val === Symbol (" " ) ? donothing :
387
403
:($ (esc (threadlocal_var)) = multi_thread_result)
@@ -391,8 +407,7 @@ function enclose(exorig::Expr, minbatchsize, per, threadlocal, reduction, stride
391
407
threadlocal_set =
392
408
threadlocal_val === Symbol (" " ) ? donothing :
393
409
:($ threadlocal_accum[var"##THREAD##" ] = $ threadlocal_var_gen)
394
- threadlocal_return =
395
- threadlocal_val === Symbol (" " ) ? donothing : :($ threadlocal_accum)
410
+ threadlocal_return = threadlocal_val === Symbol (" " ) ? donothing : :($ threadlocal_accum)
396
411
threadlocal_val != = Symbol (" " ) && push! (q. args, threadlocal_init)
397
412
# reduction stuff
398
413
reduction_ops = Expr (:tuple )
@@ -650,7 +665,7 @@ macro batch(arg1, ex)
650
665
threadlocal,
651
666
reduction,
652
667
stride,
653
- __module__
668
+ __module__,
654
669
)
655
670
end
656
671
macro batch (arg1, arg2, ex)
@@ -665,7 +680,7 @@ macro batch(arg1, arg2, ex)
665
680
threadlocal,
666
681
reduction,
667
682
stride,
668
- __module__
683
+ __module__,
669
684
)
670
685
end
671
686
macro batch (arg1, arg2, arg3, ex)
@@ -682,7 +697,7 @@ macro batch(arg1, arg2, arg3, ex)
682
697
threadlocal,
683
698
reduction,
684
699
stride,
685
- __module__
700
+ __module__,
686
701
)
687
702
end
688
703
macro batch (arg1, arg2, arg3, arg4, ex)
@@ -701,7 +716,7 @@ macro batch(arg1, arg2, arg3, arg4, ex)
701
716
threadlocal,
702
717
reduction,
703
718
stride,
704
- __module__
719
+ __module__,
705
720
)
706
721
end
707
722
macro batch (arg1, arg2, arg3, arg4, arg5, ex)
@@ -722,6 +737,6 @@ macro batch(arg1, arg2, arg3, arg4, arg5, ex)
722
737
threadlocal,
723
738
reduction,
724
739
stride,
725
- __module__
740
+ __module__,
726
741
)
727
742
end
0 commit comments