Deepminds JAX ecosystem provides deep learning practitioners with an appealing alternative to TensorFlow and PyTorch. Among its strengths are great functionalities such as native TPU support, as well as easy vectorization and parallelization. Nevertheless, making your first steps in JAX can feel complicated given some of its idiosyncrasies. This talk helps new users getting started in this promising ecosystem by sharing practical tips and best practises.
Deepminds JAX ecosystem provides deep learning practitioners with an appealing alternative to Tensorflow and Pytorch. Among its strengths are great functionalities such as native TPU support, as well as easy vectorization and parallelization which make JAX and its ecosystem an attractive option for your deep learning projects. Nevertheless, making your first steps can feel complicated. From pure functions and the resulting differences in coding style, to avoiding recompilation, JAX comes with its own set of restrictions and design decisions to be taken by the user. This talk wants to help new and prospective users in their JAX learning journey, by providing guidance regarding practical problems they are likely to encounter when transitioning into the JAX ecosystem. Having recently switched to using Jax and Flax for my daily work this talk shares some of the insights I gained and wants to help them to avoid some of the mistakes I made early on. The talk will have a systematic look at selected situations in which JAX provides users with choices, seeing how they differ, and which is the best option given different circumstances. The talk covers: - Why bother switching to JAX? - A brief introduction to JAX including a list of JAX’s idiosyncrasies - Pure functions and the resulting architectural decisions - To JIT and or not to JIT - A speed and memory comparison of the different iteration options - Memory management and profiling
Speakers: Simon Pressler