Don't understand NNUE

Discussion of chess software programming and technical issues.

Moderator: Ras

Gerd Isenberg
Posts: 2251
Joined: Wed Mar 08, 2006 8:47 pm
Location: Hattingen, Germany

Re: Don't understand NNUE

Post by Gerd Isenberg »

lucasart wrote: Tue Aug 18, 2020 6:30 am Presumably this is because they use the smallest int possible (16 bits then 8 bits), to reduce the number of chunks of SIMD dot product operations. So one needs to carefully handle overflows.
Overflows are already handled by saturated arithmetcs. So I guess ReLu has the advantages mentioned in wikipedia.

Considering AVX2 in AffineTransform only, they use _mm256_maddubs_epi16 (vpmaddubsw) for intermediate saturated 16-bit results as sum of two 8 byte multiplications, and _mm256_madd_epi16 (vpmaddwd ) as mul with 1, and horinzontal add of consecutive 16 to 32-bit int, accumuleated via _mm256_add_epi32 in sum. After further horizontal add and some shuffling, the 32-bit sums are written as 32-bit ints to output.

Code: Select all

// Affine transformation layer
template <typename PreviousLayer, IndexType OutputDimensions>
class AffineTransform {
 public:
  using OutputType = std::int32_t;
  ...
  
  // AVX2 
  const OutputType* Propagate( const TransformedFeatureType* transformed_features, char* buffer) const
  {
    const auto input = previous_layer_.Propagate(transformed_features, buffer + kSelfBufferSize);
    const auto output = reinterpret_cast<OutputType*>(buffer);

    constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
    const __m256i kOnes = _mm256_set1_epi16(1);
    const auto input_vector = reinterpret_cast<const __m256i*>(input);

    for (IndexType i = 0; i < kOutputDimensions; ++i) {
      const IndexType offset = i * kPaddedInputDimensions;
      __m256i sum = _mm256_setzero_si256();
      const auto row = reinterpret_cast<const __m256i*>(&weights_[offset]);

      for (IndexType j = 0; j < kNumChunks; ++j) {
        __m256i product = _mm256_maddubs_epi16(_mm256_loadA_si256(&input_vector[j]), _mm256_load_si256(&row[j]));
        product = _mm256_madd_epi16(product, kOnes);
        sum = _mm256_add_epi32(sum, product);
      }
      __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
      sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
      sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
      output[i] = _mm_cvtsi128_si32(sum128) + biases_[i];
    }
    return output;
  }
In ClippedReLU, the 32-bit results are packed to 16-bit integers using signed saturation (_mm256_packs_epi32 aka vpackssdw), and arithmetically shifted right by kWeightScaleBits = 6 (idiv 128). FInally with _mm256_packs_epi16 the 16-bit words0 and words1 are packed to 8-bit integers using signed saturation (_mm256_packs_epi16), before _mm256_max_epi8 implements the 8-bit ReLu.

Code: Select all

// Clipped ReLU
template <typename PreviousLayer>
class ClippedReLU {
 public:
  // Input/output type
  using InputType = typename PreviousLayer::OutputType;
  using OutputType = std::uint8_t;
  static_assert(std::is_same<InputType, std::int32_t>::value, "");
 
  ...
// AVX2
  const OutputType* Propagate( const TransformedFeatureType* transformed_features, char* buffer) const
  {
    const auto input = previous_layer_.Propagate( transformed_features, buffer + kSelfBufferSize);
    const auto output = reinterpret_cast<OutputType*>(buffer);

    constexpr IndexType kNumChunks = kInputDimensions / kSimdWidth;
    const __m256i kZero = _mm256_setzero_si256();
    const __m256i kOffsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
    const auto in = reinterpret_cast<const __m256i*>(input);
    const auto out = reinterpret_cast<__m256i*>(output);
    for (IndexType i = 0; i < kNumChunks; ++i) {
      const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32(
          _mm256_loadA_si256(&in[i * 4 + 0]),
          _mm256_loadA_si256(&in[i * 4 + 1])), kWeightScaleBits);
      const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32(
          _mm256_loadA_si256(&in[i * 4 + 2]),
          _mm256_loadA_si256(&in[i * 4 + 3])), kWeightScaleBits);
      _mm256_storeA_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8(
          _mm256_packs_epi16(words0, words1), kZero), kOffsets));
    }
...