Binarized Neural Networks (BNNs) are neural networks whose weights and activations are binary (+1 or −1) at runtime. During training, these binarized weights and activations are used to compute parameter gradients. And in forward-pass, the same are used to generate inference. The power of BNNs come from the fact that during forward pass, the complex arithmetic is replaced by bit-wise operations which substantially improves energy efficiency alongwith added benefit of reduced memory size and accesses. The BNN achieved nearly state-of-the-art results on MNIST, CIFAR-10 and SVHN datasets.
Converting Neural Network weights to binary (Binarization function)
Deterministic Binarization:
xb=Sign(x)={−1,+1,x<0,x≥0.
where, xb is the binarized value (weight or activation) and x is the real-valued variable.
Stochastic Binarization Function:
xb={+1,−1,with probability p=σ(x),with probability 1 - p
where σ is hard Sigmoid function:
σ(x)=clip(2x+1,0,1)=max(0,min(1,2x+1))
Stochastic binarization is harder to implement in hardware as it requires generation of random bits when quantizing. Therefore, mostly deterministic binarization is used (with some exceptions during train-time − dataset dependent).
Gradient Computation and Accumulation
Even through binary weights and activations are used for training, the gradients are still stored as high-precision floats. Without these gradients in full/half-precision floating point (FP), the stochastic gradient descent (SGD) would not work at all.
Binarization also introduces some noise to weights and activations when computing the parameter gradients which acts as a regularization technique to generalise better − like a variation of dropout, instead of randomly setting half of the activations to zero, activations and weights are binarized instead. Also see: Variational Weight Noise, DropOut, DropConnect
Propagating Gradients through Discretization
The derivative of Sign function is zero almost everywhere which makes it incompatible for backpropagation. So a variation of "straight-through estimator" that takes into account saturation effect and uses deterministic rather than stochastic sampling of the bit. See Estimating or Propagating Gradients Through Stochastic Neurons
Given the Sign function activation:
q=Sign(r)
and assuming that the estimator gq of the gradient ∂q∂C has been obtained (with the straight-through estimator when needed). then the straight-through estimator if ∂r∂C is simply:
gr=gq1∣r∣−1
This preserves gradient information and cancels the gradient when r is too large − which if not cancelled, worsens the performance. The derivative 1∣r∣−1 can also be seen as propagating gradient through hard tanh, which can be written as the following piece-wise linear activation function:
Htanh=Clip(x,−1,1)=max(−1,min(1,x)).
For the hidden units, the Sign function non-linearity is used to obtain binary activations, and for weights the following is done:
- Constrain each real-valued weight between −1 and +1, by projecting wr to −1 or +1 when the weight update brings wr outside of [−1,+1] i.e. clipping the weights during training. The real valued weights should not grow too large.
- When using a weight wr, we quantize it using wb=Sign(wr).
Complete training detail is illustrated Algorithm 1 and 2 in BNN paper.
In addition to these above steps, Shift-based Batch Normalization and Shift-based AdaMax algorithms are used to reduce the number of multiplications, instead of their vanilla variants. See Algorithm 3 and 4, respectively in BNN paper.
In the BNN, since all weights and activations are binary, all layers' inputs are also binary with the exception of the first layer. In this case, the first layer is instead quantized in 8-bit fixed point:
s=x⋅wb.
s=∑n=182n−1(xn⋅wb)
where, x is a vector of 1024 8-bit inputs, x18 is the most significant bit of the first input, wb is a vector of 1024 1-bit weights, and s is the resulting weighted sum. See Algorithm 5 in BNN paper.
This work's main contribution is that it has successfully binarized weights and activations in both, the inference phase and the training phase of the deep neural network. A good discussion of previously implemented binary neural networks is presented in Section 5 in the BNN paper.