When building PyTorch models for custom applications from scratch there's usually one problem: The model does not learn anything. In a complex project, it can be tricky to identify the cause: Is it the data? A bug in the model? Choosing the wrong loss function at 3 am after an 8-hour coding session?
In this talk, we will build a toolbox to find the culprits in a structured manner. We will focus on simple ways to ensure a training loop is correct, generate synthetic training data to determine whether we have a model bug or problematic real-world data, and leverage pytest to safely refactor PyTorch models.
After this talk, visitors will be well equipped to take the right steps when a model is not learning, quickly identify the underlying reasons, and prevent bugs in the future.
PyTorch models for off-the-shelf applications are easy to build and debug. But in real-world ML applications, debugging can become quite tricky - especially when model complexity is high and only noisy real-world data is available.
When our DNN is not learning many factors can be at fault:
- Is there a bug in the model structure - for example mixed-up channels or timesteps?
- Is our dataset not large or homogeneous enough to learn something? Have we mixed up labels in the preprocessing?
- Have we chosen incorrect losses, accidentally skipped layers, or chosen inappropriate activation functions?
The plethora of potential reasons can be overwhelming to engineers. This talk will introduce a structured approach and valuable tools for efficiently debugging PyTorch models.
We'll start with techniques to check for correct training loops, such as ensuring our model overfits with a single training example. In the second step, we'll investigate how to generate simple, synthetic data for arbitrary input and output formats to validate our model. At last, we'll look at how to avoid model bugs altogether, by setting up universal tests that can be used during development and refactoring to prevent breaking PyTorch models.