22template <
typename ADataType,
29 typename AElementwiseOperation,
30 typename BElementwiseOperation,
31 typename CElementwiseOperation,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
56 bool BBlockLdsAddExtraN,
67 AElementwiseOperation,
68 BElementwiseOperation,
69 CElementwiseOperation>
80 AElementwiseOperation,
81 BElementwiseOperation,
82 CElementwiseOperation,
93 ABlockTransferThreadClusterLengths_K0_M_K1,
94 ABlockTransferThreadClusterArrangeOrder,
95 ABlockTransferSrcAccessOrder,
96 ABlockTransferSrcVectorDim,
97 ABlockTransferSrcScalarPerVector,
98 ABlockTransferDstScalarPerVector_K1,
101 BBlockTransferThreadClusterLengths_K0_N_K1,
102 BBlockTransferThreadClusterArrangeOrder,
103 BBlockTransferSrcAccessOrder,
104 BBlockTransferSrcVectorDim,
105 BBlockTransferSrcScalarPerVector,
106 BBlockTransferDstScalarPerVector_K1,
110 CThreadTransferSrcDstVectorDim,
111 CThreadTransferDstScalarPerVector,
122 if(stream_config.log_level_ > 0)
129 throw std::runtime_error(
130 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting");
142 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
149 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
159 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
185 const BDataType* p_b,
193 AElementwiseOperation,
194 BElementwiseOperation,
195 CElementwiseOperation)
197 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
212 AElementwiseOperation,
213 BElementwiseOperation,
214 CElementwiseOperation)
override
216 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
217 static_cast<const BDataType*
>(p_b),
218 static_cast<CDataType*
>(p_c),
230 return std::make_unique<Invoker>(
Invoker{});
236 auto str = std::stringstream();
238 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
242 str <<
"DeviceGemmDpp"
252 << MDppPerWave <<
", "
253 << MDppPerWave <<
", "
254 << ABlockTransferSrcScalarPerVector <<
", "
255 << ABlockTransferDstScalarPerVector_K1 <<
", "
256 << BBlockTransferSrcScalarPerVector <<
", "
257 << BBlockTransferDstScalarPerVector_K1
260 << NumPrefetch <<
", "
261 <<
"PipelineVersion: "
262 << PipelineVersionToString[PipelineVer];
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__global__ void kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_dpp.hpp:29
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dpp.hpp:96
ck::GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, MPerBlock, NPerBlock, KPerBlock, MPerDpp, NPerDpp, AK1, BK1, MDppPerWave, NDppPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 0, 2, 4, 1, 3, 5 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumPrefetch, PipelineVer >::CalculateHasMainKBlockLoop static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_dpp.hpp:349
ck::GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, MPerBlock, NPerBlock, KPerBlock, MPerDpp, NPerDpp, AK1, BK1, MDppPerWave, NDppPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 0, 2, 4, 1, 3, 5 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumPrefetch, PipelineVer >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dpp.hpp:115
ck::GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, MPerBlock, NPerBlock, KPerBlock, MPerDpp, NPerDpp, AK1, BK1, MDppPerWave, NDppPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 0, 2, 4, 1, 3, 5 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumPrefetch, PipelineVer >::CheckValidity static __host__ constexpr bool CheckValidity(const Problem &problem)
Definition gridwise_gemm_dpp.hpp:256
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_gemm_dpp.hpp:119
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_dpp.hpp:156
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_dpp.hpp:120
Definition device_gemm_dpp.hpp:70
GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, MPerBlock, NPerBlock, KPerBlock, MPerDpp, NPerDpp, AK1, BK1, MDppPerWave, NDppPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 0, 2, 4, 1, 3, 5 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumPrefetch, PipelineVer > GridwiseGemm
Definition device_gemm_dpp.hpp:71
typename GridwiseGemm::Argument Argument
Definition device_gemm_dpp.hpp:115
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_dpp.hpp:179
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_dpp.hpp:163
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_dpp.hpp:184
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_dpp.hpp:169
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_dpp.hpp:203
static auto MakeInvoker()
Definition device_gemm_dpp.hpp:200
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_dpp.hpp:228
std::string GetTypeString() const override
Definition device_gemm_dpp.hpp:234
Definition device_gemm.hpp:22