|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ |
|
|
|
#include <cstdint> |
|
#include <vector> |
|
|
|
|
|
#include "sparse_matmul/compute/ar_inputs.h" |
|
#include "sparse_matmul/compute/gru_gates_arm.h" |
|
#include "sparse_matmul/compute/gru_gates_avx_fixed.h" |
|
#include "sparse_matmul/compute/gru_gates_generic.h" |
|
#include "sparse_matmul/compute/matmul.h" |
|
#include "sparse_matmul/numerics/fixed_types.h" |
|
#include "sparse_matmul/numerics/type_utils.h" |
|
#include "sparse_matmul/vector/cache_aligned_vector.h" |
|
|
|
|
|
namespace csrblocksparse { |
|
|
|
|
|
|
|
template <typename GRUStateType, typename InputType, typename SampleType = void> |
|
class GruGates : public MatmulBase { |
|
public: |
|
using SampleWeightType = float; |
|
static constexpr int kSIMDWidth = kGenericSIMDWidth; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, |
|
bool kSplitGates = false> |
|
void GruWithARInput(int start, int end, int state_size, |
|
const InputType* gru_recurrent_ptr, |
|
const InputType* input_ptr, GRUStateType* gru_state_ptr, |
|
const SampleType* ar_sample0 = nullptr, |
|
const SampleType* ar_sample1 = nullptr, |
|
const SampleWeightType* ar_01_weights = nullptr, |
|
int num_replicas = 1, int replica_stride = 0, |
|
const SampleType* ar_sample2 = nullptr, |
|
const SampleWeightType* ar_2_weights = nullptr, |
|
const InputType* gru_recurrent_other_ptr = nullptr) { |
|
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; |
|
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType, |
|
kInputsMode, kSplitGates>( |
|
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, |
|
input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0, |
|
ar_sample1, ar_sample2); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void PlainGru(int start, int end, int state_size, |
|
const InputType* gru_recurrent_ptr, const InputType* input_ptr, |
|
GRUStateType* gru_state_ptr) { |
|
GruWithARInput<ARInputsMode::k0ARInputs>( |
|
start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr); |
|
} |
|
}; |
|
|
|
#if defined __ARM_NEON || defined __aarch64__ |
|
|
|
template <> |
|
class GruGates<float, float, float> : public MatmulBase { |
|
public: |
|
static constexpr int kSIMDWidth = kNeonSIMDWidth; |
|
|
|
|
|
|
|
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, |
|
bool kSplitGates = false> |
|
void GruWithARInput(int start, int end, int state_size, |
|
const float* gru_recurrent_data, const float* input_data, |
|
float* gru_state_data, const float* ar_sample0 = nullptr, |
|
const float* ar_sample1 = nullptr, |
|
const float* ar_01_weights = nullptr, |
|
int num_replicas = 1, int replica_stride = 0, |
|
const float* ar_sample2 = nullptr, |
|
const float* ar_2_weights = nullptr, |
|
const float* gru_recurrent_other_data = nullptr) { |
|
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; |
|
GoThroughGatesFloat<kInputsMode, kSplitGates>( |
|
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, |
|
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, |
|
ar_sample1, ar_sample2); |
|
} |
|
}; |
|
#endif |
|
|
|
|
|
|
|
template <int kGRUStateBits, int kInputBits, int kSampleBits> |
|
class GruGates<fixed16<kGRUStateBits>, fixed32<kInputBits>, |
|
fixed16<kSampleBits>> : public MatmulBase { |
|
public: |
|
#if defined __ARM_NEON || defined __aarch64__ |
|
static constexpr int kSIMDWidth = kNeonSIMDWidth; |
|
#elif defined __AVX2__ |
|
static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2; |
|
#else |
|
static constexpr int kSIMDWidth = kGenericSIMDWidth; |
|
#endif |
|
|
|
using GRUStateType = fixed16<kGRUStateBits>; |
|
using InputType = fixed32<kInputBits>; |
|
using SampleType = fixed16<kSampleBits>; |
|
using SampleWeightType = float; |
|
static constexpr int kInputMantissaBits = InputType::kMantissaBits; |
|
static constexpr int kSampleMantissaBits = SampleType::kMantissaBits; |
|
static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits; |
|
|
|
|
|
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, |
|
bool kSplitGates = false> |
|
void GruWithARInput(int start, int end, int state_size, |
|
const InputType* gru_recurrent_data, |
|
const InputType* input_data, GRUStateType* gru_state_data, |
|
const SampleType* ar_sample0 = nullptr, |
|
const SampleType* ar_sample1 = nullptr, |
|
const SampleWeightType* ar_01_weights = nullptr, |
|
int num_replicas = 1, int replica_stride = 0, |
|
const SampleType* ar_sample2 = nullptr, |
|
const SampleWeightType* ar_2_weights = nullptr, |
|
const InputType* gru_recurrent_other_data = nullptr) { |
|
#if defined __ARM_NEON || defined __aarch64__ || defined __AVX2__ |
|
const int32_t* gru_recurrent_ptr = |
|
reinterpret_cast<const int32_t*>(gru_recurrent_data); |
|
const int32_t* gru_recurrent_other_ptr = |
|
reinterpret_cast<const int32_t*>(gru_recurrent_other_data); |
|
const int32_t* input_ptr = reinterpret_cast<const int32_t*>(input_data); |
|
int16_t* gru_state_ptr = reinterpret_cast<int16_t*>(gru_state_data); |
|
#if defined __AVX2__ |
|
|
|
|
|
|
|
const float sample_factor = static_cast<float>(1 << kInputMantissaBits); |
|
#else |
|
const float sample_factor = 1.0f; |
|
#endif |
|
|
|
|
|
std::pair<float, float> ar_sample01; |
|
float ar_sample2_float = 0.0f; |
|
if (kInputsMode == ARInputsMode::k2ARInputs || |
|
kInputsMode == ARInputsMode::k3ARInputs) { |
|
ar_sample01 = {static_cast<float>(*ar_sample0) * sample_factor, |
|
static_cast<float>(*ar_sample1) * sample_factor}; |
|
if (kInputsMode == ARInputsMode::k3ARInputs) { |
|
ar_sample2_float = static_cast<float>(*ar_sample2) * sample_factor; |
|
} |
|
} |
|
#if defined __AVX2__ |
|
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; |
|
GruGatesAVXFixed<kInputMantissaBits, kStateMantissaBits, kInputsMode, |
|
kSplitGates>( |
|
start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01, |
|
ar_01_weights, num_replicas, replica_stride, &ar_sample2_float, |
|
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); |
|
#else |
|
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; |
|
GoThroughGatesFixed<GRUStateType, InputType, kInputsMode, kSplitGates>( |
|
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, |
|
input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01, |
|
&ar_sample2_float); |
|
#endif |
|
#else |
|
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; |
|
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType, |
|
kInputsMode, kSplitGates>( |
|
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, |
|
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, |
|
ar_sample1, ar_sample2); |
|
#endif |
|
} |
|
}; |
|
|
|
} |
|
|
|
#endif |
|
|