JAX是一个具备自动微分功能且能够在CPU/GPU/TPU上高性能执行的NumPy。其中自动微分想法源自autograd,多设备高性能则得益于XLA。JAX的核心能力基于composable transformations,就是可以将函数进行转换。

JAX核心变换能力

JAX为了更方便转换定义了一种中间表示语言:jaxpr。jax会先将Python中的函数转换为jaxpr,然后对jaxpr进行转换,JAX中的所有基础函数都有自定义的转换规则。

JAX核心转换能力如下:

  • grad: 自动微分
  • jit: jit编译,使用XLA进行加速
  • vmap: 向量化
  • pmap: 并行化

grad

对指定函数进行变换,得到其一阶导数函数,示例如下:

from jax import numpy as jnp
from jax import grad

def f(x):
    return jnp.sum(x * x)

grad_f = grad(f)
grad_f(jnp.array([1, 2, 3.]))

jit

显示使用XLA对函数进行优化

import numpy as np
from jax import numpy as jnp
from jax import jit

def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

norm_compiled = jit(norm)

vmap

vmap可以自动让函数支持batching,原始函数表示的是向量-向量乘法,使用vmap可以得到矩阵-向量乘法的函数:

from jax import numpy as jnp
from jax import vmap


def vec_vec_dot(x, y):
    """vector-vector dot, ([a], [a]) -> []
    """
    return jnp.dot(x, y)

x = jnp.array([1,1,2])
y = jnp.array([2,1,1,])
vec_vec_dot(x, y)
# DeviceArray(5, dtype=int32)

mat_vec = vmap(vec_vec_dot, in_axes=(0, None), out_axes=0)  # ([b,a], [a]) -> [b]      (b is the mapped axis)
xx = jnp.array([[1,1,2], [1,1,2]])
mat_vec(xx, y)
# DeviceArray([5, 5], dtype=int32)

pmap

多核并行,如下例子,使用vmap会在一个设备上做SIMD优化,结果是DeviceArray;pmap会在多个核心并行加速,结果是SharedDeviceArray。

import jax
from jax import numpy as jnp
from jax import pmap

jax.device_count()  # 8核心
x = jnp.arange(8)
y = jnp.arange(8)

vmap(jnp.add)(x, y)
# DeviceArray([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

pmap(jnp.add)(x, y)
# ShardedDeviceArray([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

变换组合

上述几种核心的变换可以组合使用,如下:

pmap(vamp(some_func))
jit(grad(grad(vmap(some_func))))

使用约束

上述变换只能针对纯函数pure function使用。纯函数的性质:

  1. 返回值只与传入参数相关;
  2. 不改变函数外状态;

如果想在函数内使用随机数函数又保持纯函数特性可以使用jax.numpy.random下的函数;JAX不做纯函数校验,需要用户自己保证。

JAX深度学习能力

直接使用JAX搭建深度学习网络会比较复杂。所以使用JAX进行深度学习网络搭建一般使用基于JAX的深度学习框架,目前主流使用的是Flax和Optax。其中JAX提供基础运算,类似传统的numpy, Flax提供网络结果,类似torch.nn, Optax提供优化器,类似torch.optim。然后配合使用各类转换进行加速和优化。