26template <
typename GridwiseGemm,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CDEElementwiseOperation,
31 bool HasMainKBlockLoop,
32 bool HasDoubleTailKBlockLoop>
34#if CK_USE_LAUNCH_BOUNDS
39 const AElementwiseOperation a_element_op,
40 const BElementwiseOperation b_element_op,
41 const CDEElementwiseOperation cde_element_op)
43#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
44 defined(__gfx11__) || defined(__gfx94__) || defined(__gfx12__))
45 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
49 const auto gemm_desc_ptr =
55 while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
56 block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
59 if(block_id < gemm_desc_ptr[group_id].BlockStart_)
67 group_id =
index_t((left + right) / 2);
70 GridwiseGemm::Run(gemm_desc_ptr[group_id].a_ptr_,
71 gemm_desc_ptr[group_id].b_ptr_,
72 gemm_desc_ptr[group_id].ds_ptr_,
73 gemm_desc_ptr[group_id].e_ptr_,
78 gemm_desc_ptr[group_id].a_grid_desc_k0_m0_m1_k1_,
79 gemm_desc_ptr[group_id].b_grid_desc_k0_n0_n1_k1_,
80 gemm_desc_ptr[group_id].ds_grid_desc_m0_m10_m11_n0_n10_n11_,
81 gemm_desc_ptr[group_id].e_grid_desc_m0_m10_m11_n0_n10_n11_,
82 gemm_desc_ptr[group_id].block_2_etile_map_,
83 integral_constant<bool, HasMainKBlockLoop>{},
84 integral_constant<bool, HasDoubleTailKBlockLoop>{});
94template <
typename ALayout,
100 typename AccDataType,
103 typename AElementwiseOperation,
104 typename BElementwiseOperation,
105 typename CDEElementwiseOperation,
115 typename M1N1ThreadClusterM1Xs,
116 typename M1N1ThreadClusterN1Xs,
117 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
118 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
119 typename ABlockTransferThreadClusterArrangeOrder,
120 typename ABlockTransferSrcAccessOrder,
121 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
122 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
123 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
124 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
125 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
126 typename BBlockTransferThreadClusterArrangeOrder,
127 typename BBlockTransferSrcAccessOrder,
128 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
129 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
130 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
131 typename CThreadTransferSrcDstAccessOrder,
132 index_t CThreadTransferSrcDstVectorDim,
133 index_t CThreadTransferDstScalarPerVector,
146 AElementwiseOperation,
147 BElementwiseOperation,
148 CDEElementwiseOperation>
150 using DeviceOp = DeviceGroupedGemmMultipleD_Dl;
151 static constexpr index_t NumDTensor = DsDataType::Size();
153 static constexpr auto I0 = Number<0>{};
154 static constexpr auto I1 = Number<1>{};
155 static constexpr auto I2 = Number<2>{};
156 static constexpr auto I3 = Number<3>{};
157 static constexpr auto I4 = Number<4>{};
158 static constexpr auto I5 = Number<5>{};
160 static constexpr auto K1Number = Number<K1>{};
162 static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
168 const auto a_grid_desc_m_k = [&]() {
169 if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
173 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
179 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
181 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
201 static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
207 const auto b_grid_desc_k_n = [&]() {
208 if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
212 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
218 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
220 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
240 template <
typename ELay>
241 static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
243 const auto c_grid_desc_m_n = [&]() {
244 if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
248 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
254 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
256 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
257 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
276 static auto MakeDsGridDescriptor_M_N(
const std::array<index_t, NumDTensor>& MRaws,
277 const std::array<index_t, NumDTensor>& NRaws,
278 const std::array<index_t, NumDTensor>& DsStride)
284 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
286 Number<NumDTensor>{});
289 using AGridDesc_K0_M_K1 =
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
290 using BGridDesc_K0_N_K1 =
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
291 using DsGridDesc_M_N =
decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
292 using EGridDesc_M_N =
decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
296 GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
301 AElementwiseOperation,
302 BElementwiseOperation,
303 CDEElementwiseOperation,
304 InMemoryDataOperationEnum::Set,
315 M1N1ThreadClusterM1Xs,
316 M1N1ThreadClusterN1Xs,
317 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
318 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
319 ABlockTransferThreadClusterArrangeOrder,
320 ABlockTransferSrcAccessOrder,
321 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
322 ABlockTransferSrcVectorTensorContiguousDimOrder,
323 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
324 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
325 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
326 BBlockTransferThreadClusterArrangeOrder,
327 BBlockTransferSrcAccessOrder,
328 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
329 BBlockTransferSrcVectorTensorContiguousDimOrder,
330 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
331 CThreadTransferSrcDstAccessOrder,
332 CThreadTransferSrcDstVectorDim,
333 CThreadTransferDstScalarPerVector>;
335 using AGridDesc_K0_M0_M1_K1 =
336 decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
337 using BGridDesc_K0_N0_N1_K1 =
338 decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
339 using DsGridDesc_M0_M10_M11_N0_N10_N11 =
340 decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{}));
341 using EGridDesc_M0_M10_M11_N0_N10_N11 =
342 decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{}));
344 struct GroupedGemmBlock2ETileMap
346 using Block2ETileMap =
347 remove_cvref_t<
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}))>;
349 GroupedGemmBlock2ETileMap()
351 block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{});
355 GroupedGemmBlock2ETileMap(
const EGridDesc_M_N& e_grid_desc_m_n,
ck::index_t BlockStart)
357 block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n);
358 BlockStart_ = BlockStart;
361 template <
typename TopIdx>
362 __host__ __device__
constexpr auto CalculateBottomIndex(
const TopIdx& idx_top)
const
364 return block_2_etile_map_.CalculateBottomIndex(
369 template <
typename CTileIdx,
typename CTileDim>
370 __host__ __device__
bool ValidCTileIndex(
const CTileIdx& c_tile_idx,
371 const CTileDim& c_tile_dim)
const
373 return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
376 __host__
bool CheckValidity(
const EGridDesc_M_N& e_grid_desc_m_n)
const
378 return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
381 Block2ETileMap block_2_etile_map_;
388 const ADataType* a_ptr_;
389 const BDataType* b_ptr_;
390 typename GridwiseGemm::DsGridPointer ds_ptr_;
394 AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
395 BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
396 DsGridDesc_M_N ds_grid_desc_m_n_;
397 EGridDesc_M_N e_grid_desc_m_n_;
400 AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
401 BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
402 DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_;
403 EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_;
406 GroupedGemmBlock2ETileMap block_2_etile_map_;
411 struct Argument :
public BaseArgument
413 Argument(std::vector<const void*>& p_As,
414 std::vector<const void*>& p_Bs,
415 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
416 std::vector<void*>& p_Es,
417 std::vector<GemmDesc>& gemm_descs,
418 AElementwiseOperation a_element_op,
419 BElementwiseOperation b_element_op,
420 CDEElementwiseOperation cde_element_op)
421 : a_element_op_{a_element_op},
422 b_element_op_{b_element_op},
423 cde_element_op_{cde_element_op},
424 gemm_kernel_host_args_{nullptr}
434 throw std::runtime_error(
"wrong! group_count_ != p_As/b/c.size");
437 gemm_desc_kernel_arg_.reserve(group_count_);
439 skipped_group_count_ = 0;
441 for(std::size_t i = 0; i < gemm_descs.size(); i++)
443 const index_t M = gemm_descs[i].M_;
444 const index_t N = gemm_descs[i].N_;
445 const index_t K = gemm_descs[i].K_;
447 a_mtx_mraw_kraw_.emplace_back(M, K);
448 b_mtx_nraw_kraw_.emplace_back(N, K);
452 skipped_group_count_++;
456 const index_t StrideA = gemm_descs[i].stride_A_;
457 const index_t StrideB = gemm_descs[i].stride_B_;
458 const index_t StrideE = gemm_descs[i].stride_C_;
460 typename GridwiseGemm::DsGridPointer p_ds_grid{};
461 DsGridDesc_M_N ds_grid_desc_m_n;
463 static_for<0, NumDTensor, 1>{}([&](
auto j) {
467 p_ds_grid(j) =
static_cast<const DDataType*
>(p_Ds[i][j]);
468 ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
469 M, N, gemm_descs[i].stride_Ds_[j]);
473 const auto a_grid_desc_k0_m_k1 =
474 DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
475 const auto b_grid_desc_k0_n_k1 =
476 DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
477 const auto e_grid_desc_m_n =
478 DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
480 if(GridwiseGemm::CheckValidity(
481 a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n))
485 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
486 .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
488 const index_t BlockStart = grid_size_;
489 const index_t BlockEnd = grid_size_ + grid_size_grp;
491 grid_size_ += grid_size_grp;
494 const auto block_2_etile_map =
495 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
498 const auto a_grid_desc_k0_m0_m1_k1 =
499 GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
500 const auto b_grid_desc_k0_n0_n1_k1 =
501 GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1);
502 const auto ds_grid_desc_m0_m10_m11_n0_n10_n11 =
503 GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n);
504 const auto e_grid_desc_m0_m10_m11_n0_n10_n11 =
505 GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n);
507 gemm_desc_kernel_arg_.push_back(
508 GemmKernelArg{
static_cast<const ADataType*
>(p_As[i]),
509 static_cast<const BDataType*
>(p_Bs[i]),
511 static_cast<EDataType*
>(p_Es[i]),
516 a_grid_desc_k0_m0_m1_k1,
517 b_grid_desc_k0_n0_n1_k1,
518 ds_grid_desc_m0_m10_m11_n0_n10_n11,
519 e_grid_desc_m0_m10_m11_n0_n10_n11,
533 AElementwiseOperation a_element_op_;
534 BElementwiseOperation b_element_op_;
535 CDEElementwiseOperation cde_element_op_;
537 std::vector<GemmKernelArg> gemm_desc_kernel_arg_;
538 std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
539 std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
542 void* gemm_kernel_host_args_;
546 struct Invoker :
public BaseInvoker
548 using Argument = DeviceOp::Argument;
550 float Run(
const Argument& arg,
551 const StreamConfig& stream_config = StreamConfig{},
552 hipStream_t cpy_stream =
nullptr,
553 hipEvent_t cpy_event =
nullptr)
555 auto K0 = arg.gemm_desc_kernel_arg_[0].a_grid_desc_k0_m_k1_.GetLength(I0);
556 bool all_has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
557 bool all_has_double_tail_k_block_loop =
558 GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
560 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
564 std::cout <<
"group: " << i <<
" arg.a_grid_desc_k0_m_k1_{"
565 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
567 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1)
569 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2)
572 std::cout <<
", arg.b_grid_desc_k0_n_k1_{"
573 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0)
575 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1)
577 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2)
580 std::cout <<
", arg.e_grid_desc_m_n_{ "
581 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) <<
", "
582 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) <<
"}"
586 if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
587 arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
588 arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
590 throw std::runtime_error(
591 "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
594 K0 = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
595 bool not_all_has_main_k_block_loop_same =
596 all_has_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(K0);
597 bool not_all_has_double_tail_k_block_loop_same =
598 all_has_double_tail_k_block_loop xor
599 GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
601 if(not_all_has_main_k_block_loop_same or not_all_has_double_tail_k_block_loop_same)
603 std::ostringstream err;
604 err <<
"Not all gemms have same value for [main|double_tail]_k_block_loop! in "
605 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__;
606 throw std::runtime_error(err.str());
613 if(cpy_stream && cpy_event)
615 if(arg.gemm_kernel_host_args_ ==
nullptr)
617 std::ostringstream err;
618 err <<
"No memory has been allocated for gemm kernel host args "
619 <<
"when providing the copy stream and copy event! In " << __FILE__ <<
":"
620 << __LINE__ <<
", in function: " << __func__;
621 throw std::runtime_error(err.str());
623 hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
624 arg.gemm_kernel_host_args_,
625 arg.group_count_ *
sizeof(GemmKernelArg),
626 hipMemcpyHostToDevice,
628 hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
629 hipGetErrorString(hipEventSynchronize(cpy_event));
634 hipMemcpyAsync(arg.p_workspace_,
635 arg.gemm_desc_kernel_arg_.data(),
636 arg.gemm_desc_kernel_arg_.size() *
sizeof(GemmKernelArg),
637 hipMemcpyHostToDevice,
638 stream_config.stream_id_));
642 auto has_double_tail_k_block_loop) {
643 constexpr bool has_main_loop = has_main_k_block_loop.value;
644 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
646 const auto kernel = kernel_grouped_gemm_multiple_d_dl<GridwiseGemm,
648 AElementwiseOperation,
649 BElementwiseOperation,
650 CDEElementwiseOperation,
657 dim3(arg.grid_size_),
661 arg.gemm_desc_kernel_arg_.size(),
664 arg.cde_element_op_);
667 if(all_has_main_k_block_loop && all_has_double_tail_k_block_loop)
670 integral_constant<bool, true>{});
672 else if(all_has_main_k_block_loop && !all_has_double_tail_k_block_loop)
675 integral_constant<bool, false>{});
677 else if(!all_has_main_k_block_loop && all_has_double_tail_k_block_loop)
680 integral_constant<bool, true>{});
685 integral_constant<bool, false>{});
690 float Run(
const BaseArgument* p_arg,
691 const StreamConfig& stream_config = StreamConfig{})
override
693 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
697 static bool IsSupportedArgument(
const Argument& arg)
700 arg.skipped_group_count_) != arg.group_count_)
708 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
710 if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
711 arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
712 arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
726 bool IsSupportedArgument(
const BaseArgument* p_arg)
override
728 return IsSupportedArgument(*
dynamic_cast<const Argument*
>(p_arg));
731 static auto MakeArgument(std::vector<const void*>& p_As,
732 std::vector<const void*>& p_Bs,
733 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
734 std::vector<void*>& p_Es,
735 std::vector<GemmDesc> gemm_descs,
736 AElementwiseOperation a_element_op,
737 BElementwiseOperation b_element_op,
738 CDEElementwiseOperation cde_element_op)
741 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op};
744 static auto MakeInvoker() {
return Invoker{}; }
747 std::unique_ptr<BaseArgument>
748 MakeArgumentPointer(std::vector<const void*>& p_As,
749 std::vector<const void*>& p_Bs,
750 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
751 std::vector<void*>& p_Es,
752 std::vector<GemmDesc>& gemm_descs,
753 AElementwiseOperation a_element_op,
754 BElementwiseOperation b_element_op,
755 CDEElementwiseOperation cde_element_op)
override
757 return std::make_unique<Argument>(
758 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op);
762 std::unique_ptr<BaseInvoker> MakeInvokerPointer()
override
764 return std::make_unique<Invoker>(Invoker{});
768 std::string GetTypeString()
const override
770 auto str = std::stringstream();
773 str <<
"DeviceGroupedGemmMultipleD_Dl"
778 << K0PerBlock <<
", "
780 << M1PerThread <<
", "
781 << N1PerThread <<
", "
790 size_t GetWorkSpaceSize(
const BaseArgument* p_arg)
const override
792 return dynamic_cast<const Argument*
>(p_arg)->group_count_ *
sizeof(GemmKernelArg);
795 size_t GetDeviceKernelArgSize(
const BaseArgument* p_arg)
const override
797 return GetWorkSpaceSize(p_arg);
800 size_t GetHostKernelArgSize(
const BaseArgument* p_arg)
const {
return GetWorkSpaceSize(p_arg); }
802 void SetDeviceKernelArgs(BaseArgument* p_arg,
void* p_dev_kernel_args)
const override
804 return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
815 void SetHostKernelArgsPointer(BaseArgument* p_arg,
void* p_host_kernel_args)
const
817 Argument* pArg_ =
dynamic_cast<Argument*
>(p_arg);
820 throw std::runtime_error(
"Failed to cast argument pointer!");
823 pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
824 std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
825 pArg_->gemm_desc_kernel_arg_.end(),
826 static_cast<GemmKernelArg*
>(pArg_->gemm_kernel_host_args_));
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition device_grouped_gemm.hpp:99
Definition device_grouped_gemm.hpp:80
#define CK_ENV(name)
Definition utility/env.hpp:129