Introduction to JAX AutoDiff

JAX is the fusion of XLA (short for Accelerated Linear Algebra) and Autograd, a calculus which allows for automatic computation of differentials. A very popular Autograd framework as of today is Pytorch's, but JAX's is gaining popularity every day. This post is an introduction to JAX through simple examples.

A rational function

Assume you have two polynomials $P,Q$ of degree less than $d = 5$. Denote the coefficients $\vec{\alpha}, \vec{\beta}$ of $P,Q$. In order to avoid singularities at the roots of $Q$, we will rather study the function

$$ f(x) = \dfrac{P(x)}{1 + Q(x)^2} = \dfrac{\alpha_0 + \alpha_1x + \alpha_2x^2 + \alpha_3 x^3 + \alpha_4x^4 + \alpha_5x^5}{1 + (\beta_0 + \beta_1x + \beta_2x^2 + \beta_3 x^3 + \beta_4x^4 + \beta_5x^5)^2} $$

The function is computed in JAX code using the following snippet:

@jit
def rational_function(x, num_par, denom_par):
    acc_num = 0.0
    acc_denom = 0.0
    for i,par in enumerate(num_par):
        acc_num += par*jnp.float_power(x, i)
        
    for j,par in enumerate(denom_par):
        acc_denom += par*jnp.float_power(x, j)
        
    return acc_num / (1.0 + acc_denom**2)

we can plot this function on the interval $I_{10} = (-10, +10)$, giving us the figure

rational function

Gradient and Hessian basics

Gradient

If one does not specify anything to JAX's grad routine, it takes a function func(arg1, arg2, arg3, ...) and returns a function with the same arguments, except the values of the new function are now the gradient w.r.t arg1. Let's try this functionality by calling grad(rational_function) and plotting the result on our $I_{10}$ grid:

rat func and its gradient wrt x

We obtain $f'$ as expected, notice that the points $x$ for which $f'(x)$ vanishes are local extrema of $f$. We can count 5 of them in total if we ignore regions where $f(x) = 0$.

Hessian

As the grad(...) call returns a function with $f$'s arguments, we can again call grad on the output of the first gradient call:

y_g = grad(rational_function)
y_gg = grad(y_g)

Do not forget to call jit to pre-compile the gradient functions and avoid performing unnecessary computations at each evaluation (y_g = jit(y_g) and same thing for y_gg). We can then plot the functions $f, f', f''$ on $I_{10}$ after normalizing the $y$-axis (because $f''$ exhibits high variability).

jax grad hessian wrt x

Gradient and arguments

As mentionned earlier, calling grad(...) without specifying the argument w.r.t which the gradient is computed automatically computes the gradient w.r.t the first argument. To change that behavior, one should use the argnums parameter. For example, the snippet

num_par_grad = jit(grad(rational_function, argnums=1))

Computes the gradient $\partial f /\partial \vec{\alpha}$. As this gradient at each $x$ is an array of shape $1\times 6$, it is impossible to visualize the gradient on the full grid $I_{10}$. To simplify the problem, we set the degree $\deg(P) = 2$ and visualize $\nabla_{\vec{\alpha}}f$ on our grid. The result can be seen below:

grad wrt numerator coeffs

The same simplifications and computations can be used to compute $\partial f /\partial \vec{\beta}$ (by setting argnums=2, which yields the plot

grad wrt denominator coeffs

Contours

The above curves where the result of plotting the values of $\partial f /\partial \vec{\alpha}$ as $x$ changes on a uniform grid $I_{10}$ having fixed both $\vec{\alpha}, \vec{\beta}$. We can proceed differently, setting $x = x_0$ and varying the coefficients $ \vec{\alpha}, \vec{\beta}$ on a uniform grid. (TO DO).

Tags

#autodiff #autograd #jax #python #code #math #applied #data #plots #gradient #hessian