比OpenAI原始的Whisper快70倍的开源语音识别模型Whisper JAX发布!

标签:#ASR##JAX##Whisper##语音识别# 时间:2023/04/24 22:50:23 作者:小木

Whisper是OpenAI在2022年9月份开源的自动语音识别模型。官方宣传其英语的识别水平与人类接近。而2个月后,官方就发布了Whisper V2版本,是第一个版本继续训练2.5倍得到,且加了正则化技术。而今天,一位网友Sanchit Gandhi发布了Whisper JAX,这是对原有版本的优化结果,识别速度最高达到原始模型的70倍!


Whisper模型卡信息:https://www.datalearner.com/ai-models/pretrained-models/Whisper%20JAX

Whisper JAX性能对比

从GitHub上提供的信息看,各个版本的Whisper的参数对比结果如下:

模型规模 参数数量 是否仅支持英文 多语言能力
tiny 3900万
base 7400万
small 2.44亿
medium 7.69亿
large 15.5 亿 x
large-v2 15.5 亿 x

Whisper是近年来OpenAI少有的开源模型,其效果很好,除了OpenAI官方的实现外,transformers也有相应的实现。作者做了不同版本的测试对比:

OpenAI Transformers Whisper JAX Whisper JAX
框架 PyTorch PyTorch JAX JAX
使用的硬件 GPU GPU GPU TPU
1 min 13.8 4.54 1.72 0.45
10 min 108.3 20.2 9.38 2.01
1 hour 1001.0 126.1 75.3 13.8

上表中的1min、10min和1 hour分别代表不同时长的语音识别所需要的时间,单位是秒。以1个小时语音时长为例,OpenAI官方实现的PyTorch版本需要1001秒完成识别,大约是16-17分钟,transformers的实现需要126秒,而GPU版本的Whisper JAX只需要75.3秒,TPU运行的Whisper JAX只需要13.8秒!

不得不说JAX+TPU的Google全家桶性能是真的强!

Whisper JAX实际测试结果

Whisper JAX在Hugging Face上创建了一个Space,可以直接演示,我测试了中文,效果很好:


识别速度很快,效果非常棒。如果按照作者测试,GPU版本一分钟可以识别一个小时的语音,那么这方面的工作效率真的大大加强(为什么不说TPU?因为我们买不到!)

附:Google的JAX框架简介

Google\’s JAX是一个用于高性能数值计算和机器学习的Python库,它可以让用户使用类似NumPy的API进行计算,并且能够在GPU和TPU等加速设备上进行自动的优化和并行化处理。JAX提供了一些强大的功能,包括自动微分,高阶梯度计算和动态编译等,这使得它成为许多机器学习算法的理想实现平台。

与其他深度学习框架类似,JAX提供了一个自动微分系统,可以轻松地计算复杂函数的梯度,并且可以支持高阶导数的计算。此外,JAX还提供了一种称为jax.jit()的动态编译器,可以将Python函数转换为高效的机器码,以实现更快的执行速度和更低的内存占用。

JAX的一个独特之处在于它可以使用XLA(Accelerated Linear Algebra)库来自动并行化和优化计算,从而在GPU和TPU等加速设备上获得更好的性能。此外,JAX还提供了一些用于构建神经网络的高级API,例如jax.nn模块和jax.experimental.stax模块,以帮助用户更轻松地构建复杂的神经网络。

总的来说,JAX是一个非常强大的数值计算和机器学习库,可以帮助用户更轻松地构建高效的模型,并在加速设备上获得更好的性能。

欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
Back to Top