Skip to content

Commit 03723d6

Browse files
committed
Make win_put/get/accum weights in double
1 parent ae493b7 commit 03723d6

File tree

9 files changed

+34
-35
lines changed

9 files changed

+34
-35
lines changed

bluefog/common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ struct TensorTableEntry {
233233
int device = CPU_DEVICE_ID;
234234
// Source and destination of ranks used in win ops.
235235
// It maps the src(dst) rank to the weight.
236-
std::unordered_map<int, float> dst_weights = {};
237-
std::unordered_map<int, float> src_weights = {};
236+
std::unordered_map<int, double> dst_weights = {};
237+
std::unordered_map<int, double> src_weights = {};
238238

239239
// The ops requires the mutex.
240240
bool require_mutex = false;

bluefog/common/mpi_controller.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ int MPIController::SetTopology(int indegree, const int* sources, int outdegree,
257257
}
258258

259259
int MPIController::SetTopologyWeights(int indegree, const int* sources,
260-
float self_weight, const float* neighbor_weights) {
260+
double self_weight, const double* neighbor_weights) {
261261
// We assume when this function is called, the base topology has already
262262
// been set. Here the neighbor_weights specifies the weights from the sources.
263263
if (!mpi_ctx_.IsTopoSetup()) {
@@ -281,8 +281,8 @@ int MPIController::LoadTopology(int* indegree, int*& sources, int* outdegree,
281281
}
282282

283283
int MPIController::LoadTopologyWeights(
284-
float& self_weight,
285-
const std::unordered_map<int, float>*& neighbor_weights) {
284+
double& self_weight,
285+
const std::unordered_map<int, double>*& neighbor_weights) {
286286
if (!mpi_ctx_.IsWeighted()) {
287287
return 0;
288288
}
@@ -538,7 +538,7 @@ void MPIController::WinPut(TensorTableEntry& entry) {
538538
Status timeline_status = GetBluefogTimeline(timeline_ptr);
539539
for (auto kv : entry.dst_weights) {
540540
int target_rank = kv.first;
541-
float weight = kv.second;
541+
double weight = kv.second;
542542

543543
BFLOG(TRACE, rank_) << "Start MPI_Put for " << entry.tensor_name << " to " << target_rank;
544544

@@ -603,7 +603,7 @@ void MPIController::WinAccumulate(TensorTableEntry& entry) {
603603

604604
for (auto kv : entry.dst_weights) {
605605
int target_rank = kv.first;
606-
float weight = kv.second;
606+
double weight = kv.second;
607607
// avoid putting the tensor for itself (NOT valid).
608608
if (target_rank == rank_) continue;
609609

bluefog/common/mpi_controller.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ class MPIController {
6666
int SetTopology(int indegree, const int* sources, int outdegree,
6767
const int* destinations);
6868
int SetTopologyWeights(int indegree, const int* sources,
69-
float self_weight, const float* neighbor_weights);
69+
double self_weight, const double* neighbor_weights);
7070
int LoadTopology(int* indegree, int*& sources, int* outdegree,
7171
int*& destinations);
72-
int LoadTopologyWeights(float& self_weight,
73-
const std::unordered_map<int, float>*& neighbor_weights);
72+
int LoadTopologyWeights(double& self_weight,
73+
const std::unordered_map<int, double>*& neighbor_weights);
7474

7575
Status WinCreate(std::shared_ptr<Tensor> tensor,
7676
std::vector<std::shared_ptr<Tensor>> neighbor_tensors,
@@ -126,8 +126,8 @@ class MPIController {
126126
// COMM_WORLD ranks of processes running on this node.
127127
std::vector<int> local_comm_ranks_;
128128

129-
float self_weight_;
130-
std::unordered_map<int, float> neighbor_weights_;
129+
double self_weight_;
130+
std::unordered_map<int, double> neighbor_weights_;
131131
};
132132

133133
class WinMutexGuard {

bluefog/common/operations.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ int bluefog_set_topology(int indegree, const int* sources, int outdegree,
283283

284284
int bluefog_set_topology_with_weights(int indegree, const int* sources,
285285
int outdegree, const int* destinations,
286-
float self_weight, const float* neighbor_weights) {
286+
double self_weight, const double* neighbor_weights) {
287287
int ret = bluefog_set_topology(indegree, sources, outdegree, destinations);
288288
if (ret != 1) {
289289
return ret;
@@ -302,8 +302,8 @@ int bluefog_load_topology(int* indegree, int*& sources, int* outdegree,
302302
}
303303

304304
int bluefog_load_topology_weights(
305-
float& self_weight_,
306-
const std::unordered_map<int, float>*& neighbor_weights_) {
305+
double& self_weight_,
306+
const std::unordered_map<int, double>*& neighbor_weights_) {
307307
if (!bluefog_global.initialization_done) {
308308
return -1;
309309
}
@@ -434,7 +434,7 @@ Status EnqueueTensorNeighborAllreduce(std::shared_ptr<OpContext> context,
434434

435435
Status EnqueueTensorWindowPut(std::shared_ptr<Tensor> tensor,
436436
const std::string& name,
437-
const std::unordered_map<int, float>& dst_weights,
437+
const std::unordered_map<int, double>& dst_weights,
438438
const int device,
439439
const bool require_mutex,
440440
StatusCallback callback) {
@@ -456,7 +456,7 @@ Status EnqueueTensorWindowPut(std::shared_ptr<Tensor> tensor,
456456

457457
Status EnqueueTensorWindowAccumulate(
458458
std::shared_ptr<Tensor> tensor, const std::string& name,
459-
const std::unordered_map<int, float>& dst_weights, const int device,
459+
const std::unordered_map<int, double>& dst_weights, const int device,
460460
const bool require_mutex, StatusCallback callback) {
461461
TensorTableEntry e;
462462
e.tensor_name = name;
@@ -475,7 +475,7 @@ Status EnqueueTensorWindowAccumulate(
475475
}
476476

477477
Status EnqueueTensorWindowGet(const std::string& name,
478-
const std::unordered_map<int, float>& src_weights,
478+
const std::unordered_map<int, double>& src_weights,
479479
const bool require_mutex,
480480
StatusCallback callback) {
481481
TensorTableEntry e;

bluefog/common/operations.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ int bluefog_set_topology(int indegree, const int* sources,
7575
// Bluefog is not initialized or failed.
7676
int bluefog_set_topology_with_weights(int indegree, const int* sources,
7777
int outdegree, const int* destinations,
78-
float self_weight, const float* neighbor_weights);
78+
double self_weight, const double* neighbor_weights);
7979

8080
// C interface to load the virtual topology for MPI graph communicator.
8181
// Self-rank is never included no matter self-loop is presented in setup or not.
@@ -85,8 +85,8 @@ int bluefog_load_topology(int* indegree, int*& sources,
8585

8686
// Load the weights for neighbors.
8787
// TODO(ybc) Make it as C compatible interface.
88-
int bluefog_load_topology_weights(float& self_weight,
89-
const std::unordered_map<int, float>*& neighbor_weights);
88+
int bluefog_load_topology_weights(double& self_weight,
89+
const std::unordered_map<int, double>*& neighbor_weights);
9090

9191

9292
// C interface to allow python to call timeline.
@@ -124,17 +124,17 @@ Status EnqueueTensorNeighborAllreduce(std::shared_ptr<OpContext> context,
124124

125125
Status EnqueueTensorWindowPut(std::shared_ptr<Tensor> tensor,
126126
const std::string& name,
127-
const std::unordered_map<int, float>& dst_ranks,
127+
const std::unordered_map<int, double>& dst_ranks,
128128
const int device, const bool require_mutex,
129129
StatusCallback callback);
130130

131131
Status EnqueueTensorWindowAccumulate(
132132
std::shared_ptr<Tensor> tensor, const std::string& name,
133-
const std::unordered_map<int, float>& dst_ranks, const int device,
133+
const std::unordered_map<int, double>& dst_ranks, const int device,
134134
const bool require_mutex, StatusCallback callback);
135135

136136
Status EnqueueTensorWindowGet(const std::string& name,
137-
const std::unordered_map<int, float>& src_ranks,
137+
const std::unordered_map<int, double>& src_ranks,
138138
const bool require_mutex,
139139
StatusCallback callback);
140140

bluefog/common/topology_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def PowerTwoRingGraph(size: int) -> nx.DiGraph:
5353
"""Generate graph topology such that each points only
5454
connected to a point such that the index difference is power of 2.
5555
56-
Example: A PowerTwoRingGraph with 16 nodes:
56+
Example: A PowerTwoRingGraph with 12 nodes:
5757
5858
.. plot::
5959
:context: close-figs

bluefog/torch/mpi_win_ops.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ int DoWinFree(const std::string& name) {
350350
}
351351

352352
int DoWinPut(::torch::Tensor tensor, const std::string& name,
353-
const std::unordered_map<int, float>& dst_weights,
353+
const std::unordered_map<int, double>& dst_weights,
354354
const bool require_mutex) {
355355
ThrowIfError(common::CheckInitialized());
356356

@@ -386,7 +386,7 @@ int DoWinPut(::torch::Tensor tensor, const std::string& name,
386386
}
387387

388388
int DoWinAccumulate(::torch::Tensor tensor, const std::string& name,
389-
const std::unordered_map<int, float>& dst_weights,
389+
const std::unordered_map<int, double>& dst_weights,
390390
const bool require_mutex) {
391391
ThrowIfError(common::CheckInitialized());
392392

@@ -421,7 +421,7 @@ int DoWinAccumulate(::torch::Tensor tensor, const std::string& name,
421421
}
422422

423423
int DoWinGet(const std::string& name,
424-
const std::unordered_map<int, float>& src_weights,
424+
const std::unordered_map<int, double>& src_weights,
425425
const bool require_mutex) {
426426
ThrowIfError(common::CheckInitialized());
427427

bluefog/torch/mpi_win_ops.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ WIN_SYNC_H(torch_cuda_DoubleTensor, THCudaDoubleTensor)
144144
#define WIN_PUT_H(torch_Tensor, THTensor) \
145145
extern "C" int bluefog_torch_win_put_##torch_Tensor( \
146146
THTensor* tensor, char* name, \
147-
const std::unordered_map<int, float>& dst_weights, \
147+
const std::unordered_map<int, double>& dst_weights, \
148148
const bool require_mutex);
149149

150150
WIN_PUT_H(torch_IntTensor, THIntTensor)
@@ -162,7 +162,7 @@ WIN_PUT_H(torch_cuda_DoubleTensor, THCudaDoubleTensor)
162162
#define WIN_ACCUMULATE_H(torch_Tensor, THTensor) \
163163
extern "C" int bluefog_torch_win_accumulate_##torch_Tensor( \
164164
THTensor* tensor, char* name, \
165-
const std::unordered_map<int, float>& dst_weights, \
165+
const std::unordered_map<int, double>& dst_weights, \
166166
const bool require_mutex);
167167

168168
WIN_ACCUMULATE_H(torch_IntTensor, THIntTensor)
@@ -178,7 +178,7 @@ WIN_ACCUMULATE_H(torch_cuda_DoubleTensor, THCudaDoubleTensor)
178178
#endif
179179

180180
extern "C" int bluefog_torch_win_GET(
181-
char* name, const std::unordered_map<int, float>& src_weights,
181+
char* name, const std::unordered_map<int, double>& src_weights,
182182
const bool require_mutex);
183183

184184
extern "C" int bluefog_torch_win_free(char* name);

examples/pytorch_logistic_regression.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,10 @@ def logistic_loss_step(x_, tensor_name):
277277
alpha_pd = 1e-1 # step-size for Push-DIGing
278278
mse_pd = []
279279
for i in range(maxite):
280-
if i % 10 == 0:
281-
bf.barrier()
282-
283280
w[:n] = w[:n] - alpha_pd*w[n:2*n]
284281
bf.win_accumulate(
285282
w, name="w_buff",
286-
dst_weights={rank: 0.5 / (outdegree)
283+
dst_weights={rank: 1.0 / (outdegree*2)
287284
for rank in bf.out_neighbor_ranks()},
288285
require_mutex=True)
289286
w.div_(2)
@@ -296,6 +293,8 @@ def logistic_loss_step(x_, tensor_name):
296293

297294
w[n:2*n] += grad - grad_prev
298295
grad_prev = grad
296+
if i % 10 == 0:
297+
bf.barrier()
299298
if bf.rank() == 0:
300299
mse_pd.append(torch.norm(x.data - w_opt, p=2))
301300

0 commit comments

Comments
 (0)