JAX for Machine Learning: how it works and why learn it

JAX for Machine Learning: how it works and why learn it

JAX is the new kid in Machine Learning (ML) town and it promises to make ML programming more intuitive, structured, and clean. It can possibly replace the likes of Tensorflow and PyTorch despite the fact that it is very different in its core.

As a friend of mine said, we had all sorts of Aces, Kings, and Queens. Now we have JAX.

In this article, we will explore what is JAX and why one should use it over all the other libraries. We will make our points using code snippets that capture the power of JAX and we will present some good-to-know features of it.

If that sounds interesting, hop in.

What is Jax?

Jax is a Python library designed for high-performance ML research. Jax is nothing more than a numerical computing library, just like Numpy, but with some key improvements. It was developed by Google and used internally both by Google and Deepmind teams.

Source: JAX documentation

Install JAX

Before we discuss the main advantages of JAX, I suggest you to install JAX in your Python environment or in a Google colab so you can follow along and run the code by yourself. Of course, I will leave a link to the full code at the end of the article.

To install JAX, we can simply use pip from our command line:

$ pip install -- upgrade jax jaxlib

Note that this will support execution-only on CPU. If you also want to support GPU, you first need CUDA and cuDNN and then run the following command (make sure to map the jaxlib version with your CUDA version):

$ pip install -- upgrade jax jaxlib == 0.1 . 61 + cuda110 - f https : // storage . googleapis . com / jax - releases / jax_releases . html

For troubleshooting, check the official Github instructions .

Now let’s import JAX alongside Numpy. We will use Numpy to compare different use cases.

import jax
import jax.numpy as jnp
import numpy as np

JAX basics

Let’s start with the basics. As we already told, JAX’s main and only purpose is to perform numeric operations in an expressible and high-performance way. This means that the syntax is almost identical to Numpy. For example, if we want to create an array of zeros, we’d have:

x = np . zeros ( 10 )

y = jnp . zeros ( 10 )

The difference lies behind the scenes.

The DeviceArray

You see one of JAX’s main advantages is that we can run the same program, without any change, in hardware accelerators like GPUs and TPUs .

This is accomplished by an underlying structure called DeviceArray , which essentially replaces Numpy’s standard array.

DeviceArrays are lazy, which means that they keep the values in the accelerator and pull them only when...