conferences | speakers | series

JAX and Flax: Function Transformations and Neural Networks

home

JAX and Flax: Function Transformations and Neural Networks
EuroSciPy 2022

Modern accelerators (graphics processing units and tensor processing units) allow for high performance computing at massive scale. JAX traces computation in Python programs through the familiar numpy API, and uses XLA to compile programs that run efficiently on these accelerators. A set of composable function transformations allows for expressing versatile scientific computing with an elegant syntax. Flax provides abstractions on top of JAX that make it easy to handle weights and other states that is required for solving problems using neural networks. This talk first presents the basic JAX API that allows for computing gradients, compiling functions, or vectorizing computation. It then proceeds to cover other parts of the JAX ecosystem commonly used for neural network programming, such as basic building blocks and optimizers.

Modern accelerators (graphics processing units and tensor processing units) allow for high performance computing at massive scale. JAX traces computation in Python programs through the familiar numpy API, and uses XLA to compile programs that run efficiently on these accelerators. A set of composable function transformations allows for expressing versatile scientific computing with an elegant syntax. Flax provides abstractions on top of JAX that make it easy to handle weights and other states that is required for solving problems using neural networks. This talk first presents the basic JAX API that allows for computing gradients, compiling functions, or vectorizing computation. It then proceeds to cover other parts of the JAX ecosystem commonly used for neural network programming, such as basic building blocks and optimizers.

Speakers: Andreas Steiner