forked from dmlc/rabit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathallreduce_robust-inl.h
More file actions
161 lines (160 loc) · 5.92 KB
/
allreduce_robust-inl.h
File metadata and controls
161 lines (160 loc) · 5.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust-inl.h
* \brief implementation of inline template function in AllreduceRobust
*
* \author Tianqi Chen
*/
#ifndef RABIT_ENGINE_ROBUST_INL_H_
#define RABIT_ENGINE_ROBUST_INL_H_
#include <vector>
namespace rabit {
namespace engine {
/*!
* \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge
* \param p_edge_out used to store output message from each of the edge
* \param func a function that defines the message passing rule
* Parameters of func:
* - node_value same as node_value in the main function
* - edge_in the array of input messages from each edge,
* this includes the output edge, which should be excluded
* - out_index array the index of output edge, the function should
* exclude the output edge when compute the message passing value
* Return of func:
* the function returns the output message based on the input message and node_value
*
* \tparam EdgeType type of edge message, must be simple struct
* \tparam NodeType type of node value
*/
template<typename NodeType, typename EdgeType>
inline AllreduceRobust::ReturnType
AllreduceRobust::MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());
// initialize the pointers
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
std::vector<EdgeType> &edge_in = *p_edge_in;
std::vector<EdgeType> &edge_out = *p_edge_out;
edge_in.resize(nlink);
edge_out.resize(nlink);
// stages in the process
// 0: recv messages from childs
// 1: send message to parent
// 2: recv message from parent
// 3: send message to childs
int stage = 0;
// if no childs, no need to, directly start passing message
if (nlink == static_cast<int>(parent_index != -1)) {
utils::Assert(parent_index == 0, "parent must be 0");
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
}
// while we have not passed the messages out
while (true) {
// for node with no parent, directly do stage 3
if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
}
break;
case 1: if (i == parent_index) selecter.WatchWrite(links[i].sock); break;
case 2: if (i == parent_index) selecter.WatchRead(links[i].sock); break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
done = false;
}
break;
default: utils::Error("invalid stage");
}
}
// finish all the stages, and write out message
if (done) break;
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (stage == 0) {
bool finished = true;
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
if (links[i].size_read != sizeof(EdgeType)) finished = false;
}
}
// if no parent, jump to stage 3, otherwise do stage 1
if (finished) {
if (parent_index != -1) {
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
} else {
for (int i = 0; i < nlink; ++i) {
edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
}
if (stage == 1) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[pid], ret);
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
}
if (stage == 2) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[pid], ret);
if (links[pid].size_read == sizeof(EdgeType)) {
for (int i = 0; i < nlink; ++i) {
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
if (stage == 3) {
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_ROBUST_INL_H_