device_batched_gemm_e_permute.hpp Source File

device_batched_gemm_e_permute.hpp Source File#

Composable Kernel: device_batched_gemm_e_permute.hpp Source File
device_batched_gemm_e_permute.hpp
Go to the documentation of this file.
1#pragma once
2#include <iostream>
3#include <vector>
4
5#include "device_base.hpp"
6
7namespace ck {
8namespace tensor_operation {
9namespace device {
10
16
17template <typename ALayout,
18 typename BLayout,
19 typename DELayout,
20 typename ADataType,
21 typename BDataType,
22 typename EDataType,
23 typename AElementwiseOperation,
24 typename BElementwiseOperation,
25 typename CDEElementwiseOperation>
27{
28 virtual std::unique_ptr<BaseArgument>
29 MakeArgumentPointer(const void* p_a,
30 const void* p_b,
31 void* p_e,
32 index_t M,
33 index_t N,
34 index_t K,
35 index_t stride_A,
36 index_t stride_B,
37 index_t batch_stride_A,
38 index_t batch_stride_B,
39 BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
40 index_t BatchCount,
41 AElementwiseOperation a_element_op,
42 BElementwiseOperation b_element_op,
43 CDEElementwiseOperation cde_element_op) = 0;
44
45 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
46};
47
48} // namespace device
49} // namespace tensor_operation
50} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batched_gemm_e_permute.hpp:12
ck::index_t N_
Definition device_batched_gemm_e_permute.hpp:13
ck::index_t stride_N_
Definition device_batched_gemm_e_permute.hpp:14
ck::index_t G1_
Definition device_batched_gemm_e_permute.hpp:13
ck::index_t stride_G1_
Definition device_batched_gemm_e_permute.hpp:14
ck::index_t M_
Definition device_batched_gemm_e_permute.hpp:13
ck::index_t stride_M_
Definition device_batched_gemm_e_permute.hpp:14
ck::index_t G0_
Definition device_batched_gemm_e_permute.hpp:13
ck::index_t stride_G0_
Definition device_batched_gemm_e_permute.hpp:14
Definition device_batched_gemm_e_permute.hpp:27
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0