比OpenAI原始的Whisper快70倍的开源语音识别模型Whisper JAX发布!
Whisper是OpenAI在2022年9月份开源的自动语音识别模型。官方宣传其英语的识别水平与人类接近。而2个月后,官方就发布了Whisper V2版本,是第一个版本继续训练2.5倍得到,且加了正则化技术。而今天,一位网友Sanchit Gandhi发布了Whisper JAX,这是对原有版本的优化结果,识别速度最高达到原始模型的70倍!

Whisper模型卡信息:https://www.datalearner.com/ai-models/pretrained-models/Whisper%20JAX
Whisper JAX性能对比
从GitHub上提供的信息看,各个版本的Whisper的参数对比结果如下:
模型规模 | 参数数量 | 是否仅支持英文 | 多语言能力 |
---|---|---|---|
tiny | 3900万 | ✓ | ✓ |
base | 7400万 | ✓ | ✓ |
small | 2.44亿 | ✓ | ✓ |
medium | 7.69亿 | ✓ | ✓ |
large | 15.5 亿 | x | ✓ |
large-v2 | 15.5 亿 | x | ✓ |
Whisper是近年来OpenAI少有的开源模型,其效果很好,除了OpenAI官方的实现外,transformers也有相应的实现。作者做了不同版本的测试对比:
OpenAI | Transformers | Whisper JAX | Whisper JAX | |
---|---|---|---|---|
框架 | PyTorch | PyTorch | JAX | JAX |
使用的硬件 | GPU | GPU | GPU | TPU |
1 min | 13.8 | 4.54 | 1.72 | 0.45 |
10 min | 108.3 | 20.2 | 9.38 | 2.01 |
1 hour | 1001.0 | 126.1 | 75.3 | 13.8 |
上表中的1min、10min和1 hour分别代表不同时长的语音识别所需要的时间,单位是秒。以1个小时语音时长为例,OpenAI官方实现的PyTorch版本需要1001秒完成识别,大约是16-17分钟,transformers的实现需要126秒,而GPU版本的Whisper JAX只需要75.3秒,TPU运行的Whisper JAX只需要13.8秒!
不得不说JAX+TPU的Google全家桶性能是真的强!
Whisper JAX实际测试结果
Whisper JAX在Hugging Face上创建了一个Space,可以直接演示,我测试了中文,效果很好:

识别速度很快,效果非常棒。如果按照作者测试,GPU版本一分钟可以识别一个小时的语音,那么这方面的工作效率真的大大加强(为什么不说TPU?因为我们买不到!)
附:Google的JAX框架简介
Google\’s JAX是一个用于高性能数值计算和机器学习的Python库,它可以让用户使用类似NumPy的API进行计算,并且能够在GPU和TPU等加速设备上进行自动的优化和并行化处理。JAX提供了一些强大的功能,包括自动微分,高阶梯度计算和动态编译等,这使得它成为许多机器学习算法的理想实现平台。
与其他深度学习框架类似,JAX提供了一个自动微分系统,可以轻松地计算复杂函数的梯度,并且可以支持高阶导数的计算。此外,JAX还提供了一种称为jax.jit()的动态编译器,可以将Python函数转换为高效的机器码,以实现更快的执行速度和更低的内存占用。
JAX的一个独特之处在于它可以使用XLA(Accelerated Linear Algebra)库来自动并行化和优化计算,从而在GPU和TPU等加速设备上获得更好的性能。此外,JAX还提供了一些用于构建神经网络的高级API,例如jax.nn模块和jax.experimental.stax模块,以帮助用户更轻松地构建复杂的神经网络。
总的来说,JAX是一个非常强大的数值计算和机器学习库,可以帮助用户更轻松地构建高效的模型,并在加速设备上获得更好的性能。
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
