@@ -48,8 +48,8 @@ def dist_all_reduce(x, return_num=False, distributed=False):
48
48
n = len (x )
49
49
x_sum = 0 if n == 0 else np .sum (x )
50
50
if distributed :
51
- n = dist .all_reduce (paddle .to_tensor (n , dtype = 'int64' )). numpy ()[ 0 ]
52
- x_sum = dist .all_reduce (paddle .to_tensor (x_sum , dtype = 'float32' )). numpy ()[ 0 ]
51
+ n = int ( dist .all_reduce (paddle .to_tensor (n , dtype = 'int64' )))
52
+ x_sum = float ( dist .all_reduce (paddle .to_tensor (x_sum , dtype = 'float32' )))
53
53
x_mean = 0 if n == 0 else x_sum / n
54
54
if return_num :
55
55
return x_mean , n
@@ -62,8 +62,8 @@ def dist_mean(x, distributed=False):
62
62
n = len (x )
63
63
x_sum = 0 if n == 0 else np .sum (x )
64
64
if distributed :
65
- n = dist .all_reduce (paddle .to_tensor (n , dtype = 'int64' )). numpy ()[ 0 ]
66
- x_sum = dist .all_reduce (paddle .to_tensor (x_sum , dtype = 'float32' )). numpy ()[ 0 ]
65
+ n = int ( dist .all_reduce (paddle .to_tensor (n , dtype = 'int64' )))
66
+ x_sum = float ( dist .all_reduce (paddle .to_tensor (x_sum , dtype = 'float32' )))
67
67
x_mean = 0 if n == 0 else x_sum / n
68
68
return x_mean
69
69
@@ -73,15 +73,15 @@ def dist_sum(x, distributed=False):
73
73
n = len (x )
74
74
x_sum = 0 if n == 0 else np .sum (x )
75
75
if distributed :
76
- x_sum = dist .all_reduce (paddle .to_tensor (x_sum , dtype = 'float32' )). numpy ()[ 0 ]
76
+ x_sum = float ( dist .all_reduce (paddle .to_tensor (x_sum , dtype = 'float32' )))
77
77
return x_sum
78
78
79
79
80
80
def dist_length (x , distributed = False ):
81
81
"""tbd"""
82
82
n = len (x )
83
83
if distributed :
84
- n = dist .all_reduce (paddle .to_tensor (n , dtype = 'int64' )). numpy ()[ 0 ]
84
+ n = int ( dist .all_reduce (paddle .to_tensor (n , dtype = 'int64' )))
85
85
return n
86
86
87
87
0 commit comments