reference_moe_sorting.hpp Source File

reference_moe_sorting.hpp Source File#

Composable Kernel: reference_moe_sorting.hpp Source File
reference_moe_sorting.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
12 static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
13
14template <typename WeightType, typename IndexType = index_t>
16 const HostTensor<WeightType>& weights,
17 const HostTensor<IndexType>& local_expert_mask,
18 HostTensor<IndexType>& p_sorted_token_ids,
19 HostTensor<WeightType>& sorted_weight,
20 HostTensor<IndexType>& sorted_expert_ids,
21 index_t& unit_cnt,
22 const index_t experts,
23 const index_t unit_size,
24 const index_t tokens,
25 bool local_expert_masking,
26 bool skip_experts_with_zero_token = true)
27{
28 // note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
29 const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
30 const index_t topk = topk_ids.mDesc.get_lengths()[1];
31 // allocate a temp buffer, and fill the value with [number_token|topk]
32 std::vector<std::vector<IndexType>> expert_tokens(
33 experts,
35 std::vector<IndexType>(unit_size, MOE_SORTING_MOCK_ID(num_token, topk)));
36#else
37 std::vector<IndexType>(unit_size, num_token));
38#endif
39 std::vector<std::vector<WeightType>> expert_token_weights(
40 experts, std::vector<WeightType>(unit_size, 0));
41 // count number of unit-size slices in this expert
42 std::vector<IndexType> expert_slices(experts, 1);
43 // count the tokens used in this expert
44 std::vector<IndexType> expert_slice_idxs(experts, 0);
45 // TODO: above 2 buffer seems duplicated
46
47 for(index_t t = 0; t < num_token; t++)
48 {
49 for(index_t k = 0; k < topk; k++)
50 {
51 IndexType e = topk_ids(t, k);
52 WeightType w = weights(t, k);
53 index_t idx = expert_slice_idxs[e];
54 if(idx > expert_slices[e] * unit_size - 1)
55 {
56 expert_slices[e]++;
57 index_t new_size = expert_slices[e] * unit_size;
58 expert_tokens[e].resize(new_size);
59 expert_token_weights[e].resize(new_size);
60 for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
61 {
62#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
63 expert_tokens[e][i] = MOE_SORTING_MOCK_ID(num_token, topk);
64#else
65 expert_tokens[e][i] = num_token;
66#endif
67 expert_token_weights[e][i] = 0;
68 }
69 }
70#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
71 expert_tokens[e][idx] = MOE_SORTING_MOCK_ID(t, k);
72#else
73 expert_tokens[e][idx] = t;
74#endif
75 expert_token_weights[e][idx] = w;
76 expert_slice_idxs[e]++;
77 }
78 }
79
80 IndexType* out_tokens = p_sorted_token_ids.data();
81 WeightType* out_weights = sorted_weight.data();
82 IndexType* out_expert_id = sorted_expert_ids.data();
83 int curr_expert_id = 0;
84 for(index_t e = 0; e < experts; e++)
85 {
86 if(local_expert_masking)
87 {
88 if(local_expert_mask(e) == 0)
89 continue;
90 }
91 if(skip_experts_with_zero_token)
92 {
93 if(expert_slice_idxs[e] == 0)
94 {
95 curr_expert_id++;
96 continue;
97 }
98 }
99
100 memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
101 out_tokens += expert_slices[e] * unit_size;
102 memcpy(out_weights,
103 expert_token_weights[e].data(),
104 sizeof(WeightType) * expert_slices[e] * unit_size);
105 out_weights += expert_slices[e] * unit_size;
106
107 for(index_t s = 0; s < expert_slices[e]; s++)
108 {
109 out_expert_id[s] = curr_expert_id;
110 unit_cnt++;
111 }
112 out_expert_id += expert_slices[e];
113 curr_expert_id++;
114 }
115 unit_cnt *= unit_size;
116 return;
117}
118
119#undef MOE_SORTING_MOCK_ID
120
121} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
Definition config.hpp:251
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST void reference_moe_sorting(const HostTensor< IndexType > &topk_ids, const HostTensor< WeightType > &weights, const HostTensor< IndexType > &local_expert_mask, HostTensor< IndexType > &p_sorted_token_ids, HostTensor< WeightType > &sorted_weight, HostTensor< IndexType > &sorted_expert_ids, index_t &unit_cnt, const index_t experts, const index_t unit_size, const index_t tokens, bool local_expert_masking, bool skip_experts_with_zero_token=true)
Definition reference_moe_sorting.hpp:15
int32_t index_t
Definition integer.hpp:9
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_)
Definition reference_moe_sorting.hpp:11
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800
Data::pointer data()
Definition tile/host/host_tensor.hpp:591