device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp Source File#
device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
Go to the documentation of this file.
418 using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm64::DefaultBlock2CTileMap>;
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(const void CK_CONSTANT_ADDRESS_SPACE *group_kernel_args, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:37
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
Definition ck/stream_config.hpp:10
Gridwise gemm + softmax + gemm fusion.
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:87
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:293
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:231
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:319
Definition block_to_ctile_map.hpp:872
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition transform_contraction_to_gemm.hpp:121
static auto MakeB0GridDescriptor_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:198
static auto MakeAGridDescriptor_G_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:154
static auto MakeB0GridDescriptor_G_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:193
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm.hpp:168
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm.hpp:248
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:159
static constexpr auto matrix_padder
Definition transform_contraction_to_gemm.hpp:139
static auto MakeCGridDescriptor_G_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:274
static auto MakeB1GridDescriptor_G_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:233
static auto MakeB1GridDescriptor_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:238
static auto MakeCGridDescriptor_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:279
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm.hpp:208
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
BaseArgument()=default
BaseInvoker()=default
Definition masking_specialization.hpp:57
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:313
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const BGridDesc_G_N_K &b_grid_desc_g_n_k, const B1GridDesc_G_N_K &b1_grid_desc_g_n_k, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:314
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:325
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:330
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:335
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:340
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:466
std::vector< GroupDeviceArg > group_device_args_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:593
AccElementwiseOperation acc_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:600
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle::Argument::grid_size_
index_t grid_size_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:596
std::size_t group_count_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:595
CElementwiseOperation c_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:602
Argument(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * > > p_acc0_biases_vec, std::vector< std::vector< const void * > > p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:467
B1ElementwiseOperation b1_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:601
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:598
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:599
std::vector< GroupKernelArg > group_kernel_args_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:592
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:449
std::vector< index_t > b_nz_kz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:455
std::vector< index_t > a_mz_kz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:454
std::vector< index_t > c_mz_gemm1nz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:457
std::vector< index_t > raw_lengths_mz_nz_kz_gemm1nz_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:451
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:460
std::vector< index_t > b1_nz_kz_strides_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:456
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:421
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:431
index_t block_start_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:445
C0MatrixMask c0_matrix_mask_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:440
CDataType * p_c_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:426
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:429
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:430
index_t num_blocks_per_batch_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:436
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:437
const ADataType * p_a_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:423
GridwiseGemm64::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:433
const BDataType * p_b_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:424
index_t block_end_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:445
const B1DataType * p_b1_grid_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:425
Block2CTileMap block_2_ctile_map_
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:443
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:607
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:703
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:683
DeviceOp::Argument Argument
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:608
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:611
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:207
decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})) AGridDesc_AK0_M_AK1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:290
static auto MakeArgument(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * > > p_acc0_biases_vec, std::vector< std::vector< const void * > > p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:833
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:294
std::string GetTypeString() const override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:898
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:828
static constexpr index_t NumAcc1Bias
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:217
static constexpr index_t NumAcc0Bias
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:216
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:710
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:716
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * > > p_acc0_biases_vec, std::vector< std::vector< const void * > > p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:864
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:892
static constexpr auto I1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:252
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:264
static constexpr auto I0
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:251
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:310
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_K
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:296
static auto MakeInvoker()
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:860
decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})) BGridDesc_BK0_N_BK1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:291
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:297
static constexpr auto I2
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:253
OffsettedBlockToCTileMap< typename GridwiseGemm64::DefaultBlock2CTileMap > Block2CTileMap
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:418
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:415
static constexpr auto make_MaskOutPredicate()
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:299
DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle DeviceOp
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:232
TransformBatchedContractionContractionToBatchedGemmGemm< Sequence< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO >, Sequence< MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock >, GemmSpec, ASpec, BSpec, B1Spec, CSpec > Transform
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:255
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:272
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle::B1GridDesc_BK0_N_BK1
decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})) B1GridDesc_BK0_N_BK1
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:292
typename DeviceGroupedGemmSoftmaxGemmPermute< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, BDataType, B1DataType, CDataType, Acc0BiasDataType, Acc1BiasDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, MaskingSpec >::ProblemDesc ProblemDesc
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:233
static constexpr auto MXdlPerWave64
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:208
static auto MakeB1GridDescriptor_BK0_N_BK1(const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_lengths_vec, const std::vector< index_t > &b1_gs_gemm1ns_gemm1ks_strides_vec)
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:281
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:935
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:416
static constexpr auto MXdlPerWave32
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:210
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:293
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) BGridDesc_G_N_K
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:295
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle > GridwiseGemmBase
Definition device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp:354
Definition device_grouped_gemm_softmax_gemm_permute.hpp:34
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43