How to Reduce Your Neural Network Size by ~32x and Increase Speed ~58x
Disclaimer: this post is for mostly for people who have a basic understanding of neural networks already
What the heck are binary neural networks?
Binary neural networks are a way of quantizing (reducing the number) of bytes needed to represent the values in a neural network in computation. Normally, we would compute all aspects - forward, back propagation, gradients - using 32 bit floating point values. Having 32x zeros/ones means we have to represent 4.2 e9 different values, which is a lot of precision which often isn’t needed.
A lot of smart people kept trying to reduce the network down, and got to 8 bit numbers (Dally) without accuracy loss which allows a 4x decrease in size and computation cost. However, being able to go even lower to 1 bit representations would then allow the replacement of computationally expensive multiplication accumulations with cheap and quick bitwise operations leading up to 58x increase in forward propagation speed and 32x decrease in memory usage from a 64 bit representation (Rastegari).
Why use BNNs?
Cell phones. Self driving cars. Augmented Reality Glasses. Basically any electronics that aren’t tethered to the wall plug aren’t able to properly take up too much of the deep learning hype because running a network takes up so much memory in the RAM (which you could be using for video games) and battery life (which could be used for video games). For reference, Apple actually limits downloads to 100 MB on cellular which means you can’t even build an app without using up all of your allotted space for just the model.
We can’t really stick a GPU on the iPhone either, so what most apps do is send your data to the server to compute before sending it back to your phone (yes, that does sound slow and expensive to maintain a computing server). So who wouldn’t appreciate the massive computational speed up (except Google who’s trying to sell you their cloud platform).
Binary Neural Networks have also been shown to withstand adversarial attacks very well. This is believed to be due to the noise injected by the quantization process which acts as a regularization effect.
Wait, what about back propagation?
Yeah, you can’t really back propagate through just 0’s and 1’s, your gradient would be 0 all over the place. So, we use a guesstimation commonly used to overcome this issue called the "straight-through estimator" (STE) (Hinton, Bengio), which simply passes the gradient through these functions as-is. (Apparently just passing the gradient through worked better than using the actual derivative for the sigmoid function)
Ok so what’s the catch?
It takes longer to train. Often times, most BNNs will learn from a trained full precision network. The network structure has to be modified a bit, more parameters might be added, and essentially leading to longer lead times.
It’s not always too accurate. However, there’s a lot of research addressing this. My friend Marianne and I implemented a paper recently (Lin), and built a BNN that uses multiple binary “bases” (which allows more information to pass through to approximate the true value) to approach the accuracy of a full precision network while still using significantly less memory and the bitwise operations to speed up calculations. The original paper managed to get only a 4.3% top-1 accuracy drop using ResNet18 on Imagenet while still achieving a 6.4x memory reduction.
Check our implementation on GitHub, paper write up, and my review of the field in quantizing Neural Networks.
Sources
William Dally. High-Performance Hardware for Machine Learning.