device_normalization_fwd_splitk_impl.hpp File Reference#
device_normalization_fwd_splitk_impl.hpp File Reference
#include <iostream>#include <sstream>#include "ck/utility/reduction_operator.hpp"#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"#include "ck/tensor_operation/gpu/device/device_reduce.hpp"#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp"#include "ck/host_utility/device_prop.hpp"#include "ck/host_utility/kernel_launch.hpp"Go to the source code of this file.
Namespaces | |
| namespace | ck |
| namespace | ck::tensor_operation |
| namespace | ck::tensor_operation::device |
Functions | |
| template<typename GridwiseWelford, typename XDataType, typename WorkspaceMeanVarDataType, typename ComputeDataType, typename XGridDesc_M_K, typename MeanVarGridDesc_M_KBlock> | |
| __global__ void | ck::kernel_normalizationSplitK1st (const XGridDesc_M_K x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x_global, WorkspaceMeanVarDataType *const __restrict__ p_welford_mean, WorkspaceMeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count) |
| template<typename GridwiseWelfordNormalization, typename WorkspaceMeanVarDataType, typename XDataType, typename GammaDataType, typename BetaDataType, typename YDataType, typename SaveMeanInvStdDataType, typename ComputeDataType, typename YElementwiseOperation, typename MeanVarGridDesc_M_KBlock, typename CountGridDesc_M_KBlock, typename XYGammaBetaGridDesc_M_K, typename SaveMeanInvStdGridDesc_M> | |
| __global__ void | ck::kernel_normalizationSplitK2nd (const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const WorkspaceMeanVarDataType *const p_mean_global, const WorkspaceMeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) |