Network quantization

Discussion of chess software programming and technical issues.

Moderator: Ras

alvinypeng
Posts: 36
Joined: Thu Mar 03, 2022 7:29 am
Full name: Alvin Peng

Network quantization

Post by alvinypeng »

I've been experimenting with different ways to do NNUEs. Recently, I've trained a net with P features (color, type, square) with only one perspective.

The net is input->hidden->relu->output, which is about as simple as a network can get. However, for this net in particular, the outputs of the quantized version are sometimes extremely different from the outputs of the unquantized version.

The pseudocode for my quantization process is something like this:

Code: Select all

hidden_weights = torch.round(HIDDEN_QUANTIZE_SCALE * hidden_weights)
hidden_biases = torch.round(HIDDEN_QUANTIZE_SCALE * hidden_biases)
    
MAX_FEATURES = 32
max_hidden_weight = torch.max(torch.abs(hidden_weights)).item()
max_hidden_bias = torch.max(torch.abs(hidden_biases)).item()
max_hidden_output = MAX_FEATURES * max_hidden_weight + max_hidden_bias
assert max_hidden_weight < MAX_INT16
assert max_hidden_bias < MAX_INT16
assert max_hidden_output < MAX_INT16
    
output_weights = torch.round(OUTPUT_QUANTIZE_SCALE * output_weights)
output_biases = torch.round(HIDDEN_QUANTIZE_SCALE * OUTPUT_QUANTIZE_SCALE *  output_biases)
    
max_output_weight = torch.max(torch.abs(output_weights)).item()
max_output_bias = torch.max(torch.abs(output_biases)).item()
max_output = HIDDEN_SIZE * max_hidden_output * max_output_weight + max_output_bias
assert max_output_weight < MAX_INT16
assert max_output_bias < MAX_INT32
assert max_output < MAX_INT32
Am I missing something? Is there a better way to do network quantization?
Joost Buijs
Posts: 1632
Joined: Thu Jul 16, 2009 10:47 am
Location: Almere, The Netherlands

Re: Network quantization

Post by Joost Buijs »

I also use 16 bit quantization in my NNUE networks, usually I see very small differences in output between a float32 and a int16 network.
The quantization scales I use lie between (1 << 9) and (1 << 12), when I make them larger overflows start to appear, when I make them smaller the error is getting too large.

When you see very large differences between the output of the quantized and the non-quantized network it probably is an overflow.
I assume the tensors in your quantization pseudocode are all 32 bit float, this looks ok, maybe the code you use for inference is the culprit.
When you use saturated arithmetic in your inference code an overflow now and then is not a big problem, unfortunately this is not always possible.