forked from dmlc/rabit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathallreduce_mock.h
More file actions
175 lines (172 loc) · 5.89 KB
/
allreduce_mock.h
File metadata and controls
175 lines (172 loc) · 5.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
/*!
* \file allreduce_mock.h
* \brief Mock test module of AllReduce engine,
* insert failures in certain call point, to test if the engine is robust to failure
*
* \author Ignacio Cano, Tianqi Chen
*/
#ifndef RABIT_ALLREDUCE_MOCK_H
#define RABIT_ALLREDUCE_MOCK_H
#include <vector>
#include <map>
#include <sstream>
#include "../include/rabit/engine.h"
#include "../include/rabit/timer.h"
#include "./allreduce_robust.h"
namespace rabit {
namespace engine {
class AllreduceMock : public AllreduceRobust {
public:
// constructor
AllreduceMock(void) {
num_trial = 0;
force_local = 0;
report_stats = 0;
tsum_allreduce = 0.0;
}
// destructor
virtual ~AllreduceMock(void) {}
virtual void SetParam(const char *name, const char *val) {
AllreduceRobust::SetParam(name, val);
// additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val);
if (!strcmp(name, "mock")) {
MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d",
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
"invalid mock parameter");
mock_map[k] = 1;
}
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg);
tsum_allreduce += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
return AllreduceRobust::LoadCheckPoint(&dum, &com);
}
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint;
if (force_local == 0) {
AllreduceRobust::CheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com);
}
time_checkpoint = utils::GetTime();
double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
<< ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
tsum_allreduce = 0.0;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model);
}
protected:
// force checkpoint to local
int force_local;
// whether report statistics
int report_stats;
// sum of allreduce
double tsum_allreduce;
double time_checkpoint;
private:
struct DummySerializer : public ISerializable {
virtual void Load(IStream &fi) {
}
virtual void Save(IStream &fo) const {
}
};
struct ComboSerializer : public ISerializable {
ISerializable *lhs;
ISerializable *rhs;
const ISerializable *c_lhs;
const ISerializable *c_rhs;
ComboSerializer(ISerializable *lhs, ISerializable *rhs)
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
}
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
}
virtual void Load(IStream &fi) {
if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi);
}
virtual void Save(IStream &fo) const {
if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo);
}
};
// key to identify the mock stage
struct MockKey {
int rank;
int version;
int seqno;
int ntrial;
MockKey(void) {}
MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const {
return rank == b.rank &&
version == b.version &&
seqno == b.seqno &&
ntrial == b.ntrial;
}
inline bool operator<(const MockKey &b) const {
if (rank != b.rank) return rank < b.rank;
if (version != b.version) return version < b.version;
if (seqno != b.seqno) return seqno < b.seqno;
return ntrial < b.ntrial;
}
};
// number of failure trials
int num_trial;
// record all mock actions
std::map<MockKey, int> mock_map;
// used to generate all kinds of exceptions
inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) {
num_trial += 1;
fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name);
exit(-2);
}
}
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_MOCK_H