|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ |
|
|
|
#include <cstdint> |
|
|
|
#include "absl/memory/memory.h" |
|
#include "glog/logging.h" |
|
#include "sparse_matmul/layers/csr_blocksparse_matrix.h" |
|
#include "sparse_matmul/layers/masked_sparse_matrix.h" |
|
#include "sparse_matmul/numerics/type_utils.h" |
|
#include "sparse_matmul/os/coop_threads.h" |
|
#include "sparse_matmul/vector/cache_aligned_vector.h" |
|
|
|
namespace csrblocksparse { |
|
|
|
template <typename WeightType, typename RhsType, |
|
typename BiasType = typename TypeOfProduct<WeightType, RhsType>::type, |
|
typename DeltaType = int16_t> |
|
class SparseLinearLayer { |
|
public: |
|
SparseLinearLayer() {} |
|
|
|
SparseLinearLayer(CsrBlockSparseMatrix<WeightType, RhsType>&& sparse_matrix, |
|
CacheAlignedVector<BiasType>&& bias) |
|
: sparse_matrix_(std::move(sparse_matrix)), full_bias_(std::move(bias)) { |
|
CHECK_EQ(sparse_matrix_.rows(), full_bias_.size()); |
|
|
|
|
|
|
|
|
|
bias_ = full_bias_; |
|
for (int i = 0; i < bias_.size(); ++i) { |
|
bias_[i] = static_cast<BiasType>(.25f * static_cast<float>(bias_[i])); |
|
} |
|
} |
|
SparseLinearLayer( |
|
const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) { |
|
*this = src; |
|
} |
|
SparseLinearLayer& operator=( |
|
const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) { |
|
sparse_matrix_ = src.sparse_matrix_; |
|
bias_ = src.bias_; |
|
full_bias_ = src.full_bias_; |
|
mid_output_ = src.mid_output_; |
|
thread_layers_ = src.thread_layers_; |
|
num_threads_ = src.num_threads_; |
|
if (src.split_pc_) { |
|
split_pc_ = absl::make_unique<ProducerConsumer>( |
|
src.split_pc_->num_producers(), src.split_pc_->num_consumers()); |
|
} |
|
return *this; |
|
} |
|
|
|
|
|
|
|
|
|
template <typename RhsClassType, typename OutType> |
|
void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false, |
|
int tid = 0, SpinBarrier* barrier = nullptr) const { |
|
static_assert( |
|
std::is_same<typename RhsClassType::value_type, RhsType>::value, ""); |
|
sparse_matrix_.SpMM_bias(rhs, bias_, out, relu, tid, barrier); |
|
} |
|
|
|
|
|
template <typename RhsClassType, typename OutType> |
|
int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature, |
|
int tid, SpinBarrier* barrier, std::minstd_rand* gen, |
|
CacheAlignedVector<float>* scratch) const { |
|
static_assert( |
|
std::is_same<typename RhsClassType::value_type, RhsType>::value, ""); |
|
return sparse_matrix_.SpMM_bias_Sample(rhs, bias_, out, temperature, tid, |
|
barrier, gen, scratch); |
|
} |
|
template <typename RhsClassType, typename OutType> |
|
void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas, |
|
int output_stride, OutType* output, |
|
SpinBarrier* barrier = nullptr) { |
|
static_assert( |
|
std::is_same<typename RhsClassType::value_type, RhsType>::value, ""); |
|
#ifdef __AVX2__ |
|
if (block_width() == 4 && (block_height() == 4 || block_height() == 8) && |
|
!IsCustomFloatType<WeightType>::value) { |
|
if (!IsSplit()) { |
|
sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, |
|
tid, replicas, output_stride, output->data()); |
|
if (barrier != nullptr) barrier->barrier(); |
|
return; |
|
} |
|
|
|
|
|
|
|
|
|
split_pc_->produce(); |
|
PartLinearLayer& thread_part = thread_layers_[tid]; |
|
auto offset_output = |
|
sparse_matrix_.thread_bounds().OffsetOutput(output->data(), tid); |
|
auto mid_output = |
|
sparse_matrix_.thread_bounds().OffsetOutput(mid_output_.data(), tid); |
|
auto offset_bias = sparse_matrix_.thread_bounds().OffsetOutput( |
|
mid_output_.cast_data(), tid); |
|
|
|
|
|
|
|
|
|
thread_part.self_matrix.MatVec( |
|
rhs.cast_data(), thread_part.full_bias.cast_data(), false, |
|
0, 1, output_stride, mid_output); |
|
|
|
|
|
split_pc_->consume(); |
|
thread_part.other_matrix.MatVec(rhs.cast_data(), offset_bias, relu, |
|
0, replicas, output_stride, |
|
offset_output); |
|
return; |
|
} |
|
#endif |
|
DCHECK_EQ(replicas, 1) << "Must have single replica for SpMM API"; |
|
if (IsSplit()) { |
|
|
|
split_pc_->produce(); |
|
split_pc_->consume(); |
|
} |
|
if (block_height() == 8) { |
|
|
|
LOG(WARNING) << "Need to implement MatVec for 8x4 for non-AVX2 targets!!"; |
|
sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, tid, |
|
replicas, output_stride, output->data()); |
|
if (barrier != nullptr) barrier->barrier(); |
|
} else { |
|
sparse_matrix_.SpMM_bias(rhs, bias_, output, relu, tid, barrier); |
|
} |
|
} |
|
|
|
int rows() const { return sparse_matrix_.rows(); } |
|
int cols() const { return sparse_matrix_.cols(); } |
|
float sparsity() const { return sparse_matrix_.sparsity(); } |
|
int block_width() const { return sparse_matrix_.block_width(); } |
|
int block_height() const { return sparse_matrix_.block_height(); } |
|
int num_threads() const { return sparse_matrix_.num_threads(); } |
|
const CacheAlignedVector<BiasType>& bias() const { return bias_; } |
|
const std::vector<int>& split_points() const { |
|
return sparse_matrix_.split_points(); |
|
} |
|
bool IsSplit() const { |
|
return !thread_layers_.empty() && split_pc_ != nullptr; |
|
} |
|
|
|
std::size_t bytes() const { return sparse_matrix_.bytes() + bias_.bytes(); } |
|
void Print() const { |
|
printf("Matrix\n"); |
|
sparse_matrix_.Print(); |
|
printf("Bias\n"); |
|
bias_.Print(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void DoubleBlockHeight() { sparse_matrix_.DoubleBlockHeight(); } |
|
|
|
|
|
|
|
int PrepareForThreads(int num_threads, int cache_line_size = -1) { |
|
num_threads_ = num_threads; |
|
if (num_threads_ > 1) { |
|
split_pc_ = |
|
absl::make_unique<ProducerConsumer>(num_threads_, num_threads_); |
|
} else { |
|
split_pc_.reset(nullptr); |
|
} |
|
return sparse_matrix_.PrepareForThreads(num_threads, cache_line_size); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void SliceForThreads(const std::vector<int>& split_points) { |
|
thread_layers_.clear(); |
|
thread_layers_.reserve(num_threads_); |
|
LOG(INFO) << "Slicing " << rows() << "x" << cols() << " matrix for " |
|
<< num_threads_ << " threads"; |
|
for (int tid = 0; tid < num_threads_; ++tid) { |
|
thread_layers_.emplace_back( |
|
sparse_matrix_, full_bias_, bias_, tid, |
|
split_points[tid] * sparse_matrix_.block_height(), |
|
split_points[tid + 1] * sparse_matrix_.block_height()); |
|
} |
|
mid_output_ = |
|
std::move(csrblocksparse::CacheAlignedVector<BiasType>(rows())); |
|
mid_output_.FillZero(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void SplitInputs( |
|
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1, |
|
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) { |
|
CsrBlockSparseMatrix<WeightType, RhsType> matrix1( |
|
sparse_matrix_.SplitByColumn(0, sparse_matrix_.cols() / 2)); |
|
CsrBlockSparseMatrix<WeightType, RhsType> matrix2( |
|
sparse_matrix_.SplitByColumn(sparse_matrix_.cols() / 2, |
|
sparse_matrix_.cols())); |
|
*part1 = |
|
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>( |
|
std::move(matrix1), |
|
std::move(CacheAlignedVector<BiasType>(full_bias_)))); |
|
CacheAlignedVector<BiasType> bias2(sparse_matrix_.rows()); |
|
bias2.FillZero(); |
|
*part2 = |
|
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>( |
|
std::move(matrix2), std::move(bias2))); |
|
} |
|
|
|
|
|
|
|
|
|
void SplitOutputs( |
|
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1, |
|
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) { |
|
LOG(INFO) << "input rows=" << sparse_matrix_.rows() |
|
<< ", cols=" << sparse_matrix_.cols(); |
|
CsrBlockSparseMatrix<WeightType, RhsType> matrix1( |
|
sparse_matrix_.SplitByRow(0, sparse_matrix_.rows() / 2)); |
|
CsrBlockSparseMatrix<WeightType, RhsType> matrix2(sparse_matrix_.SplitByRow( |
|
sparse_matrix_.rows() / 2, sparse_matrix_.rows())); |
|
CacheAlignedVector<BiasType> bias1(full_bias_, 0, full_bias_.size() / 2); |
|
*part1 = |
|
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>( |
|
std::move(matrix1), std::move(bias1))); |
|
CacheAlignedVector<BiasType> bias2(full_bias_, full_bias_.size() / 2, |
|
full_bias_.size()); |
|
*part2 = |
|
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>( |
|
std::move(matrix2), std::move(bias2))); |
|
} |
|
|
|
private: |
|
|
|
struct PartLinearLayer { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PartLinearLayer(const CsrBlockSparseMatrix<WeightType, RhsType>& matrix, |
|
const CacheAlignedVector<BiasType>& bias, |
|
const CacheAlignedVector<BiasType>& bias_4, int tid, |
|
int start_col, int end_col) { |
|
int block_height = matrix.block_height(); |
|
|
|
|
|
int start_row = matrix.split_points()[tid] * block_height; |
|
int end_row = matrix.split_points()[tid + 1] * block_height; |
|
LOG(INFO) << "input cols [" << start_col << "," << end_col << ") rows [" |
|
<< start_row << "," << end_row << ")"; |
|
CsrBlockSparseMatrix<WeightType, RhsType> row_sub_matrix = |
|
matrix.SplitByRow(start_row, end_row); |
|
|
|
|
|
|
|
|
|
|
|
self_matrix = std::move(row_sub_matrix.SplitByColumn( |
|
start_col, end_col, true)); |
|
self_matrix.PrepareForThreads(1); |
|
|
|
other_matrix = std::move(row_sub_matrix.SplitByColumn( |
|
end_col, start_col, true)); |
|
other_matrix.PrepareForThreads(1); |
|
full_bias = |
|
std::move(CacheAlignedVector<BiasType>(bias, start_row, end_row)); |
|
|
|
quarter_bias = |
|
std::move(CacheAlignedVector<BiasType>(bias_4, start_row, end_row)); |
|
} |
|
|
|
CsrBlockSparseMatrix<WeightType, RhsType> self_matrix; |
|
CacheAlignedVector<BiasType> full_bias; |
|
CacheAlignedVector<BiasType> quarter_bias; |
|
|
|
CsrBlockSparseMatrix<WeightType, RhsType> other_matrix; |
|
}; |
|
CsrBlockSparseMatrix<WeightType, RhsType, DeltaType> sparse_matrix_; |
|
CacheAlignedVector<BiasType> bias_; |
|
CacheAlignedVector<BiasType> full_bias_; |
|
|
|
CacheAlignedVector<BiasType> mid_output_; |
|
|
|
std::vector<PartLinearLayer> thread_layers_; |
|
|
|
|
|
std::unique_ptr<ProducerConsumer> split_pc_; |
|
int num_threads_ = 0; |
|
}; |
|
|
|
template <typename WeightType, typename RhsType> |
|
SparseLinearLayer<WeightType, RhsType> CreateRandomLayer(int rows, int cols, |
|
float sparsity, |
|
int block_height = 1, |
|
int block_width = 1) { |
|
typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType; |
|
CacheAlignedVector<BiasType> bias(rows); |
|
bias.FillRandom(); |
|
|
|
auto masked_matrix = MaskedSparseMatrix<float>(rows, cols, sparsity, |
|
block_height, block_width); |
|
auto sparse_matrix = CsrBlockSparseMatrix<WeightType, RhsType>(masked_matrix); |
|
|
|
return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix), |
|
std::move(bias)); |
|
} |
|
|
|
template <typename WeightType, typename RhsType> |
|
SparseLinearLayer<WeightType, RhsType> CreateConstantLayer( |
|
int rows, int cols, float sparsity, float constant = 1.f) { |
|
typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType; |
|
CacheAlignedVector<BiasType> bias(rows); |
|
bias.FillOnes(); |
|
|
|
MaskedSparseMatrix<float> masked_matrix(rows, cols, sparsity, |
|
1, 1, |
|
constant, false); |
|
CsrBlockSparseMatrix<WeightType, RhsType> sparse_matrix(masked_matrix); |
|
|
|
return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix), |
|
std::move(bias)); |
|
} |
|
|
|
} |
|
|
|
#endif |
|
|