reduce2d_kernel.hpp Source File

reduce2d_kernel.hpp Source File#

Composable Kernel: reduce2d_kernel.hpp Source File
reduce2d_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11// Reduce2d Kernel:
12// =======================================
13// This kernel implements a 2D reduction operation that reduces data along the second dimension
14// of a matrix. The reduction is performed in multiple hierarchical stages.
15
16namespace ck_tile {
17
18template <typename Problem_, typename Policy_ = Reduce2dDefaultPolicy>
19struct Reduce
20{
23
27
28 static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
29 CK_TILE_HOST static constexpr auto BlockSize()
30 {
31 return is_wave32() ? kBlockSize / 2 : kBlockSize;
32 }
33
34 private:
35 // Helper function to calculate optimal vector size for input tensor
36 template <typename InputShape, typename ReduceDims>
37 static constexpr index_t CalculateInputVectorSize()
38 {
39 using S = typename Problem::BlockShape;
40 constexpr index_t memory_vector_size = 16 / sizeof(XDataType);
41 constexpr index_t thread_tile_vector_size = S::ThreadTile_N;
42
43 // Check if innermost reduce dimension is the last dimension (stride 1).
44 constexpr auto innermost_reduce_dim = ReduceDims{}.at(number<ReduceDims{}.size() - 1>{});
45 constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1);
46
47 // If innermost reduce dimension is not the last dim (not contiguous), limit vectorization
48 constexpr index_t stride_based_vector_size =
49 is_innermost_contiguous ? ck_tile::min(memory_vector_size, thread_tile_vector_size) : 1;
50
51 return stride_based_vector_size;
52 }
53
54 // Helper function to calculate optimal vector size for output tensor
55 static constexpr index_t CalculateOutputVectorSize()
56 {
57 using S = typename Problem::BlockShape;
58 constexpr index_t memory_vector_size = 16 / sizeof(YDataType);
59 constexpr index_t thread_tile_vector_size = S::ThreadTile_M;
60 constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size);
61
62 return vector_size;
63 }
64
65 public:
66 template <typename InputShape, typename InputStrides, typename KeptDim, typename ReduceDims>
68 YDataType* p_y,
69 InputShape input_shape,
70 InputStrides input_strides,
71 KeptDim kept_dim,
72 ReduceDims reduce_dims) const
73 {
74 using S = typename Problem::BlockShape;
75 const auto iM = get_block_id() * S::Block_M;
76
77 static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(),
78 "Size of kept dimensions + reduced dimensions must equal input tensor rank");
79
80 // Extract lengths based on kept and reduced dimensions
81 const auto kept_lens = [&]() {
82 return generate_tuple([&](auto I) { return input_shape.at(number<kept_dim.at(I)>{}); },
83 number<kept_dim.size()>{});
84 }();
85 const auto reduce_lens = [&]() {
86 return generate_tuple(
87 [&](auto I) { return input_shape.at(number<reduce_dims.at(I)>{}); },
88 number<reduce_dims.size()>{});
89 }();
90
91 const auto kept_merge_transform = make_merge_transform(kept_lens);
92 const auto reduce_merge_transform = make_merge_transform(reduce_lens);
93
94 auto reduce_func = typename Problem::ReduceOp{};
95 const XDataType custom_padding_value =
96 type_convert<XDataType>(reduce_func.template GetIdentityValue<ComputeDataType>());
97
98 // Calculate optimal vector size for input tensor
99 constexpr auto x_tensor_vector_size = CalculateInputVectorSize<InputShape, ReduceDims>();
100
101 // Create input tensor view with custom padding value
103 input_shape, input_strides, number<x_tensor_vector_size>{}, number<1>{});
104
105 // Create buffer view with custom padding value
107 p_x, desc.get_element_space_size(), custom_padding_value);
108
109 // Create tensor view with custom padding
110 const auto x_tensor = tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
111 const auto transformed_x_tensor = pad_tensor_view(
112 transform_tensor_view(x_tensor,
113 make_tuple(kept_merge_transform, reduce_merge_transform),
114 make_tuple(kept_dim, reduce_dims),
118
119 // Calculate strides for output tensor based on its own dimensions
120 const auto kept_strides = [&]() {
121 return generate_tuple(
122 [&](auto I) {
123 // Calculate stride for dimension I as product of all following dimensions
124 index_t stride = 1;
125 static_for<I + 1, kept_dim.size(), 1>{}(
126 [&](auto J) { stride *= kept_lens.at(number<J>{}); });
127 return stride;
128 },
129 number<kept_dim.size()>{});
130 }();
131
132 // Calculate optimal vector size for output tensor
133 constexpr auto y_tensor_vector_size = CalculateOutputVectorSize();
134
136 p_y, kept_lens, kept_strides, number<y_tensor_vector_size>{}, number<1>{});
137
138 // Transform output tensor to 1D merged view
139 // This creates a view compatible with the 2D reduction pattern
140 const auto y_merged = transform_tensor_view(
141 y_m,
142 make_tuple(kept_merge_transform),
143 make_tuple(typename arithmetic_sequence_gen<0, kept_dim.size(), 1>::type{}),
145
146 auto x_window = make_tile_window(transformed_x_tensor,
148 {iM, 0},
149 Policy::template MakeXBlockTileDistribution<Problem>());
150
151 auto y_window = make_tile_window(y_merged, make_tuple(number<S::Block_M>{}), {iM});
152
153 __shared__ char smem[Policy::template GetSmemSize<Problem>()];
154
155 // Get the merged dimension size from the transformed tensor
156 const auto merged_reduce_len =
157 transformed_x_tensor.get_tensor_descriptor().get_lengths().at(number<1>{});
158 index_t num_n_tile_iteration =
159 amd_wave_read_first_lane(integer_divide_ceil(merged_reduce_len, S::Block_N));
160
161 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
162 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
163 auto block_reduce2d_cross_warp_sync =
164 Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
165
166 using XTensorType = decltype(load_tile(x_window));
167 auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
168 set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
169
170 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
171 {
172 const auto x = load_tile(x_window);
173 block_reduce2d(x, y_compute, reduce_func);
174 move_tile_window(x_window, {0, S::Block_N});
175 }
176
177 block_reduce2d_sync(y_compute, reduce_func);
178 block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
179
180 store_tile(y_window, cast_tile<YDataType>(y_compute));
181 }
182
198 template <typename InputStrides>
199 CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
200 InputStrides input_strides)
201 {
202 using S = typename Problem::BlockShape;
203
204 if(y_continous_dim % S::ThreadTile_N != 0)
205 {
206 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
207 {
208 CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!");
209 }
210 return false;
211 }
212
213 if(input_strides.at(number<input_strides.size() - 1>{}) != 1)
214 {
215 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
216 {
218 "Input tensor's last stride must be 1 to support correct vector access!");
219 }
220 return false;
221 }
222
223 return true;
224 }
225};
226
227} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T *__restrict__ p, BufferSizeType buffer_size)
Definition buffer_view.hpp:1262
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition reduce2d_kernel.hpp:20
ck_tile::remove_cvref_t< Problem_ > Problem
Definition reduce2d_kernel.hpp:21
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition reduce2d_kernel.hpp:26
ck_tile::remove_cvref_t< Policy_ > Policy
Definition reduce2d_kernel.hpp:22
static constexpr index_t kBlockSize
Definition reduce2d_kernel.hpp:28
static CK_TILE_HOST bool IsSupportedArgument(index_t y_continous_dim, InputStrides input_strides)
Validates if the given arguments are supported by the 2D reduction kernel.
Definition reduce2d_kernel.hpp:199
CK_TILE_DEVICE void operator()(const XDataType *p_x, YDataType *p_y, InputShape input_shape, InputStrides input_strides, KeptDim kept_dim, ReduceDims reduce_dims) const
Definition reduce2d_kernel.hpp:67
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition reduce2d_kernel.hpp:25
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition reduce2d_kernel.hpp:24
static CK_TILE_HOST constexpr auto BlockSize()
Definition reduce2d_kernel.hpp:29
Definition tile/core/container/sequence.hpp:287
Definition buffer_view.hpp:35
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145