Writing automated tests for neural networks

Introduction

Meet Eddie. He has spent the last couple of days working on a new neural network architecture for his classification problem, while simultaneously modifying the training pipeline to support his new dataset and loss function. Eddie starts the training process with great expectations, but soon notices that the model isn’t learning anything – the loss doesn’t decrease at all. He tries tweaking the hyperparameters but that doesn’t help. He then starts digging into the code to see what’s wrong, but isn’t able to find any bugs. Eventually, after a few days of testing, Eddie decides to rewrite all the parts one by one, and finally manages to get the model working without ever knowing exactly what the original issue was.

What can be done to avoid this scenario?

Having any kind of targeted, automated tests for your neural network and training pipeline is very helpful in finding the issues. A good principle is to start with the simplest possible model and training pipeline that does something, and verify with tests that this works at least to some degree. Then you can make modifications to one part at a time, and rerun the automated tests after each step to make sure that nothing breaks.

Unfortunately, writing automated tests for neural network architectures can be pretty tricky. It’s certainly more complicated than writing unit tests for regular functions, since due to the complexity of neural networks you usually don’t know in detail what the output should be for a certain input. An approach I’ve found useful is to focus on integration and property-based tests instead of traditional unit tests. Below are a few examples of such tests that I’ve used.

General notes regarding all tests

Even though these tests are not strictly unit tests, it still makes sense to use a unit test framework to write them. This makes the tests easy to run repeatedly. For Python, you can use, for instance, unittest or pytest.

For all the tests you may use a much smaller dataset than your actual dataset (usually some tens of samples is enough). These tests do not aim to verify the model trained on the full dataset, since running such trainings may take hours or days. Instead the goal of the tests is to verify that the model at least has a potential to work when trained on the actual full dataset, i.e. that there’s no obvious error in the model or the training pipeline.

It makes sense to fix the random seed for each test to keep the results the same between runs. Nothing is more annoying than a flaky test.

Now, let’s move on to the examples of different test cases.

The network overfits to a small dataset

Select a small subset of your training dataset, small here meaning just a few samples. Make sure you do select samples from at least two different classes if you do classification. Then create a test where you train your neural network until it overfits this subset fully (i.e. training loss is zero or near zero, and classification accuracy is 100% in case you’re doing classification). You can for instance pass the entire subset in one minibatch to the neural network, and repeat this until convergence. To speed up the test run you can often use a much smaller version of the network than you would normally use.

A network being able to overfit to a tiny dataset indicates that the data passes through the network at least in some way, that the loss is also propagated and that the weight updates cause the network to improve.

Issues that this test may catch
  • Data points being shuffled during the forward pass. This may happen e.g. if the batch creation is invalid, or if some data grouping or result gathering is done incorrectly.
  • Data being mangled during the forward pass. One example is if you use a masking function incorrectly.
  • The network not being expressive enough to solve the task at hand. This can happen for instance if you forget to add non-linearities and stack just linear layers on top of each other. Then the entire network is just a linear function, and it cannot overfit if the target function or decision boundary is non-linear.
  • Invalid weight initialization. If you for some reason need to write your own weight initialization you might end up with a network that’s not able to learn at all.

Gradients are non-zero and loss decreases

In some cases it might not be plausible to overfit to just a few samples. You may then instead (or in addition) just run training for some batches and verify that the gradients are non-zero, and that the loss decreases during training.

The gradients being non-zero and the loss decreasing shows that the training works at least on some level: The data passes through the forward pass and at least some loss signal propagates backwards.

Issues that this test may catch
  • Similar issues to the overfitting test, e.g regarding invalid batching or data handling during the forward pass. It might not find issues relating network expressivity.

All network layers change after each optimizer step

Run the training for your neural network for a few batches (you may again use a small subset of your actual training data). Assert after each batch that the weights for all layers have changed. This indicates that the entire network is actually being used for training. Sometimes you may have to run a few batches between the asserts, depending on your data and the hyperparameters you use.

Issues that this test may catch
  • Some layer not being used at all, or the layer not having gradients enabled. This can happen for instance if you add a new layer but forget to actually pass any data through it during the forward pass.
  • Issues where all the weight updates go to only one or a few layers. This might be due to the network architecture or due to poorly chosen hyperparameters.

Order of samples in batch does not affect output

Take a minibatch of samples (at least two, preferably more). Pass the samples as one batch to the neural network in inference mode (i.e. don’t run any training). Then pass each sample separately (as a one sample minibatch) to the neural network. Verify that the outputs for each sample is the same regardless of whether it was calculated as a part of a minibatch or by itself.

Additionally you can also shuffle the order of the samples in the minibatch and redo the forward pass. Again the output should be the same for each individual sample.

Issues that this test may catch
  • Inconsistencies in how the model forward works with batching vs. individual samples.
  • Non-determinism in batch handling.

Conclusion

Don’t be like Eddie. Spending days or even weeks on debugging some issue in your model and/or training pipeline is not particularly enjoyable. Using a few hours to write some basic automated tests will significantly reduce the risk of you having to do that.