## Introduction to JAX AutoDiff

JAX is the fusion of XLA (short for Accelerated Linear Algebra) and *Autograd*, a calculus which allows for automatic computation of differentials. A very popular Autograd framework as of today is Pytorch's, but JAX's is gaining popularity every day.
This post is an introduction to JAX through simple examples.

## A rational function

Assume you have two polynomials $P,Q$ of degree less than $d = 5$. Denote the coefficients $\vec{\alpha}, \vec{\beta}$ of $P,Q$. In order to avoid singularities at the roots of $Q$, we will rather study the function

$$ f(x) = \dfrac{P(x)}{1 + Q(x)^2} = \dfrac{\alpha_0 + \alpha_1x + \alpha_2x^2 + \alpha_3 x^3 + \alpha_4x^4 + \alpha_5x^5}{1 + (\beta_0 + \beta_1x + \beta_2x^2 + \beta_3 x^3 + \beta_4x^4 + \beta_5x^5)^2} $$

The function is computed in JAX code using the following snippet:

```
@jit
def rational_function(x, num_par, denom_par):
acc_num = 0.0
acc_denom = 0.0
for i,par in enumerate(num_par):
acc_num += par*jnp.float_power(x, i)
for j,par in enumerate(denom_par):
acc_denom += par*jnp.float_power(x, j)
return acc_num / (1.0 + acc_denom**2)
```

we can plot this function on the interval $I_{10} = (-10, +10)$, giving us the figure

## Gradient and Hessian basics

### Gradient

If one does not specify anything to JAX's `grad`

routine, it takes a function `func(arg1, arg2, arg3, ...)`

and returns a function with the same arguments, except the values of the new function are now the gradient w.r.t `arg1`

. Let's try this functionality by calling `grad(rational_function)`

and plotting the result on our $I_{10}$ grid:

We obtain $f'$ as expected, notice that the points $x$ for which $f'(x)$ vanishes are local extrema of $f$. We can count 5 of them in total if we ignore regions where $f(x) = 0$.

### Hessian

As the `grad(...)`

call returns a function with $f$'s arguments, we can again call `grad`

on the output of the first gradient call:

```
y_g = grad(rational_function)
y_gg = grad(y_g)
```

Do not forget to call `jit`

to pre-compile the gradient functions and avoid performing unnecessary computations at each evaluation (`y_g = jit(y_g)`

and same thing for `y_gg`

).
We can then plot the functions $f, f', f''$ on $I_{10}$ after normalizing the $y$-axis (because $f''$ exhibits high variability).

## Gradient and arguments

As mentionned earlier, calling `grad(...)`

without specifying the argument w.r.t which the gradient is computed automatically computes the gradient w.r.t the first argument. To change that behavior, one should use the `argnums`

parameter. For example, the snippet

```
num_par_grad = jit(grad(rational_function, argnums=1))
```

Computes the gradient $\partial f /\partial \vec{\alpha}$. As this gradient at each $x$ is an array of shape $1\times 6$, it is impossible to visualize the gradient on the full grid $I_{10}$. To simplify the problem, we set the degree $\deg(P) = 2$ and visualize $\nabla_{\vec{\alpha}}f$ on our grid. The result can be seen below:

The same simplifications and computations can be used to compute $\partial f /\partial \vec{\beta}$ (by setting `argnums=2`

, which yields the plot

## Contours

The above curves where the result of plotting the values of $\partial f /\partial \vec{\alpha}$ as $x$ changes on a uniform grid $I_{10}$ having fixed both $\vec{\alpha}, \vec{\beta}$. We can proceed differently, setting $x = x_0$ and varying the coefficients $ \vec{\alpha}, \vec{\beta}$ on a uniform grid. (**TO DO**).

### Tags

#autodiff #autograd #jax #python #code #math #applied #data #plots #gradient #hessian