Introduction to JAX AutoDiff
JAX is the fusion of XLA (short for Accelerated Linear Algebra) and Autograd, a calculus which allows for automatic computation of differentials. A very popular Autograd framework as of today is Pytorch's, but JAX's is gaining popularity every day. This post is an introduction to JAX through simple examples.
A rational function
Assume you have two polynomials $P,Q$ of degree less than $d = 5$. Denote the coefficients $\vec{\alpha}, \vec{\beta}$ of $P,Q$. In order to avoid singularities at the roots of $Q$, we will rather study the function
$$ f(x) = \dfrac{P(x)}{1 + Q(x)^2} = \dfrac{\alpha_0 + \alpha_1x + \alpha_2x^2 + \alpha_3 x^3 + \alpha_4x^4 + \alpha_5x^5}{1 + (\beta_0 + \beta_1x + \beta_2x^2 + \beta_3 x^3 + \beta_4x^4 + \beta_5x^5)^2} $$
The function is computed in JAX code using the following snippet:
@jit
def rational_function(x, num_par, denom_par):
acc_num = 0.0
acc_denom = 0.0
for i,par in enumerate(num_par):
acc_num += par*jnp.float_power(x, i)
for j,par in enumerate(denom_par):
acc_denom += par*jnp.float_power(x, j)
return acc_num / (1.0 + acc_denom**2)
we can plot this function on the interval $I_{10} = (-10, +10)$, giving us the figure
Gradient and Hessian basics
Gradient
If one does not specify anything to JAX's grad
routine, it takes a function func(arg1, arg2, arg3, ...)
and returns a function with the same arguments, except the values of the new function are now the gradient w.r.t arg1
. Let's try this functionality by calling grad(rational_function)
and plotting the result on our $I_{10}$ grid:
We obtain $f'$ as expected, notice that the points $x$ for which $f'(x)$ vanishes are local extrema of $f$. We can count 5 of them in total if we ignore regions where $f(x) = 0$.
Hessian
As the grad(...)
call returns a function with $f$'s arguments, we can again call grad
on the output of the first gradient call:
y_g = grad(rational_function)
y_gg = grad(y_g)
Do not forget to call jit
to pre-compile the gradient functions and avoid performing unnecessary computations at each evaluation (y_g = jit(y_g)
and same thing for y_gg
).
We can then plot the functions $f, f', f''$ on $I_{10}$ after normalizing the $y$-axis (because $f''$ exhibits high variability).
Gradient and arguments
As mentionned earlier, calling grad(...)
without specifying the argument w.r.t which the gradient is computed automatically computes the gradient w.r.t the first argument. To change that behavior, one should use the argnums
parameter. For example, the snippet
num_par_grad = jit(grad(rational_function, argnums=1))
Computes the gradient $\partial f /\partial \vec{\alpha}$. As this gradient at each $x$ is an array of shape $1\times 6$, it is impossible to visualize the gradient on the full grid $I_{10}$. To simplify the problem, we set the degree $\deg(P) = 2$ and visualize $\nabla_{\vec{\alpha}}f$ on our grid. The result can be seen below:
The same simplifications and computations can be used to compute $\partial f /\partial \vec{\beta}$ (by setting argnums=2
, which yields the plot
Contours
The above curves where the result of plotting the values of $\partial f /\partial \vec{\alpha}$ as $x$ changes on a uniform grid $I_{10}$ having fixed both $\vec{\alpha}, \vec{\beta}$. We can proceed differently, setting $x = x_0$ and varying the coefficients $ \vec{\alpha}, \vec{\beta}$ on a uniform grid. (TO DO).
Tags
#autodiff #autograd #jax #python #code #math #applied #data #plots #gradient #hessian