gridwise_normalization_bwd_data.hpp Source File

gridwise_normalization_bwd_data.hpp Source File#

Composable Kernel: gridwise_normalization_bwd_data.hpp Source File
gridwise_normalization_bwd_data.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13// Tensor Shape
14// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
15
16// Flow:
17// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
18// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
19// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
20// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
21// c = -b * x_mean - db * inv_std / reduce_size
22// dx = inv_std * dy * gamma + b * x + c
23// return dx
24
25template <typename DYDataType,
26 typename XDataType,
27 typename GammaDataType,
28 typename MeanInvStdDataType,
29 typename ComputeDataType,
30 typename DXDataType,
31 typename GridDesc_M_K,
32 index_t BlockSize,
33 index_t MThreadClusterSize,
34 index_t KThreadClusterSize,
35 index_t MThreadSliceSize,
36 index_t KThreadSliceSize,
37 index_t DYSrcVectorDim,
38 index_t DYSrcVectorSize,
39 index_t XSrcVectorDim,
40 index_t XSrcVectorSize,
41 index_t GammaSrcVectorDim,
42 index_t GammaSrcVectorSize,
43 index_t MeanInvStdSrcVectorDim,
44 index_t MeanInvStdSrcVectorSize,
45 index_t DXDstVectorDim,
46 index_t DXDstVectorSize,
47 bool SweepOnce>
49{
50 // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
51 static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
52 (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
53 "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
54
55 static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
56 (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
57 "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
58
59 static_assert(
60 ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) ||
61 (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)),
62 "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
63
64 static_assert(
65 ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) ||
66 (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)),
67 "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
68
69 static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) ||
70 (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)),
71 "Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
72
74
85
87 static constexpr auto thread_cluster_desc =
89
91
94
95 static constexpr auto thread_buffer_desc_m =
97
99
101 BlockSize,
105 true>;
106
107 static constexpr auto I0 = Number<0>{};
108 static constexpr auto I1 = Number<1>{};
109 static constexpr auto I2 = Number<2>{};
110
111 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
112 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
113
114 __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
115 const GridDesc_M_K& x_grid_desc_m_k,
116 const GridDesc_M_K& gamma_grid_desc_m_k,
117 const GridDesc_M_K& mean_grid_desc_m_k,
118 const GridDesc_M_K& inv_std_grid_desc_m_k,
119 const GridDesc_M_K& dx_grid_desc_m_k,
120 index_t num_k_block_tile_iteration,
121 const DYDataType* const __restrict__ p_dy_global,
122 const XDataType* const __restrict__ p_x_global,
123 const GammaDataType* const __restrict__ p_gamma_global,
124 const MeanInvStdDataType* const __restrict__ p_mean_global,
125 const MeanInvStdDataType* const __restrict__ p_inv_std_global,
126 DXDataType* const __restrict__ p_dx_global)
127 {
128 // LDS
129 __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
130
131 auto reduce_work_buf =
132 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
133
134 // Global
135 const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
136 p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
137
138 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
139 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
140
141 auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
142 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
143
144 const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
145 p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
146
147 const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
148 p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
149
150 auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
151 p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize());
152
153 // VGPR
154 auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
155 ComputeDataType,
156 MThreadSliceSize * KThreadSliceSize,
157 true>{};
158
159 auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
160 ComputeDataType,
161 MThreadSliceSize * KThreadSliceSize,
162 true>{};
163
164 auto gamma_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
165 ComputeDataType,
166 MThreadSliceSize * KThreadSliceSize,
167 true>{};
168
169 auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
170 ComputeDataType,
171 MThreadSliceSize * KThreadSliceSize,
172 true>{};
173
174 auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
175 ComputeDataType,
176 MThreadSliceSize * KThreadSliceSize,
177 true>{};
178
179 auto dx_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
180 ComputeDataType,
181 MThreadSliceSize * KThreadSliceSize,
182 true>{};
183
184 auto ds_thread_buf =
186
187 auto db_thread_buf =
189
190 // thread id
191 const index_t thread_local_id = get_thread_local_1d_id();
192 const index_t block_global_id = get_block_1d_id();
193
194 const auto thread_cluster_idx =
195 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
196
197 const auto thread_m_cluster_id = thread_cluster_idx[I0];
198 const auto thread_k_cluster_id = thread_cluster_idx[I1];
199
200 // IO
201 auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
202 ComputeDataType,
203 GridDesc_M_K,
204 decltype(thread_buffer_desc_m_k),
207 DYSrcVectorDim,
208 DYSrcVectorSize,
209 1,
210 false>(
211 dy_grid_desc_m_k,
212 make_multi_index(block_global_id * M_BlockTileSize +
213 thread_m_cluster_id * MThreadSliceSize,
214 thread_k_cluster_id * KThreadSliceSize));
215
216 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
217 ComputeDataType,
218 GridDesc_M_K,
219 decltype(thread_buffer_desc_m_k),
222 XSrcVectorDim,
223 XSrcVectorSize,
224 1,
225 false>(
226 x_grid_desc_m_k,
227 make_multi_index(block_global_id * M_BlockTileSize +
228 thread_m_cluster_id * MThreadSliceSize,
229 thread_k_cluster_id * KThreadSliceSize));
230
231 auto threadwise_gamma_load =
233 ComputeDataType,
234 GridDesc_M_K,
235 decltype(thread_buffer_desc_m_k),
238 GammaSrcVectorDim,
239 GammaSrcVectorSize,
240 1,
241 false>(
242 gamma_grid_desc_m_k,
243 make_multi_index(block_global_id * M_BlockTileSize +
244 thread_m_cluster_id * MThreadSliceSize,
245 thread_k_cluster_id * KThreadSliceSize));
246
247 auto threadwise_mean_load =
248 ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
249 ComputeDataType,
250 GridDesc_M_K,
251 decltype(thread_buffer_desc_m_k),
254 MeanInvStdSrcVectorDim,
255 MeanInvStdSrcVectorSize,
256 1,
257 false>(
258 mean_grid_desc_m_k,
259 make_multi_index(block_global_id * M_BlockTileSize +
260 thread_m_cluster_id * MThreadSliceSize,
261 thread_k_cluster_id * KThreadSliceSize));
262
263 auto threadwise_inv_std_load =
264 ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
265 ComputeDataType,
266 GridDesc_M_K,
267 decltype(thread_buffer_desc_m_k),
270 MeanInvStdSrcVectorDim,
271 MeanInvStdSrcVectorSize,
272 1,
273 false>(
274 inv_std_grid_desc_m_k,
275 make_multi_index(block_global_id * M_BlockTileSize +
276 thread_m_cluster_id * MThreadSliceSize,
277 thread_k_cluster_id * KThreadSliceSize));
278
279 auto threadwise_dx_store =
281 DXDataType,
282 decltype(thread_buffer_desc_m_k),
283 GridDesc_M_K,
287 DXDstVectorDim,
288 DXDstVectorSize,
290 1,
291 false>(
292 dx_grid_desc_m_k,
293 make_multi_index(block_global_id * M_BlockTileSize +
294 thread_m_cluster_id * MThreadSliceSize,
295 thread_k_cluster_id * KThreadSliceSize),
296 PassThroughOp{});
297
298 ComputeDataType reduce_size = type_convert<ComputeDataType>(
299 dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
300
302 ds_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
303 db_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
304 });
305
306 // Separate sweep once and sweep twice pipeline
307 // Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
308 // we don't need to use loop to read x, dy, gamma twice
309 if constexpr(SweepOnce)
310 {
311 threadwise_dy_load.Run(dy_grid_desc_m_k,
312 dy_global_val_buf,
314 make_tuple(I0, I0),
315 dy_thread_buf);
316
317 threadwise_x_load.Run(x_grid_desc_m_k,
318 x_global_val_buf,
320 make_tuple(I0, I0),
321 x_thread_buf);
322
323 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
324 gamma_global_val_buf,
326 make_tuple(I0, I0),
327 gamma_thread_buf);
328
329 threadwise_mean_load.Run(mean_grid_desc_m_k,
330 mean_global_val_buf,
332 make_tuple(I0, I0),
333 mean_thread_buf);
334
335 threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
336 inv_std_global_val_buf,
338 make_tuple(I0, I0),
339 inv_std_thread_buf);
340
342 constexpr auto offset_m =
343 Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
344
346 constexpr auto offset_m_k =
347 Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
348
349 ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
350 gamma_thread_buf[offset_m_k] *
351 x_thread_buf[offset_m_k];
352
353 db_thread_buf(offset_m) +=
354 dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
355 });
356 });
357
359 if constexpr(I > 0)
361
362 BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
364 BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
365 });
366
368 constexpr auto offset_m =
369 Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
370
372 constexpr auto offset_m_k =
373 Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
374
375 // b = (db * x_mean - ds) * rstd ** (3) / reduce_size
376 // c = -b * x_mean - db * rstd / reduce_size
377 // dx = rstd * dy * gamma + b * x + c
378
379 ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
380 ds_thread_buf[offset_m];
381
382 b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
383 inv_std_thread_buf[offset_m_k] / reduce_size;
384
385 ComputeDataType c = -b * mean_thread_buf(offset_m_k);
386
387 c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
388
389 dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
390 gamma_thread_buf[offset_m_k] *
391 inv_std_thread_buf[offset_m_k] +
392 b * x_thread_buf[offset_m_k] + c;
393 });
394 });
395
396 threadwise_dx_store.Run(thread_buffer_desc_m_k,
397 make_tuple(I0, I0),
398 dx_thread_buf,
399 dx_grid_desc_m_k,
400 dx_global_val_buf);
401
402 } // end of sweep once
403 else // Sweep Twice pipeline
404 {
405 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
406
407 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
408 {
409 threadwise_dy_load.Run(dy_grid_desc_m_k,
410 dy_global_val_buf,
412 make_tuple(I0, I0),
413 dy_thread_buf);
414
415 threadwise_x_load.Run(x_grid_desc_m_k,
416 x_global_val_buf,
418 make_tuple(I0, I0),
419 x_thread_buf);
420
421 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
422 gamma_global_val_buf,
424 make_tuple(I0, I0),
425 gamma_thread_buf);
426
427 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
428 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
429 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
430 thread_copy_fwd_step_m_k);
431
433 constexpr auto offset_m =
434 Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
435
437 constexpr auto offset_m_k =
438 Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
439
440 ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
441 gamma_thread_buf[offset_m_k] *
442 x_thread_buf[offset_m_k];
443
444 db_thread_buf(offset_m) +=
445 dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
446 });
447 });
448 } // end of first sweep
449
451 if constexpr(I > 0)
453
454 BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
456 BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
457 });
458
459 // reverse read for using dy, gamma and x in the cache
460 constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
461 auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
462
463 // move to tail
464 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
465 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
466 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
467
468 // move from start to tail
469 threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k);
470 threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k);
471 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
472
473 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
474 {
475 threadwise_dy_load.Run(dy_grid_desc_m_k,
476 dy_global_val_buf,
478 make_tuple(I0, I0),
479 dy_thread_buf);
480
481 threadwise_x_load.Run(x_grid_desc_m_k,
482 x_global_val_buf,
484 make_tuple(I0, I0),
485 x_thread_buf);
486
487 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
488 gamma_global_val_buf,
490 make_tuple(I0, I0),
491 gamma_thread_buf);
492
493 threadwise_mean_load.Run(mean_grid_desc_m_k,
494 mean_global_val_buf,
496 make_tuple(I0, I0),
497 mean_thread_buf);
498
499 threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
500 inv_std_global_val_buf,
502 make_tuple(I0, I0),
503 inv_std_thread_buf);
504
506 constexpr auto offset_m =
507 Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
508
510 constexpr auto offset_m_k =
511 Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
512
513 // b = (db * x_mean - ds) * rstd ** (3) / reduce_size
514 // c = -b * x_mean - db * rstd / reduce_size
515 // dx = rstd * dy * gamma + b * x + c
516
517 ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
518 ds_thread_buf[offset_m];
519
520 b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
521 inv_std_thread_buf[offset_m_k] / reduce_size;
522
523 ComputeDataType c = -b * mean_thread_buf(offset_m_k);
524
525 c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
526
527 dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
528 gamma_thread_buf[offset_m_k] *
529 inv_std_thread_buf[offset_m_k] +
530 b * x_thread_buf[offset_m_k] + c;
531 });
532 });
533
534 threadwise_dx_store.Run(thread_buffer_desc_m_k,
535 make_tuple(I0, I0),
536 dx_thread_buf,
537 dx_grid_desc_m_k,
538 dx_global_val_buf);
539
540 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
541 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
542 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
543 thread_copy_bwd_step_m_k);
544 threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k,
545 thread_copy_bwd_step_m_k);
546 threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
547 thread_copy_bwd_step_m_k);
548 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
549 }
550 }
551 }
552};
553
554} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_normalization_bwd_data.hpp:49
static __device__ void Run(const GridDesc_M_K &dy_grid_desc_m_k, const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &mean_grid_desc_m_k, const GridDesc_M_K &inv_std_grid_desc_m_k, const GridDesc_M_K &dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global)
Definition gridwise_normalization_bwd_data.hpp:114
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, ComputeDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340