gridwise_batchnorm_backward_blockwise_welford.hpp Source File

gridwise_batchnorm_backward_blockwise_welford.hpp Source File#

Composable Kernel: gridwise_batchnorm_backward_blockwise_welford.hpp Source File
gridwise_batchnorm_backward_blockwise_welford.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
15
16namespace ck {
17
18template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
19 typename XDataType,
20 typename DyDataType,
21 typename DxDataType,
22 typename AccDataType,
23 typename ScaleDataType,
24 typename DscaleDbiasDataType,
25 typename MeanVarDataType,
26 typename DyElementwiseOp,
27 typename XYGridDesc_M_K,
28 typename ScaleBiasGridDesc_M,
29 typename MeanVarGridDesc_M,
30 typename GetReduceCountPerThreadFunctor>
32 const XYGridDesc_M_K x_grid_desc_m_k,
33 const XYGridDesc_M_K dy_grid_desc_m_k,
34 const XYGridDesc_M_K dx_grid_desc_m_k,
35 const ScaleBiasGridDesc_M scale_grid_desc_m,
36 const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
37 const MeanVarGridDesc_M mean_var_grid_desc_m,
38 const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
39 long_index_t reduce_size,
40 index_t num_k_block_tile_iteration,
41 AccDataType epsilon,
42 const XDataType* const __restrict__ p_x,
43 const DyDataType* const __restrict__ p_dy,
44 const ScaleDataType* const __restrict__ p_scale,
45 bool haveSavedMeanInvVar,
46 const MeanVarDataType* const __restrict__ p_savedMean,
47 const MeanVarDataType* const __restrict__ p_savedInvVar,
48 const DyElementwiseOp dy_elementwise_op,
49 DxDataType* const __restrict__ p_dx,
50 DscaleDbiasDataType* const __restrict__ p_dscale,
51 DscaleDbiasDataType* const __restrict__ p_dbias)
52{
53 GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
54 dy_grid_desc_m_k,
55 dx_grid_desc_m_k,
56 scale_grid_desc_m,
57 dscale_dbias_grid_desc_m,
58 mean_var_grid_desc_m,
59 get_reduce_count_per_thread,
60 reduce_size,
61 num_k_block_tile_iteration,
62 epsilon,
63 p_x,
64 p_dy,
65 p_scale,
66 haveSavedMeanInvVar,
67 p_savedMean,
68 p_savedInvVar,
69 dy_elementwise_op,
70 p_dx,
71 p_dscale,
72 p_dbias);
73};
74
75template <typename XDataType,
76 typename DyDataType,
77 typename DxDataType,
78 typename AccDataType,
79 typename ScaleDataType,
80 typename DscaleDbiasDataType,
81 typename MeanVarDataType,
82 typename DyElementwiseOp,
83 typename XYGridDesc_M_K,
84 typename ScaleBiasGridDesc_M,
85 typename MeanVarGridDesc_M,
86 typename GetReduceCountPerThreadFunctor,
87 index_t BlockSize,
88 index_t MThreadClusterSize,
89 index_t KThreadClusterSize,
90 index_t MThreadSliceSize,
91 index_t KThreadSliceSize,
92 index_t XDyDxVectorDim,
93 index_t XSrcVectorSize,
94 index_t DySrcVectorSize,
95 index_t DxDstVectorSize,
96 index_t ScaleSrcVectorSize,
97 index_t DscaleDbiasDstVectorSize,
98 index_t MeanVarSrcVectorSize>
100{
101 static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
102 MThreadSliceSize % DySrcVectorSize == 0 &&
103 MThreadSliceSize % DxDstVectorSize == 0) ||
104 (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
105 KThreadSliceSize % DySrcVectorSize == 0 &&
106 KThreadSliceSize % DxDstVectorSize == 0),
107 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
108
109 static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
110
112
115
118
119 static constexpr auto thread_cluster_desc =
121
126
129
131 BlockSize,
134
136 BlockSize,
140 false>;
141
146 false>;
147
149
150 static constexpr auto I0 = Number<0>{};
151 static constexpr auto I1 = Number<1>{};
152
153 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
154 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
155
156 // clang-format off
157 // Blockwise BatchNorm Backward
158 // Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size
159 // Output: dx, dscale, dbias
160 // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
161 // Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
162 // Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
163 // clang-format on
164 __device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k,
165 const XYGridDesc_M_K dy_grid_desc_m_k,
166 const XYGridDesc_M_K dx_grid_desc_m_k,
167 const ScaleBiasGridDesc_M scale_grid_desc_m,
168 const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
169 const MeanVarGridDesc_M mean_var_grid_desc_m,
170 const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
171 long_index_t reduce_size,
172 index_t num_k_block_tile_iteration,
173 AccDataType epsilon,
174 const XDataType* const __restrict__ p_x,
175 const DyDataType* const __restrict__ p_dy,
176 const ScaleDataType* const __restrict__ p_scale,
177 bool haveSavedMeanInvVar,
178 const MeanVarDataType* const __restrict__ p_savedMean,
179 const MeanVarDataType* const __restrict__ p_savedInvVar,
180 const DyElementwiseOp dy_elementwise_op,
181 DxDataType* const __restrict__ p_dx,
182 DscaleDbiasDataType* const __restrict__ p_dscale,
183 DscaleDbiasDataType* const __restrict__ p_dbias)
184 {
185 using ck::math::sqrt;
186
187 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
188
189 auto reduce_work_buf =
190 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
191
193 x_thread_buf;
194
196 dy_thread_buf;
197
199 dx_thread_buf;
200
201 // buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction
203 tmp1_thread_buf;
204
206
210 inv_var_thread_buf = var_thread_buf;
211
214
215 const index_t thread_local_id = get_thread_local_1d_id();
216 const index_t block_global_id = get_block_1d_id();
217
218 const auto thread_cluster_idx =
219 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
220
221 const auto thread_m_cluster_id = thread_cluster_idx[I0];
222 const auto thread_k_cluster_id = thread_cluster_idx[I1];
223
224 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
225 using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
226 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
228 constexpr auto thread_buffer_desc_m =
230
231 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
232 AccDataType,
233 XYGridDesc_M_K,
234 decltype(thread_buffer_desc_m_k),
235 ThreadBufferLengths_M_K,
237 XDyDxVectorDim,
238 XSrcVectorSize,
239 1,
240 true>(
241 x_grid_desc_m_k,
242 make_multi_index(block_global_id * M_BlockTileSize +
243 thread_m_cluster_id * MThreadSliceSize,
244 thread_k_cluster_id * KThreadSliceSize));
245
246 auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
247 AccDataType,
248 XYGridDesc_M_K,
249 decltype(thread_buffer_desc_m_k),
250 ThreadBufferLengths_M_K,
252 XDyDxVectorDim,
253 XSrcVectorSize,
254 1,
255 true>(
256 dy_grid_desc_m_k,
257 make_multi_index(block_global_id * M_BlockTileSize +
258 thread_m_cluster_id * MThreadSliceSize,
259 thread_k_cluster_id * KThreadSliceSize));
260
261 auto threadwise_dx_store =
263 DxDataType,
264 decltype(thread_buffer_desc_m_k),
265 XYGridDesc_M_K,
267 ThreadBufferLengths_M_K,
269 XDyDxVectorDim,
270 DxDstVectorSize,
272 1,
273 true>(
274 dx_grid_desc_m_k,
275 make_multi_index(block_global_id * M_BlockTileSize +
276 thread_m_cluster_id * MThreadSliceSize,
277 thread_k_cluster_id * KThreadSliceSize),
278 PassThroughOp{});
279
280 auto threadwise_scale_load =
282 AccDataType,
283 ScaleBiasGridDesc_M,
284 decltype(thread_buffer_desc_m),
285 ThreadBufferLengths_M,
287 0,
288 ScaleSrcVectorSize,
289 1,
290 true>(
291 scale_grid_desc_m,
292 make_multi_index(block_global_id * M_BlockTileSize +
293 thread_m_cluster_id * MThreadSliceSize));
294
295 auto threadwise_dscale_dbias_store =
297 DscaleDbiasDataType,
298 decltype(thread_buffer_desc_m),
299 ScaleBiasGridDesc_M,
301 ThreadBufferLengths_M,
303 0,
304 DscaleDbiasDstVectorSize,
306 1,
307 true>(
308 dscale_dbias_grid_desc_m,
309 make_multi_index(block_global_id * M_BlockTileSize +
310 thread_m_cluster_id * MThreadSliceSize),
311 PassThroughOp{});
312
313 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
314 constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
315
316 const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
317 p_x, x_grid_desc_m_k.GetElementSpaceSize());
318
319 const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
320 p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
321
323 p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
324
325 const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
326 p_scale, scale_grid_desc_m.GetElementSpaceSize());
327
328 auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
329 p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
330
331 auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
332 p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
333
334 // clang-format off
335 // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
336 // clang-format on
337
338 if(haveSavedMeanInvVar)
339 {
340 const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
341 p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
342
343 const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
344 p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
345
346 auto threadwise_mean_inv_var_load =
347 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
348 AccDataType,
349 MeanVarGridDesc_M,
350 decltype(thread_buffer_desc_m),
351 ThreadBufferLengths_M,
353 0,
354 MeanVarSrcVectorSize,
355 1,
356 true>(
357 mean_var_grid_desc_m,
358 make_multi_index(block_global_id * M_BlockTileSize +
359 thread_m_cluster_id * MThreadSliceSize));
360
361 threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
362 mean_global_buf,
363 thread_buffer_desc_m,
364 make_tuple(I0),
365 mean_thread_buf);
366
367 threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
368 inv_var_global_buf,
369 thread_buffer_desc_m,
370 make_tuple(I0),
371 inv_var_thread_buf);
372 }
373 else
374 {
375 auto threadwise_welford = ThreadwiseWelford();
376 threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id);
377
379 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
380 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
381 });
382
383 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
384 {
385
386 threadwise_x_load.Run(x_grid_desc_m_k,
387 x_global_buf,
388 thread_buffer_desc_m_k,
389 make_tuple(I0, I0),
390 x_thread_buf);
391
392 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
393 threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
394 }
395
397 if constexpr(I > 0)
399
400 int count = threadwise_welford.cur_count_;
401 BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
402 });
403
404 // calculate inv-variance as 1/sqrt(epsilon+variance)
406 inv_var_thread_buf(I) =
407 type_convert<AccDataType>(1.0) / sqrt(var_thread_buf[I] + epsilon);
408 });
409
410 threadwise_x_load.SetSrcSliceOrigin(
411 x_grid_desc_m_k,
412 make_multi_index(block_global_id * M_BlockTileSize +
413 thread_m_cluster_id * MThreadSliceSize,
414 thread_k_cluster_id * KThreadSliceSize));
415 };
416
417 // clang-format off
418 // Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
419 // clang-format on
420
422 dscale_thread_buf(I) = type_convert<AccDataType>(0);
423 dbias_thread_buf(I) = type_convert<AccDataType>(0);
424 });
425
426 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
427 {
428 threadwise_x_load.Run(x_grid_desc_m_k,
429 x_global_buf,
430 thread_buffer_desc_m_k,
431 make_tuple(I0, I0),
432 x_thread_buf);
433
434 threadwise_dy_load.Run(dx_grid_desc_m_k,
435 dy_global_buf,
436 thread_buffer_desc_m_k,
437 make_tuple(I0, I0),
438 dy_thread_buf);
439
442 constexpr auto offset =
443 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
444
445 dy_elementwise_op(dy_thread_buf(Number<offset>{}),
446 dy_thread_buf[Number<offset>{}]);
447
448 AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
449 inv_var_thread_buf[iM];
450
451 tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
452 });
453 });
454
455 ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);
456 ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf);
457
458 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
459 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
460 };
461
463 if constexpr(I > 0)
465 BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
467 BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
468 });
469
470 if(thread_k_cluster_id == 0)
471 {
472 threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
473 make_tuple(I0),
474 dscale_thread_buf,
475 dscale_dbias_grid_desc_m,
476 dscale_global_buf);
477
478 threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
479 make_tuple(I0),
480 dbias_thread_buf,
481 dscale_dbias_grid_desc_m,
482 dbias_global_buf);
483 };
484
485 // clang-format off
486 // Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
487 // clang-format on
488
489 threadwise_scale_load.Run(scale_grid_desc_m,
490 scale_global_buf,
491 thread_buffer_desc_m,
492 make_tuple(I0),
493 scale_thread_buf);
494
495 auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
496
497 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
498 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
499 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
500
501 AccDataType inv_reduce_size =
503
504 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
505 {
506 threadwise_x_load.Run(x_grid_desc_m_k,
507 x_global_buf,
508 thread_buffer_desc_m_k,
509 make_tuple(I0, I0),
510 x_thread_buf);
511
512 threadwise_dy_load.Run(dy_grid_desc_m_k,
513 dy_global_buf,
514 thread_buffer_desc_m_k,
515 make_tuple(I0, I0),
516 dy_thread_buf);
517
519 AccDataType multiplier =
520 inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
521
523 constexpr auto offset =
524 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
525
526 dy_elementwise_op(dy_thread_buf(Number<offset>{}),
527 dy_thread_buf[Number<offset>{}]);
528
529 AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
530 inv_var_thread_buf[iM];
531
532 AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
533
534 dx_thread_buf(Number<offset>{}) =
535 multiplier *
536 (type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
537 dbias_thread_buf[iM] - tmpVal);
538 });
539 });
540
541 threadwise_dx_store.Run(thread_buffer_desc_m_k,
542 make_tuple(I0, I0),
543 dx_thread_buf,
544 dx_grid_desc_m_k,
545 dx_global_buf);
546
547 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
548 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
549 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
550 }
551 }
552};
553
554} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_batchnorm_backward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:31
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:100
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:111
PartitionedBlockwiseReduction< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, ck::reduce::Add, false > BlockwiseReduce
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:135
ThreadwiseReduction< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M, ck::reduce::Add, false > ThreadwiseReduce
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:142
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:122
static __device__ void Run(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:164
static constexpr index_t K_BlockTileSize
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:154
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:116
static constexpr auto thread_cluster_desc
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:119
ThreadwiseWelford< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:127
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:130
static constexpr bool reorder_thread_cluster
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:109
static constexpr index_t M_BlockTileSize
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:153
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:148
static constexpr auto I1
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:151
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:113
static constexpr auto I0
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:150
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:124
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:173
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340