DeviceSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings > Struct Template Reference#
Classes |
Public Types |
Public Member Functions |
Static Public Member Functions |
List of all members
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings > Struct Template Reference
#include <device_sparse_embeddings_forward_layernorm.hpp>
Inheritance diagram for ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings >:
Classes | |
| struct | Argument |
| struct | Invoker |
Public Types | |
| using | GridwiseSparseEmbedding |
Public Member Functions | |
| std::unique_ptr< BaseArgument > | MakeArgumentPointer (void *p_out, const ck::Array< EmbType *, NumEmbeddings > &p_embs, const ck::Array< IndexType *, NumEmbeddings > &p_indexs, const void *p_gamma, const void *p_beta, ck::index_t EmbeddingDim, ck::index_t IndexLength, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op) |
| bool | IsSupportedArgument (const BaseArgument *p_arg) override |
| virtual std::unique_ptr< BaseInvoker > | MakeInvokerPointer () |
| std::string | GetTypeString () const override |
| Public Member Functions inherited from ck::tensor_operation::device::BaseOperator | |
| BaseOperator ()=default | |
| BaseOperator (const BaseOperator &)=default | |
| BaseOperator & | operator= (const BaseOperator &)=default |
| virtual std::string | GetInstanceString () const |
| virtual std::string | GetTypeIdName () const |
| virtual std::optional< std::string > | GetObjectName () const |
| virtual std::optional< std::string > | GetTemplateInfo () const |
| virtual std::string | GetTypeIdHashCode () const |
| virtual size_t | GetWorkSpaceSize (const BaseArgument *) const |
| virtual void | SetWorkSpacePointer (BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const |
| virtual | ~BaseOperator () |
Static Public Member Functions | |
| static auto | MakeOutputDescriptor (const index_t index_length, const index_t rows) |
| static bool | IsSupportedArgument (const Argument *p_arg) |
Member Typedef Documentation
◆ GridwiseSparseEmbedding
template<typename EmbType, typename IndexType, typename GammaDataType, typename BetaDataType, typename AccDataType, typename OutType, typename EmbElementwiseOperation, ck::index_t BlockSize, ck::index_t DimClusterSize, ck::index_t RowClusterSize, ck::index_t DimPerBlock, ck::index_t RowPerBlock, ck::index_t DimThreadSize, ck::index_t RowVectorSize, ck::index_t NumEmbeddings>
| using ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings >::GridwiseSparseEmbedding |
Initial value:
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(MakeOutputDescriptor(1, 1)),
EmbElementwiseOperation,
BlockSize,
DimClusterSize,
RowClusterSize,
DimPerBlock,
RowPerBlock,
DimThreadSize,
RowVectorSize,
NumEmbeddings>
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:57
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
Definition device_sparse_embeddings_forward_layernorm.hpp:42
Member Function Documentation
◆ GetTypeString()
template<typename EmbType, typename IndexType, typename GammaDataType, typename BetaDataType, typename AccDataType, typename OutType, typename EmbElementwiseOperation, ck::index_t BlockSize, ck::index_t DimClusterSize, ck::index_t RowClusterSize, ck::index_t DimPerBlock, ck::index_t RowPerBlock, ck::index_t DimThreadSize, ck::index_t RowVectorSize, ck::index_t NumEmbeddings>
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ IsSupportedArgument() [1/2]
template<typename EmbType, typename IndexType, typename GammaDataType, typename BetaDataType, typename AccDataType, typename OutType, typename EmbElementwiseOperation, ck::index_t BlockSize, ck::index_t DimClusterSize, ck::index_t RowClusterSize, ck::index_t DimPerBlock, ck::index_t RowPerBlock, ck::index_t DimThreadSize, ck::index_t RowVectorSize, ck::index_t NumEmbeddings>
|
inlinestatic |
◆ IsSupportedArgument() [2/2]
template<typename EmbType, typename IndexType, typename GammaDataType, typename BetaDataType, typename AccDataType, typename OutType, typename EmbElementwiseOperation, ck::index_t BlockSize, ck::index_t DimClusterSize, ck::index_t RowClusterSize, ck::index_t DimPerBlock, ck::index_t RowPerBlock, ck::index_t DimThreadSize, ck::index_t RowVectorSize, ck::index_t NumEmbeddings>
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ MakeArgumentPointer()
template<typename EmbType, typename IndexType, typename GammaDataType, typename BetaDataType, typename AccDataType, typename OutType, typename EmbElementwiseOperation, ck::index_t BlockSize, ck::index_t DimClusterSize, ck::index_t RowClusterSize, ck::index_t DimPerBlock, ck::index_t RowPerBlock, ck::index_t DimThreadSize, ck::index_t RowVectorSize, ck::index_t NumEmbeddings>
|
inline |
◆ MakeInvokerPointer()
template<typename EmbType, typename IndexType, typename GammaDataType, typename BetaDataType, typename AccDataType, typename OutType, typename EmbElementwiseOperation, ck::index_t BlockSize, ck::index_t DimClusterSize, ck::index_t RowClusterSize, ck::index_t DimPerBlock, ck::index_t RowPerBlock, ck::index_t DimThreadSize, ck::index_t RowVectorSize, ck::index_t NumEmbeddings>
|
inlinevirtual |
◆ MakeOutputDescriptor()
template<typename EmbType, typename IndexType, typename GammaDataType, typename BetaDataType, typename AccDataType, typename OutType, typename EmbElementwiseOperation, ck::index_t BlockSize, ck::index_t DimClusterSize, ck::index_t RowClusterSize, ck::index_t DimPerBlock, ck::index_t RowPerBlock, ck::index_t DimThreadSize, ck::index_t RowVectorSize, ck::index_t NumEmbeddings>
|
inlinestatic |
The documentation for this struct was generated from the following file: