host_gemm.hpp Source File

host_gemm.hpp Source File#

Composable Kernel: host_gemm.hpp Source File
host_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "host_tensor.hpp"
7
8template <typename AType,
9 typename BType,
10 typename CType,
11 typename AElementwiseOperation,
12 typename BElementwiseOperation,
13 typename CElementwiseOperation>
15 const Tensor<BType>& b_k_n,
16 Tensor<CType>& c_m_n,
17 const AElementwiseOperation& a_element_op,
18 const BElementwiseOperation& b_element_op,
19 const CElementwiseOperation& c_element_op)
20{
21 auto f_mk_kn_mn = [&](auto m, auto n) {
22 const int K = a_m_k.mDesc.GetLengths()[1];
23
24 float v_acc = 0;
25
26 for(int k = 0; k < K; ++k)
27 {
28 float v_a;
29 float v_b;
30
31 a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
32 b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));
33
34 v_acc += v_a * v_b;
35 }
36
37 float v_c;
38
39 c_element_op(v_c, v_acc);
40
41 c_m_n(m, n) = v_c;
42 };
43
45 c_m_n.mDesc.GetLengths()[0],
46 c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
47}
void host_gemm_mk_kn_mn(const Tensor< AType > &a_m_k, const Tensor< BType > &b_k_n, Tensor< CType > &c_m_n, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition host_gemm.hpp:14
auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition library/utility/host_tensor.hpp:687
const std::vector< std::size_t > & GetLengths() const
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition library/utility/host_tensor.hpp:694
Descriptor mDesc
Definition library/utility/host_tensor.hpp:1159