Google JAX

Google JAX是一个用于高性能数值计算的机器学习Python库,拥有和NumPy一样易用的接口,但却支持GPU加速。

Python
Google
机器学习
初始发布时间:2018年
工具描述


Google JAX是谷歌内部开发的一个项目,它不是一个机器学习或者深度学习库,而是一个高性能的数值计算库。它的API接口是基于NumPy的,因此JAX本身简单、灵活且易于使用。

但是JAX本身支持基于GPU或者TPU的加速。因此,相比较NumPy,其速度要快很多。

除了可以完成类似NumPy的数值计算,JAX还包括一个可扩展的可组合函数转换系统,有助于机器学习的研究,包括:

  • 自动微分:基于梯度的优化是机器学习的基础。JAX原生支持通过函数转换对任意数值函数的正向和反向模式的自动微分,如grad, hessian, jacfwd 和 jacrev。
  • 矢量化:在机器学习研究中,我们经常将一个单一的函数应用于大量的数据,例如,计算整个批次的损失或评估每个实例的梯度,以实现差异化的私有学习。JAX通过vmap转换提供了自动矢量化,简化了这种形式的编程。例如,研究人员在实现新算法时不需要对批处理进行推理。JAX还通过相关的pmap转换支持大规模的数据并行化,优雅地分配那些对于单个加速器的内存来说太大的数据。
  • JIT-编译:XLA被用来在GPU和云TPU加速器上进行及时(JIT)编译和执行JAX程序。JIT编译,加上JAX的NumPy一致性API,使以前没有高性能计算经验的研究人员能够轻松地扩展到一个或多个加速器。

虽然JAX本身不是一个深度学习框架,但它肯定为深度学习的目的提供了一个更充分的基础。有许多建立在JAX之上的库,旨在建立深度学习能力,包括Flax、Haiku和Elegy。JAX对Hessians的高效计算也与深度学习有关,因为它们使高阶优化技术更加可行。

尽管JAX在逐渐成长,但它目前被运用于一系列项目中,比如除了深度学习之外,还有贝叶斯方法和机器人。DeepMind宣布了四个新的库,将加入他们的生态系统。Mctx提供AlphaZero和MuZero蒙特卡洛树搜索,KFAC-JAX是一个用于神经网络二阶优化和计算可扩展曲率近似值的库,DM_AUX是用于JAX的音频信号处理,提供频谱提取和SpecAug增强的工具,TF2JAX是一个用于将TensorFlow函数和图形转换为JAX函数的库。

总之,关于JAX的相关生态正在快速发展。甚至在某些报道中,它被当做是TensorFlow的接任者。


是否开源:

许可协议: Apache-2.0 license

官方地址: https://jax.readthedocs.io/en/latest/index.html

GitHub地址: https://github.com/google/jax

初始贡献者: Google内部人员

最佳实践指南

官方指南:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

Google-logo
pytorch-logo
推荐工具

TensorFlow - 深度学习

MindSpore - 深度学习