@@ -139,8 +139,7 @@ void* allocate_device_mem(const size_t num_bytes, L0Device& device) {
139
139
return mem;
140
140
}
141
141
142
- L0DataFetcher::L0DataFetcher (const L0Driver& driver, ze_device_handle_t device)
143
- : device_(device), driver_(driver) {
142
+ L0Device::L0DataFetcher::L0DataFetcher (L0Device& device) : my_device_(device) {
144
143
ze_command_queue_desc_t command_queue_fetch_desc = {
145
144
ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
146
145
nullptr ,
@@ -149,77 +148,87 @@ L0DataFetcher::L0DataFetcher(const L0Driver& driver, ze_device_handle_t device)
149
148
0 ,
150
149
ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
151
150
ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
152
- L0_SAFE_CALL (zeCommandQueueCreate (
153
- driver.ctx (), device_, &command_queue_fetch_desc, &queue_handle_));
154
- current_cl_bytes = {{}, 0 };
155
- L0_SAFE_CALL (
156
- zeCommandListCreate (driver.ctx (), device_, &cl_desc, ¤t_cl_bytes.first ));
157
- }
158
-
159
- L0DataFetcher::~L0DataFetcher () {
151
+ L0_SAFE_CALL (zeCommandQueueCreate (my_device_.driver_ .ctx (),
152
+ my_device_.device_ ,
153
+ &command_queue_fetch_desc,
154
+ &queue_handle_));
155
+ cur_cl_bytes_ = {{}, 0 };
156
+ L0_SAFE_CALL (zeCommandListCreate (my_device_.driver_ .ctx (),
157
+ my_device_.device_ ,
158
+ &cl_desc_,
159
+ &cur_cl_bytes_.cl_handle_ ));
160
+ }
161
+
162
+ L0Device::L0DataFetcher::~L0DataFetcher () {
160
163
zeCommandQueueDestroy (queue_handle_);
161
- zeCommandListDestroy (current_cl_bytes. first );
162
- for (auto & dead_handle : graveyard ) {
164
+ zeCommandListDestroy (cur_cl_bytes_. cl_handle_ );
165
+ for (auto & dead_handle : graveyard_ ) {
163
166
zeCommandListDestroy (dead_handle);
164
167
}
165
- for (auto & cl_handle : recycled ) {
168
+ for (auto & cl_handle : recycled_ ) {
166
169
zeCommandListDestroy (cl_handle);
167
170
}
168
171
}
169
172
170
- void L0DataFetcher::recycleGraveyard () {
171
- while (recycled .size () < GRAVEYARD_LIMIT && graveyard .size ()) {
172
- recycled .push_back (graveyard .front ());
173
- graveyard .pop_front ();
174
- L0_SAFE_CALL (zeCommandListReset (recycled .back ()));
173
+ void L0Device:: L0DataFetcher::recycleGraveyard () {
174
+ while (recycled_ .size () < GRAVEYARD_LIMIT && graveyard_ .size ()) {
175
+ recycled_ .push_back (graveyard_ .front ());
176
+ graveyard_ .pop_front ();
177
+ L0_SAFE_CALL (zeCommandListReset (recycled_ .back ()));
175
178
}
176
- for (auto & dead_handle : graveyard) {
177
- L0_SAFE_CALL (zeCommandListDestroy (recycled.back ()));
179
+ for (auto & dead_handle : graveyard_) {
180
+ L0_SAFE_CALL (zeCommandListDestroy (dead_handle));
181
+ }
182
+ graveyard_.clear ();
183
+ }
184
+
185
+ void L0Device::L0DataFetcher::setCLRecycledOrNew () {
186
+ cur_cl_bytes_ = {{}, 0 };
187
+ if (recycled_.size ()) {
188
+ cur_cl_bytes_.cl_handle_ = recycled_.front ();
189
+ recycled_.pop_front ();
190
+ } else {
191
+ L0_SAFE_CALL (zeCommandListCreate (my_device_.driver_ .ctx (),
192
+ my_device_.device_ ,
193
+ &cl_desc_,
194
+ &cur_cl_bytes_.cl_handle_ ));
178
195
}
179
- graveyard.clear ();
180
196
}
181
197
182
- void L0DataFetcher::appendCopyCommand (void * dst,
183
- const void * src,
184
- const size_t num_bytes) {
185
- std::unique_lock<std::mutex> cl_lock (current_cl_lock );
198
+ void L0Device:: L0DataFetcher::appendCopyCommand (void * dst,
199
+ const void * src,
200
+ const size_t num_bytes) {
201
+ std::unique_lock<std::mutex> cl_lock (cur_cl_lock_ );
186
202
L0_SAFE_CALL (zeCommandListAppendMemoryCopy (
187
- current_cl_bytes.first , dst, src, num_bytes, nullptr , 0 , nullptr ));
188
- current_cl_bytes.second += num_bytes;
189
- if (current_cl_bytes.second >= 128 * 1024 * 1024 ) {
190
- ze_command_list_handle_t cl_h_copy = current_cl_bytes.first ;
191
- graveyard.push_back (current_cl_bytes.first );
192
- current_cl_bytes = {{}, 0 };
193
- if (recycled.size ()) {
194
- current_cl_bytes.first = recycled.front ();
195
- recycled.pop_front ();
196
- } else {
197
- L0_SAFE_CALL (
198
- zeCommandListCreate (driver_.ctx (), device_, &cl_desc, ¤t_cl_bytes.first ));
199
- }
203
+ cur_cl_bytes_.cl_handle_ , dst, src, num_bytes, nullptr , 0 , nullptr ));
204
+ cur_cl_bytes_.bytes_ += num_bytes;
205
+ if (cur_cl_bytes_.bytes_ >= CL_BYTES_LIMIT) {
206
+ ze_command_list_handle_t cl_h_copy = cur_cl_bytes_.cl_handle_ ;
207
+ graveyard_.push_back (cur_cl_bytes_.cl_handle_ );
208
+ setCLRecycledOrNew ();
200
209
cl_lock.unlock ();
201
210
L0_SAFE_CALL (zeCommandListClose (cl_h_copy));
202
211
L0_SAFE_CALL (
203
212
zeCommandQueueExecuteCommandLists (queue_handle_, 1 , &cl_h_copy, nullptr ));
204
213
}
205
214
}
206
215
207
- void L0DataFetcher::sync () {
208
- if (current_cl_bytes. second ) {
209
- L0_SAFE_CALL (zeCommandListClose (current_cl_bytes. first ));
216
+ void L0Device:: L0DataFetcher::sync () {
217
+ if (cur_cl_bytes_. bytes_ ) {
218
+ L0_SAFE_CALL (zeCommandListClose (cur_cl_bytes_. cl_handle_ ));
210
219
L0_SAFE_CALL (zeCommandQueueExecuteCommandLists (
211
- queue_handle_, 1 , ¤t_cl_bytes. first , nullptr ));
220
+ queue_handle_, 1 , &cur_cl_bytes_. cl_handle_ , nullptr ));
212
221
}
213
222
L0_SAFE_CALL (
214
223
zeCommandQueueSynchronize (queue_handle_, std::numeric_limits<uint32_t >::max ()));
215
- L0_SAFE_CALL (zeCommandListReset (current_cl_bytes. first ));
216
- if (graveyard .size () > GRAVEYARD_LIMIT) {
224
+ L0_SAFE_CALL (zeCommandListReset (cur_cl_bytes_. cl_handle_ ));
225
+ if (graveyard_ .size () > GRAVEYARD_LIMIT) {
217
226
recycleGraveyard ();
218
227
}
219
228
}
220
229
221
230
L0Device::L0Device (const L0Driver& driver, ze_device_handle_t device)
222
- : device_(device), driver_(driver), data_fetcher(driver, device ) {
231
+ : device_(device), driver_(driver), data_fetcher_(* this ) {
223
232
ze_command_queue_handle_t queue_handle;
224
233
ze_command_queue_desc_t command_queue_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
225
234
nullptr ,
@@ -271,6 +280,14 @@ unsigned L0Device::maxSharedLocalMemory() const {
271
280
return compute_props_.maxSharedLocalMemory ;
272
281
}
273
282
283
+ void L0Device::transferToDevice (void * dst, const void * src, const size_t num_bytes) {
284
+ data_fetcher_.appendCopyCommand (dst, src, num_bytes);
285
+ }
286
+
287
+ void L0Device::syncDataTransfers () {
288
+ data_fetcher_.sync ();
289
+ }
290
+
274
291
L0CommandQueue::L0CommandQueue (ze_command_queue_handle_t handle) : handle_(handle) {}
275
292
276
293
ze_command_queue_handle_t L0CommandQueue::handle () const {
@@ -420,7 +437,7 @@ void L0Manager::copyHostToDeviceAsync(int8_t* device_ptr,
420
437
CHECK_LT (device_num, drivers_[0 ]->devices ().size ());
421
438
422
439
auto & device = drivers ()[0 ]->devices ()[device_num];
423
- device->data_fetcher . appendCopyCommand (device_ptr, host_ptr, num_bytes);
440
+ device->transferToDevice (device_ptr, host_ptr, num_bytes);
424
441
}
425
442
426
443
void L0Manager::copyHostToDeviceAsyncIfPossible (int8_t * device_ptr,
@@ -438,7 +455,7 @@ void L0Manager::synchronizeDeviceDataStream(const int device_num) {
438
455
CHECK_GE (device_num, 0 );
439
456
CHECK_LT (device_num, drivers_[0 ]->devices ().size ());
440
457
auto & device = drivers ()[0 ]->devices ()[device_num];
441
- device->data_fetcher . sync ();
458
+ device->syncDataTransfers ();
442
459
}
443
460
444
461
void L0Manager::copyDeviceToHost (int8_t * host_ptr,
0 commit comments