reference_moe_gemm.hpp Source File

reference_moe_gemm.hpp Source File#

Composable Kernel: reference_moe_gemm.hpp Source File
reference_moe_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdlib>
7#include <thread>
8
9#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename ADataType,
15 typename BDataType,
16 typename AccDataType,
17 typename CDataType,
18 typename LayoutA,
19 typename LayoutB,
20 typename LayoutC,
21 int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
22 typename ActivationOp = identity>
23__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
24 const ck_tile::index_t* p_sorted_expert_ids_,
25 const ck_tile::index_t* p_max_token_id_,
26 const ADataType* A,
27 const BDataType* B,
28 CDataType* C,
29 const AccDataType* expert_weight_ptr,
30 ck_tile::index_t Num_tokens,
31 ck_tile::index_t TokensPerBlock,
36 ck_tile::index_t strideA,
37 ck_tile::index_t strideB,
38 ck_tile::index_t strideC,
39 index_t scale_granularity_m,
40 index_t scale_granularity_n,
41 index_t scale_granularity_k,
42 float* scale_A_ptr,
43 float* scale_B_ptr,
44 float* expert_bias_ptr)
45{
46 int idx = blockIdx.x * blockDim.x + threadIdx.x;
47 int problem_N = MoeGemmKind == 1 ? N / 2 : N;
48 int row = idx / problem_N; // Compute row index
49 int col = idx % problem_N; // Compute column index
50
51 index_t gather_token_id = 0;
52 index_t scatter_token_id = 0;
53 index_t expert_id = 0;
54
55 if(row < p_max_token_id_[0])
56 {
57 expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
58 gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
59 scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
60 if(gather_token_id >= Num_tokens)
61 {
62 return;
63 }
64 if(MoeGemmKind == 2)
65 {
66 gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
67 }
68 else
69 {
70 scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
71 }
72 }
73 else
74 {
75 return;
76 }
77
78 if(row < M)
79 {
80 AccDataType acc = 0.0;
81 AccDataType acc_up = 0.0;
82
83 AccDataType acc_temp = 0.0;
84 AccDataType acc_up_temp = 0.0;
85
86 float scale_A = 0;
87 float scale_B = 0;
88 float scale_B_up = 0;
89
90 index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
91 index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
92 index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
93
94 for(int k = 0; k < K; ++k)
95 {
96 if(k % scale_granularity_k == 0)
97 {
98 // update acc
99 acc += acc_temp * scale_A * scale_B;
100 acc_up += acc_up_temp * scale_A * scale_B_up;
101 // reset acc temp
102 acc_temp = 0.0;
103 acc_up_temp = 0.0;
104 // update scale factors
105 scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
106 (k / scale_granularity_k) * scale_A_stride];
107 scale_B =
108 scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
109 (k / scale_granularity_k) * scale_B_stride];
110 if constexpr(MoeGemmKind == 1)
111 scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
112 (col + problem_N) / scale_granularity_n +
113 (k / scale_granularity_k) * scale_B_stride];
114 }
115
118 // Adjust indexing based on matrix layout
119 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
120 ? gather_token_id * strideA + k
121 : k * strideA + gather_token_id;
122
123 long b_index =
124 long(expert_id) * N * K +
125 ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
126 : k * strideB + col);
127 long b_index_up;
128 if constexpr(MoeGemmKind == 1)
129 b_index_up = long(expert_id) * N * K +
130 ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
131 ? (col + problem_N) * strideB + k
132 : k * strideB + col + problem_N);
133
134 AccDataType v_a;
135 AccDataType v_b;
136 AccDataType v_b_up;
137 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
138 {
139 const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
140 if(k % 2 == 1)
141 v_a = fp32_val.hi;
142 else
143 v_a = fp32_val.lo;
144 }
145 else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
146 {
147 const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
148 if(k % 2 == 1)
149 v_a = fp32_val.hi;
150 else
151 v_a = fp32_val.lo;
152 }
153 else
154 {
155 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
156 }
157 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
158 {
159 const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
160 if(k % 2 == 1)
161 v_b = fp32_val.hi;
162 else
163 v_b = fp32_val.lo;
164 if constexpr(MoeGemmKind == 1)
165 {
166 const fp32x2_t fp32_val_up =
167 pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
168 if(k % 2 == 1)
169 v_b_up = fp32_val_up.hi;
170 else
171 v_b_up = fp32_val_up.lo;
172 }
173 }
174 else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
175 {
176 const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
177 if(k % 2 == 1)
178 v_b = fp32_val.hi;
179 else
180 v_b = fp32_val.lo;
181 if constexpr(MoeGemmKind == 1)
182 {
183 const fp32x2_t fp32_val_up =
184 pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
185 if(k % 2 == 1)
186 v_b_up = fp32_val_up.hi;
187 else
188 v_b_up = fp32_val_up.lo;
189 }
190 }
191 else
192 {
193 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
194 if constexpr(MoeGemmKind == 1)
195 v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
196 }
197 acc_temp += v_a * v_b;
198 if constexpr(MoeGemmKind == 1)
199 acc_up_temp += v_a * v_b_up;
200 }
201
202 acc += acc_temp * scale_A * scale_B;
203 acc_up += acc_up_temp * scale_A * scale_B_up;
204
205 float bias = 0.f, bias_up = 0.f;
206 if(expert_bias_ptr != nullptr)
207 {
208 bias = expert_bias_ptr[expert_id * N + col];
209 if constexpr(MoeGemmKind == 1)
210 bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
211 }
212
213 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
214 ? scatter_token_id * strideC + col
215 : col * strideC + scatter_token_id;
216 if constexpr(MoeGemmKind < 2)
217 {
219 ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
220 }
221 else
222 {
223 // moe gemm2 don't use activation.
224 CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
225 using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
228 ResV2Type add_v{0, 0};
229 if(c_index % 2)
230 {
231 // result is the second value of fp16 pair.
232 add_v.y = res;
233 }
234 else
235 {
236 // result is the first value of fp16 pair.
237 add_v.x = res;
238 }
239 // mask last bit to make sure atomicAdd pointer is aligned of DWORD.
240 atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
241 }
242 }
243}
244
245template <typename ADataType,
246 typename BDataType,
247 typename AccDataType,
248 typename CDataType,
249 typename LayoutA,
250 typename LayoutB,
251 typename LayoutC,
252 int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
253 typename ActivationOp = identity>
254void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
255 const index_t* p_sorted_expert_ids_,
256 const index_t* p_max_token_id_,
257 const ADataType* a_ptr,
258 const BDataType* b_ptr,
259 CDataType* c_ptr,
260 const AccDataType* expert_weight_ptr,
261 index_t Num_tokens,
262 index_t TokensPerBlock,
263 index_t TopK,
264 index_t M,
265 index_t N,
266 index_t K,
267 index_t stride_a,
268 index_t stride_b,
269 index_t stride_c,
270 index_t scale_granularity_m,
271 index_t scale_granularity_n,
272 index_t scale_granularity_k,
273 float* scale_A_ptr,
274 float* scale_B_ptr,
275 float* exp_bias = nullptr)
276{
277 int problem_N = MoeGemmKind == 1 ? N / 2 : N;
278 int totalElements = M * problem_N;
279 int numThreadsPerBlock = 256; // Common choice for threads per block
280 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
281
282 moe_gemm_kernel<ADataType,
283 BDataType,
284 AccDataType,
285 CDataType,
286 LayoutA,
287 LayoutB,
288 LayoutC,
289 MoeGemmKind,
290 ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
291 p_sorted_expert_ids_,
292 p_max_token_id_,
293 a_ptr,
294 b_ptr,
295 c_ptr,
296 expert_weight_ptr,
297 Num_tokens,
298 TokensPerBlock,
299 TopK,
300 M,
301 N,
302 K,
303 stride_a,
304 stride_b,
305 stride_c,
306 scale_granularity_m,
307 scale_granularity_n,
308 scale_granularity_k,
309 scale_A_ptr,
310 scale_B_ptr,
311 exp_bias);
312
313 return;
314}
315
316} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition pk_int4.hpp:105
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:350
__global__ void moe_gemm_kernel(const ck_tile::index_t *p_sorted_token_ids_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const ADataType *A, const BDataType *B, CDataType *C, const AccDataType *expert_weight_ptr, ck_tile::index_t Num_tokens, ck_tile::index_t TokensPerBlock, ck_tile::index_t TopK, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *expert_bias_ptr)
Definition reference_moe_gemm.hpp:23
void reference_moe_gemm_gpu(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const AccDataType *expert_weight_ptr, index_t Num_tokens, index_t TokensPerBlock, index_t TopK, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr, float *exp_bias=nullptr)
Definition reference_moe_gemm.hpp:254
float fp32x2_t
Definition pk_fp4.hpp:22
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition tile/core/utility/functional.hpp:86
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82