Skip to content

Commit bee678e

Browse files
committed
Fix paddle.arange.
1 parent ca93d0c commit bee678e

File tree

1 file changed

+1
-1
lines changed
  • apps/protein_folding/helixfold/alphafold_paddle/model

1 file changed

+1
-1
lines changed

apps/protein_folding/helixfold/alphafold_paddle/model/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def batched_gather(params, indices, axis=0, batch_dims=0):
188188

189189
# indices = [Batch..., Index...]
190190
# Expand the index values across batch elements
191-
strides = paddle.arange(bn).unsqueeze(-1) * stride
191+
strides = paddle.arange(bn, dtype="int64").unsqueeze(-1) * stride
192192
i = i.reshape([bn, -1])
193193
flat_i = paddle.flatten(i + strides)
194194

0 commit comments

Comments
 (0)