This commit is contained in:
雾聪 2023-10-16 12:37:02 +08:00
commit d1c9f58401
89 changed files with 104243 additions and 7 deletions

View File

@ -28,6 +28,9 @@
<a name="whats-new"></a>
## What's new:
- 2023/10/13: [SlideSpeech](https://slidespeech.github.io/): A large scale multi-modal audio-visual corpus with a significant amount of real-time synchronized slides.
- 2023/10/10: The ASR-SpeakersDiarization combined pipeline [speech_campplus_speaker-diarization_common](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/demo.py) is now released. Experience the model to get recognition results with speaker information.
- 2023/10/07: [FunCodec](https://github.com/alibaba-damo-academy/FunCodec): A Fundamental, Reproducible and Integrable Open-source Toolkit for Neural Speech Codec.
- 2023/09/01: The offline file transcription service 2.0 (CPU) of Mandarin has been released, with added support for ffmpeg, timestamp, and hotword models. For more details, please refer to ([Deployment documentation](funasr/runtime/docs/SDK_tutorial.md)).
- 2023/08/07: The real-time transcription service (CPU) of Mandarin has been released. For more details, please refer to ([Deployment documentation](funasr/runtime/docs/SDK_tutorial_online.md)).

View File

@ -31,8 +31,10 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
<a name="最新动态"></a>
## 最新动态
- 2023/10/13: [SlideSpeech](https://slidespeech.github.io/): 一个大规模的多模态音视频语料库,主要是在线会议或者在线课程场景,包含了大量与发言人讲话实时同步的幻灯片。
- 2023.10.10: [Paraformer-long-Spk](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr_vad_spk/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/demo.py)模型发布,支持在长语音识别的基础上获取每句话的说话人标签。
- 2023.10.07: [FunCodec](https://github.com/alibaba-damo-academy/FunCodec): FunCodec提供开源模型和训练工具可以用于音频离散编码以及基于离散编码的语音识别、语音合成等任务。
- 2023.09.01中文离线文件转写服务2.0 CPU版本发布新增ffmpeg、时间戳与热词模型支持详细信息参阅([一键部署文档](funasr/runtime/docs/SDK_tutorial_zh.md))
- 2023.09.01: 中文离线文件转写服务2.0 CPU版本发布新增ffmpeg、时间戳与热词模型支持详细信息参阅([一键部署文档](funasr/runtime/docs/SDK_tutorial_zh.md))
- 2023.08.07: 中文实时语音听写服务一键部署的CPU版本发布详细信息参阅([一键部署文档](funasr/runtime/docs/SDK_tutorial_online_zh.md))
- 2023.07.17: BAT一种低延迟低内存消耗的RNN-T模型发布详细信息参阅[BAT](egs/aishell/bat)
- 2023.07.03: 中文离线文件转写服务一键部署的CPU版本发布详细信息参阅([一键部署文档](funasr/runtime/docs/SDK_tutorial_zh.md))

View File

@ -17,7 +17,8 @@ Here we provided several pretrained models on different datasets. The details of
| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which would deal with arbitrary length input wav |
| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which would deal with arbitrary length input wav |
| [Paraformer-large-Spk](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Supporting speaker diarizatioin for ASR results based on paraformer-large-long |
| [Paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
| [Paraformer-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |

View File

@ -17,7 +17,8 @@
| 模型名字 | 语言 | 训练数据 | 词典大小 | 参数量 | 非实时/实时 | 备注 |
|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:-----------------:|:----:|:-------:|:---------------------------|
| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | 中文和英文 | 阿里巴巴语音数据60000小时 | 8404 | 220M | 非实时 | 输入wav文件持续时间不超过20秒 |
| [Paraformer-large长音频版本](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | 中文和英文 | 阿里巴巴语音数据60000小时 | 8404 | 220M | 非实时 || 能够处理任意长度的输入wav文件 |
| [Paraformer-large长音频版本](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | 中文和英文 | 阿里巴巴语音数据60000小时 | 8404 | 220M | 非实时 | 能够处理任意长度的输入wav文件 |
| [Paraformer-large-Spk](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/summary) | 中文和英文 | 阿里巴巴语音数据60000小时 | 8404 | 220M | 非实时 | 在长音频功能的基础上添加说话人识别功能 |
| [Paraformer-large热词](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | 中文和英文 | 阿里巴巴语音数据60000小时 | 8404 | 220M | 非实时 | 基于激励增强的热词定制支持可以提高热词的召回率和准确率输入wav文件持续时间不超过20秒 |
| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | 中文和英文 | 阿里巴巴语音数据50000小时 | 8358 | 68M | 离线 | 输入wav文件持续时间不超过20秒 |
| [Paraformer实时](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/summary) | 中文和英文 | 阿里巴巴语音数据 (50000hours) | 8404 | 68M | 实时 | 能够处理流式输入 |

View File

@ -99,6 +99,28 @@ print(rec_result)
```
The decoding mode of `fast` and `normal` is fake streaming, which could be used for evaluating of recognition accuracy.
Full code of demo, please ref to [demo](https://github.com/alibaba-damo-academy/FunASR/discussions/151)
#### [Paraformer-Spk](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/summary)
This model allows user to get recognition results which contain speaker info of each sentence. Refer to [CAM++](https://modelscope.cn/models/damo/speech_campplus_speaker-diarization_common/summary) for detailed information about speaker diarization model.
```python
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav'
output_dir = "./results"
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn',
model_revision='v0.0.2',
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
punc_model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
output_dir=output_dir,
)
rec_result = inference_pipeline(audio_in=audio_in, batch_size_token=5000, batch_size_token_threshold_s=40, max_single_segment_time=6000)
print(rec_result)
```
#### [RNN-T-online model]()
Undo

View File

@ -100,6 +100,29 @@ print(rec_result)
fast 和 normal 的解码模式是假流式解码,可用于评估识别准确性。
演示的完整代码,请参见 [demo](https://github.com/alibaba-damo-academy/FunASR/discussions/151)
#### [Paraformer-Spk model](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/summary)
返回识别结果的同时返回每个子句的说话人分类结果。关于说话人日志模型的详情请见[CAM++](https://modelscope.cn/models/damo/speech_campplus_speaker-diarization_common/summary)。
```python
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav'
output_dir = "./results"
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn',
model_revision='v0.0.2',
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
punc_model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
output_dir=output_dir,
)
rec_result = inference_pipeline(audio_in=audio_in, batch_size_token=5000, batch_size_token_threshold_s=40, max_single_segment_time=6000)
print(rec_result)
```
#### [RNN-T-online 模型]()
Undo

View File

@ -0,0 +1 @@
../asr/TEMPLATE

View File

@ -38,7 +38,9 @@ from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
from funasr.utils.whisper_utils.transcribe import transcribe
from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
class Speech2Text:
"""Speech2Text class
@ -1880,3 +1882,117 @@ class Speech2TextSAASR:
results.append((text, text_id, token, token_int, hyp))
return results
class Speech2TextWhisper:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
batch_size: int = 1,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
frontend_conf: dict = None,
**kwargs,
):
# 1. Build ASR model
scorers = {}
from funasr.tasks.whisper import ASRTask
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
token_list = []
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
self.frontend = frontend
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
speech = speech[0]
speech = pad_or_trim(speech)
mel = log_mel_spectrogram(speech).to(self.device)
if self.asr_model.is_multilingual:
options = DecodingOptions(fp16=False)
asr_res = decode(self.asr_model, mel, options)
text = asr_res.text
language = asr_res.language
else:
asr_res = transcribe(self.asr_model, speech, fp16=False)
text = asr_res["text"]
language = asr_res["language"]
results = [(text, language)]
return results

View File

@ -29,6 +29,7 @@ from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnl
from funasr.bin.asr_infer import Speech2TextSAASR
from funasr.bin.asr_infer import Speech2TextTransducer
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.asr_infer import Speech2TextWhisper
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
@ -2020,6 +2021,161 @@ def inference_sa_asr(
return _forward
def inference_whisper(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
param_dict: dict = None,
**kwargs,
):
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2TextWhisper(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = build_streaming_iterator(
task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, language) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["language"][key] = language
if text is not None:
item = {'key': key, 'value': text}
asr_result_list.append(item)
finish_count += 1
if writer is not None:
ibest_writer["text"][key] = text
logging.info("uttid: {}".format(key))
logging.info("text predictions: {}\n".format(text))
return asr_result_list
return _forward
def inference_launch(**kwargs):
if 'mode' in kwargs:
@ -2049,6 +2205,8 @@ def inference_launch(**kwargs):
return inference_transducer(**kwargs)
elif mode == "sa_asr":
return inference_sa_asr(**kwargs)
elif mode == "whisper":
return inference_whisper(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None

View File

@ -0,0 +1,269 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from funasr.models.base_model import FunASRModel
from funasr.utils.whisper_utils.decoding import detect_language as detect_language_function, decode as decode_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
return logits
class Whisper(FunASRModel):
def __init__(self, dims: dict):
super().__init__()
dims = ModelDimensions(**dims)
self.dims = dims
self.sos = 1
self.eos = 1
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
decode = decode_function

View File

@ -0,0 +1,71 @@
plugins {
id 'com.android.application'
id 'org.jetbrains.kotlin.android'
}
android {
namespace 'com.yeyupiaoling.androidclient'
compileSdk 33
defaultConfig {
applicationId "com.yeyupiaoling.androidclient"
minSdk 24
targetSdk 33
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
vectorDrawables {
useSupportLibrary true
}
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = '1.8'
}
buildFeatures {
compose true
}
composeOptions {
kotlinCompilerExtensionVersion '1.4.3'
}
packaging {
resources {
excludes += '/META-INF/{AL2.0,LGPL2.1}'
}
}
}
dependencies {
implementation 'androidx.core:core-ktx:1.9.0'
implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1'
implementation 'androidx.activity:activity-compose:1.7.0'
implementation platform('androidx.compose:compose-bom:2023.03.00')
implementation 'androidx.compose.ui:ui'
implementation 'androidx.compose.ui:ui-graphics'
implementation 'androidx.compose.ui:ui-tooling-preview'
implementation 'androidx.compose.material3:material3'
implementation 'androidx.appcompat:appcompat:1.6.1'
implementation 'com.google.android.material:material:1.8.0'
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
androidTestImplementation platform('androidx.compose:compose-bom:2023.03.00')
androidTestImplementation 'androidx.compose.ui:ui-test-junit4'
debugImplementation 'androidx.compose.ui:ui-tooling'
debugImplementation 'androidx.compose.ui:ui-test-manifest'
implementation 'com.squareup.okhttp3:okhttp:4.9.1'
}

View File

@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View File

@ -0,0 +1,24 @@
package com.yeyupiaoling.androidclient
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.*
/**
* Instrumented test, which will execute on an Android device.
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
@RunWith(AndroidJUnit4::class)
class ExampleInstrumentedTest {
@Test
fun useAppContext() {
// Context of the app under test.
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
assertEquals("com.yeyupiaoling.androidclient", appContext.packageName)
}
}

View File

@ -0,0 +1,30 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.AndroidClient"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -0,0 +1,216 @@
package com.yeyupiaoling.androidclient;
import android.content.Context;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.graphics.Point;
import android.util.AttributeSet;
import android.view.View;
import androidx.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class AudioView extends View {
// 频谱数量
private static final int LUMP_COUNT = 128;
private static final int LUMP_WIDTH = 6;
private static final int LUMP_SPACE = 2;
private static final int LUMP_MIN_HEIGHT = LUMP_WIDTH;
private static final int LUMP_MAX_HEIGHT = 200;//TODO: HEIGHT
private static final int LUMP_SIZE = LUMP_WIDTH + LUMP_SPACE;
private static final int LUMP_COLOR = Color.parseColor("#6de8fd");
private static final int WAVE_SAMPLING_INTERVAL = 3;
private static final float SCALE = LUMP_MAX_HEIGHT / LUMP_COUNT;
private ShowStyle upShowStyle = ShowStyle.STYLE_HOLLOW_LUMP;
private ShowStyle downShowStyle = ShowStyle.STYLE_WAVE;
private byte[] waveData;
List<Point> pointList;
private Paint lumpPaint;
Path wavePath = new Path();
public AudioView(Context context) {
super(context);
init();
}
public AudioView(Context context, @Nullable AttributeSet attrs) {
super(context, attrs);
init();
}
public AudioView(Context context, @Nullable AttributeSet attrs, int defStyleAttr) {
super(context, attrs, defStyleAttr);
init();
}
private void init() {
lumpPaint = new Paint();
lumpPaint.setAntiAlias(true);
lumpPaint.setColor(LUMP_COLOR);
lumpPaint.setStrokeWidth(2);
lumpPaint.setStyle(Paint.Style.STROKE);
}
public void setWaveData(byte[] data) {
this.waveData = readyData(data);
genSamplingPoint(data);
invalidate();
}
public void setStyle(ShowStyle upShowStyle, ShowStyle downShowStyle) {
this.upShowStyle = upShowStyle;
this.downShowStyle = downShowStyle;
}
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
wavePath.reset();
for (int i = 0; i < LUMP_COUNT; i++) {
if (waveData == null) {
canvas.drawRect((LUMP_WIDTH + LUMP_SPACE) * i,
LUMP_MAX_HEIGHT - LUMP_MIN_HEIGHT,
(LUMP_WIDTH + LUMP_SPACE) * i + LUMP_WIDTH,
LUMP_MAX_HEIGHT,
lumpPaint);
continue;
}
switch (upShowStyle) {
case STYLE_HOLLOW_LUMP:
drawLump(canvas, i, false);
break;
case STYLE_WAVE:
drawWave(canvas, i, false);
break;
default:
break;
}
switch (downShowStyle) {
case STYLE_HOLLOW_LUMP:
drawLump(canvas, i, true);
break;
case STYLE_WAVE:
drawWave(canvas, i, true);
break;
default:
break;
}
}
}
/**
* 预处理数据
*
* @return
*/
private static byte[] readyData(byte[] fft) {
byte[] newData = new byte[LUMP_COUNT];
byte abs;
for (int i = 0; i < LUMP_COUNT; i++) {
abs = (byte) Math.abs(fft[i]);
//描述Math.abs -128时越界
newData[i] = abs < 0 ? 127 : abs;
}
return newData;
}
/**
* 绘制曲线
*
* @param canvas
* @param i
* @param reversal
*/
private void drawWave(Canvas canvas, int i, boolean reversal) {
if (pointList == null || pointList.size() < 2) {
return;
}
float ratio = SCALE * (reversal ? -1 : 1);
if (i < pointList.size() - 2) {
Point point = pointList.get(i);
Point nextPoint = pointList.get(i + 1);
int midX = (point.x + nextPoint.x) >> 1;
if (i == 0) {
wavePath.moveTo(point.x, LUMP_MAX_HEIGHT - point.y * ratio);
}
wavePath.cubicTo(midX, LUMP_MAX_HEIGHT - point.y * ratio,
midX, LUMP_MAX_HEIGHT - nextPoint.y * ratio,
nextPoint.x, LUMP_MAX_HEIGHT - nextPoint.y * ratio);
canvas.drawPath(wavePath, lumpPaint);
}
}
/**
* 绘制矩形条
*/
private void drawLump(Canvas canvas, int i, boolean reversal) {
int minus = reversal ? -1 : 1;
float top = (LUMP_MAX_HEIGHT - (LUMP_MIN_HEIGHT + waveData[i] * SCALE) * minus);
canvas.drawRect(LUMP_SIZE * i,
top,
LUMP_SIZE * i + LUMP_WIDTH,
LUMP_MAX_HEIGHT,
lumpPaint);
}
/**
* 生成波形图的采样数据减少计算量
*
* @param data
*/
private void genSamplingPoint(byte[] data) {
if (upShowStyle != ShowStyle.STYLE_WAVE && downShowStyle != ShowStyle.STYLE_WAVE) {
return;
}
if (pointList == null) {
pointList = new ArrayList<>();
} else {
pointList.clear();
}
pointList.add(new Point(0, 0));
for (int i = WAVE_SAMPLING_INTERVAL; i < LUMP_COUNT; i += WAVE_SAMPLING_INTERVAL) {
pointList.add(new Point(LUMP_SIZE * i, waveData[i]));
}
pointList.add(new Point(LUMP_SIZE * LUMP_COUNT, 0));
}
/**
* 可视化样式
*/
public enum ShowStyle {
/**
* 空心的矩形小块
*/
STYLE_HOLLOW_LUMP,
/**
* 曲线
*/
STYLE_WAVE,
/**
* 不显示
*/
STYLE_NOTHING
}
}

View File

@ -0,0 +1,248 @@
package com.yeyupiaoling.androidclient;
import android.Manifest;
import android.annotation.SuppressLint;
import android.content.pm.PackageManager;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder;
import android.os.Bundle;
import android.util.Log;
import android.view.MotionEvent;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;
import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import javax.net.ssl.HostnameVerifier;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.ByteString;
public class MainActivity extends AppCompatActivity {
public static final String TAG = MainActivity.class.getSimpleName();
// WebSocket地址如果服务端没有使用SSL请使用ws://
public static final String ASR_HOST = "wss://192.168.0.1:10095";
// 采样率
public static final int SAMPLE_RATE = 16000;
// 声道数
public static final int CHANNEL = AudioFormat.CHANNEL_IN_MONO;
// 返回的音频数据的格式
public static final int AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT;
private AudioRecord audioRecord;
private boolean isRecording = false;
private int minBufferSize;
private AudioView audioView;
private String allAsrText = "";
private String asrText = "";
// 控件
private Button recordBtn;
private TextView resultText;
private WebSocket webSocket;
@SuppressLint("ClickableViewAccessibility")
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// 请求权限
if (!hasPermission()) {
requestPermission();
}
// 录音参数
minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL, AUDIO_FORMAT);
// 显示识别结果控件
resultText = findViewById(R.id.result_text);
// 显示录音状态控件
audioView = findViewById(R.id.audioView);
audioView.setStyle(AudioView.ShowStyle.STYLE_HOLLOW_LUMP, AudioView.ShowStyle.STYLE_NOTHING);
// 按下识别按钮
recordBtn = findViewById(R.id.record_button);
recordBtn.setOnTouchListener((v, event) -> {
if (event.getAction() == MotionEvent.ACTION_UP) {
isRecording = false;
stopRecording();
recordBtn.setText("按下录音");
} else if (event.getAction() == MotionEvent.ACTION_DOWN) {
if (webSocket != null){
webSocket.cancel();
webSocket = null;
}
allAsrText = "";
asrText = "";
isRecording = true;
startRecording();
recordBtn.setText("录音中...");
}
return true;
});
}
// 开始录音
private void startRecording() {
// 准备录音器
try {
// 确保有权限
if (ActivityCompat.checkSelfPermission(this, android.Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) {
requestPermission();
return;
}
// 创建录音器
audioRecord = new AudioRecord(MediaRecorder.AudioSource.MIC, SAMPLE_RATE, CHANNEL, AUDIO_FORMAT, minBufferSize);
} catch (IllegalStateException e) {
e.printStackTrace();
}
// 开启一个线程将录音数据写入文件
Thread recordingAudioThread = new Thread(() -> {
try {
setAudioData();
} catch (Exception e) {
e.printStackTrace();
}
});
recordingAudioThread.start();
// 启动录音器
audioRecord.startRecording();
audioView.setVisibility(View.VISIBLE);
}
// 停止录音器
private void stopRecording() {
audioRecord.stop();
audioRecord.release();
audioRecord = null;
audioView.setVisibility(View.GONE);
}
// 读取录音数据
private void setAudioData() throws Exception {
// 如果使用正常的wss可以去掉这个
HostnameVerifier hostnameVerifier = (hostname, session) -> {
// 总是返回true表示不验证域名
return true;
};
// 建立WebSocket连接
OkHttpClient client = new OkHttpClient.Builder()
.hostnameVerifier(hostnameVerifier)
.build();
Request request = new Request.Builder()
.url(ASR_HOST)
.build();
webSocket = client.newWebSocket(request, new WebSocketListener() {
@Override
public void onOpen(@NonNull WebSocket webSocket, @NonNull Response response) {
// 连接成功时的处理
Log.d(TAG, "WebSocket连接成功");
}
@Override
public void onMessage(@NonNull WebSocket webSocket, @NonNull String text) {
// 接收到消息时的处理
Log.d(TAG, "WebSocket接收到消息: " + text);
try {
JSONObject jsonObject = new JSONObject(text);
String t = jsonObject.getString("text");
boolean isFinal = jsonObject.getBoolean("is_final");
if (!t.equals("")) {
// 拼接识别结果
String mode = jsonObject.getString("mode");
if (mode.equals("2pass-offline")) {
asrText = "";
allAsrText = allAsrText + t;
// 这里可以做一些自动停止录音识别的程序
} else {
asrText = asrText + t;
}
}
// 显示语音识别结果消息
if (!(allAsrText + asrText).equals("")) {
runOnUiThread(() -> resultText.setText(allAsrText + asrText));
}
// 如果检测的录音停止就关闭WebSocket连接
if (isFinal) {
webSocket.close(1000, "关闭WebSocket连接");
}
} catch (JSONException e) {
throw new RuntimeException(e);
}
}
@Override
public void onClosing(@NonNull WebSocket webSocket, int code, @NonNull String reason) {
// 关闭连接时的处理
Log.d(TAG, "WebSocket关闭连接: " + reason);
}
@Override
public void onFailure(@NonNull WebSocket webSocket, @NonNull Throwable t, Response response) {
// 连接失败时的处理
Log.d(TAG, "WebSocket连接失败: " + t + ": " + response);
}
});
String message = getMessage("2pass", "5, 10, 5", 10, true);
webSocket.send(message);
audioRecord.startRecording();
byte[] bytes = new byte[minBufferSize];
while (isRecording) {
int readSize = audioRecord.read(bytes, 0, minBufferSize);
if (readSize > 0) {
ByteString byteString = ByteString.of(bytes);
webSocket.send(byteString);
audioView.post(() -> audioView.setWaveData(bytes));
}
}
JSONObject obj = new JSONObject();
obj.put("is_speaking", false);
webSocket.send(obj.toString());
// webSocket.close(1000, "关闭WebSocket连接");
}
// 发送第一步的JSON数据
public String getMessage(String mode, String strChunkSize, int chunkInterval, boolean isSpeaking) {
try {
JSONObject obj = new JSONObject();
obj.put("mode", mode);
JSONArray array = new JSONArray();
String[] chunkList = strChunkSize.split(",");
for (String s : chunkList) {
array.put(Integer.valueOf(s.trim()));
}
obj.put("chunk_size", array);
obj.put("chunk_interval", chunkInterval);
obj.put("wav_name", "default");
// 热词
obj.put("hotwords", "阿里巴巴 达摩院");
obj.put("wav_format", "pcm");
obj.put("is_speaking", isSpeaking);
return obj.toString();
} catch (Exception e) {
e.printStackTrace();
}
return "";
}
// 检查权限
private boolean hasPermission() {
return checkSelfPermission(android.Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED &&
checkSelfPermission(android.Manifest.permission.WRITE_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED;
}
// 请求权限
private void requestPermission() {
requestPermissions(new String[]{android.Manifest.permission.RECORD_AUDIO,
Manifest.permission.WRITE_EXTERNAL_STORAGE}, 1);
}
}

View File

@ -0,0 +1,170 @@
<?xml version="1.0" encoding="utf-8"?>
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path
android:fillColor="#3DDC84"
android:pathData="M0,0h108v108h-108z" />
<path
android:fillColor="#00000000"
android:pathData="M9,0L9,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,0L19,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,0L29,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,0L39,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,0L49,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,0L59,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,0L69,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,0L79,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M89,0L89,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M99,0L99,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,9L108,9"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,19L108,19"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,29L108,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,39L108,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,49L108,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,59L108,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,69L108,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,79L108,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,89L108,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,99L108,99"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,29L89,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,39L89,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,49L89,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,59L89,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,69L89,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,79L89,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,19L29,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,19L39,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,19L49,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,19L59,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,19L69,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,19L79,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
</vector>

View File

@ -0,0 +1,30 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
<aapt:attr name="android:fillColor">
<gradient
android:endX="85.84757"
android:endY="92.4963"
android:startX="42.9492"
android:startY="49.59793"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
android:strokeWidth="1"
android:strokeColor="#00000000" />
</vector>

View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<Button
android:id="@+id/record_button"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:layout_marginLeft="10dp"
android:layout_marginRight="10dp"
android:layout_marginBottom="10dp"
android:text="按下录音" />
<com.yeyupiaoling.androidclient.AudioView
android:id="@+id/audioView"
android:layout_width="match_parent"
android:layout_height="100dp"
android:layout_above="@id/record_button"
android:layout_marginStart="10dp"
android:visibility="gone" />
<TextView
android:id="@+id/result_text"
android:layout_above="@id/record_button"
android:layout_width="match_parent"
android:hint="显示识别结果"
android:textSize="22sp"
android:layout_height="match_parent"/>
</RelativeLayout>

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 982 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

View File

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="purple_200">#FFBB86FC</color>
<color name="purple_500">#FF6200EE</color>
<color name="purple_700">#FF3700B3</color>
<color name="teal_200">#FF03DAC5</color>
<color name="teal_700">#FF018786</color>
<color name="black">#FF000000</color>
<color name="white">#FFFFFFFF</color>
</resources>

View File

@ -0,0 +1,3 @@
<resources>
<string name="app_name">FunASR</string>
</resources>

View File

@ -0,0 +1,16 @@
<resources xmlns:tools="http://schemas.android.com/tools">
<!-- Base application theme. -->
<style name="Theme.AndroidClient" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
<!-- Primary brand color. -->
<item name="colorPrimary">@color/purple_500</item>
<item name="colorPrimaryVariant">@color/purple_700</item>
<item name="colorOnPrimary">@color/white</item>
<!-- Secondary brand color. -->
<item name="colorSecondary">@color/teal_200</item>
<item name="colorSecondaryVariant">@color/teal_700</item>
<item name="colorOnSecondary">@color/black</item>
<!-- Status bar color. -->
<item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
<!-- Customize your theme here. -->
</style>
</resources>

View File

@ -0,0 +1,13 @@
<?xml version="1.0" encoding="utf-8"?><!--
Sample backup rules file; uncomment and customize as necessary.
See https://developer.android.com/guide/topics/data/autobackup
for details.
Note: This file is ignored for devices older that API 31
See https://developer.android.com/about/versions/12/backup-restore
-->
<full-backup-content>
<!--
<include domain="sharedpref" path="."/>
<exclude domain="sharedpref" path="device.xml"/>
-->
</full-backup-content>

View File

@ -0,0 +1,19 @@
<?xml version="1.0" encoding="utf-8"?><!--
Sample data extraction rules file; uncomment and customize as necessary.
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
for details.
-->
<data-extraction-rules>
<cloud-backup>
<!--
<include .../>
<exclude .../>
-->
</cloud-backup>
<!--
<device-transfer>
<include .../>
<exclude .../>
</device-transfer>
-->
</data-extraction-rules>

View File

@ -0,0 +1,17 @@
package com.yeyupiaoling.androidclient
import org.junit.Test
import org.junit.Assert.*
/**
* Example local unit test, which will execute on the development machine (host).
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
class ExampleUnitTest {
@Test
fun addition_isCorrect() {
assertEquals(4, 2 + 2)
}
}

View File

@ -0,0 +1,5 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
id 'com.android.application' version '8.1.2' apply false
id 'org.jetbrains.kotlin.android' version '1.8.10' apply false
}

View File

@ -0,0 +1,23 @@
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app's APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true
# Kotlin code style for this project: "official" or "obsolete":
kotlin.code.style=official
# Enables namespacing of each library's R class so that its R class includes only the
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true

View File

@ -0,0 +1,6 @@
#Fri Oct 13 14:55:29 CST 2023
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

View File

@ -0,0 +1,185 @@
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"

View File

@ -0,0 +1,89 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

View File

@ -0,0 +1,17 @@
pluginManagement {
repositories {
google()
mavenCentral()
gradlePluginPortal()
}
}
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
}
}
rootProject.name = "AndroidClient"
include ':app'

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

View File

@ -0,0 +1,13 @@
# AndroidClient
先说明本项目是使用WebSocket连接服务器的语音识别服务并不是将FunASR部署到Android里服务启动方式请查看文档[SDK_advanced_guide_online_zh.md](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/docs/SDK_advanced_guide_online_zh.md)。
使用最新的 Android Studio 打开`AndroidClient`项目,运行即可,在运行之前还需要修改`ASR_HOST`参数该参数是语音识别服务的WebSocket接口地址需要修复为开发者自己的服务地址。
应用只有一个功能,按钮下开始识别,松开按钮结束识别。
应用效果图:
<div align="center">
<img src="./images/demo.png" alt="应用效果图" width="300">
</div>

View File

@ -149,3 +149,34 @@ Node: '--quantize false' means fp32, otherwise it will be int8
| 64 (onnx int8) | 81s | 0.002232 | 448 |
| 96 (onnx fp32) | 117s | 0.003257 | 307 |
| 96 (onnx int8) | 81s | 0.002258 | 442 |
## [FSMN-VAD](https://www.modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) + [Paraformer-en](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-en-16k-common-vocab10020-onnx/summary) + [CT-Transformer](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary)
```shell
./funasr-onnx-offline-rtf \
--model-dir ./asrmodel/speech_paraformer-large_asr_nat-en-16k-common-vocab10020-onnx \
--quantize true \
--vad-dir ./asrmodel/speech_fsmn_vad_zh-cn-16k-common-pytorch \
--punc-dir ./asrmodel/punc_ct-transformer_zh-cn-common-vocab272727-pytorch \
--wav-path ./librispeech_test_clean.scp \
--thread-num 32
Node: '--quantize false' means fp32, otherwise it will be int8
```
### Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz 16core-32processor with avx512_vnni
| concurrent-tasks | processing time(s) | RTF | Speedup Rate |
|---------------------|:------------------:|----------|:------------:|
| 1 (onnx fp32) | 1327s | 0.0682 | 15 |
| 1 (onnx int8) | 734s | 0.0377 | 26 |
| 8 (onnx fp32) | 169s | 0.0087 | 114 |
| 8 (onnx int8) | 94s | 0.0048 | 205 |
| 16 (onnx fp32) | 89s | 0.0046 | 217 |
| 16 (onnx int8) | 50s | 0.0025 | 388 |
| 32 (onnx fp32) | 78s | 0.0040 | 248 |
| 32 (onnx int8) | 43s | 0.0022 | 448 |
| 64 (onnx fp32) | 79s | 0.0041 | 243 |
| 64 (onnx int8) | 44s | 0.0022 | 438 |
| 96 (onnx fp32) | 80s | 0.0041 | 240 |
| 96 (onnx int8) | 45s | 0.0023 | 428 |

View File

@ -1,13 +1,17 @@
# ONNXRuntime-python
## Install `funasr_onnx`
## Install `funasr-onnx`
install from pip
```shell
pip install -U funasr_onnx
pip install -U funasr-onnx
# For the users in China, you could install with the command:
# pip install -U funasr_onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple
# pip install -U funasr-onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple
# If you want to export .onnx file, you should install modelscope and funasr
pip install -U modelscope funasr
# For the users in China, you could install with the command:
# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
or install from source code

675
funasr/tasks/whisper.py Normal file
View File

@ -0,0 +1,675 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import yaml
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
LightweightConvolution2DTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import (
LightweightConvolutionTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.e2e_asr_bat import BATModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.models.whisper_models.model import Whisper, AudioEncoder, TextDecoder
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
multichannelfrontend=MultiChannelFrontend,
),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(
specaug=SpecAug,
specaug_lfr=SpecAugLFR,
),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
model_choices = ClassChoices(
"model",
classes=dict(
asr=ASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_online=ParaformerOnline,
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
neatcontextual_paraformer=NeatContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
bat=BATModel,
sa_asr=SAASRModel,
whisper=Whisper,
),
type_check=FunASRModel,
default="asr",
)
preencoder_choices = ClassChoices(
name="preencoder",
classes=dict(
sinc=LightweightSincConvs,
linear=LinearProjection,
),
type_check=AbsPreEncoder,
default=None,
optional=True,
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
chunk_conformer=ConformerChunkEncoder,
),
type_check=AbsEncoder,
default="rnn",
)
encoder_choices2 = ClassChoices(
"encoder2",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
),
type_check=AbsEncoder,
default="rnn",
)
asr_encoder_choices = ClassChoices(
"asr_encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
),
type_check=AbsEncoder,
default="rnn",
)
spk_encoder_choices = ClassChoices(
"spk_encoder",
classes=dict(
resnet34_diar=ResNet34Diar,
),
default="resnet34_diar",
)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
hugging_face_transformers=HuggingFaceTransformersPostEncoder,
),
type_check=AbsPostEncoder,
default=None,
optional=True,
)
decoder_choices = ClassChoices(
"decoder",
classes=dict(
transformer=TransformerDecoder,
lightweight_conv=LightweightConvolutionTransformerDecoder,
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
dynamic_conv=DynamicConvolutionTransformerDecoder,
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
rnn=RNNDecoder,
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
sa_decoder=SAAsrTransformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
)
decoder_choices2 = ClassChoices(
"decoder2",
classes=dict(
transformer=TransformerDecoder,
lightweight_conv=LightweightConvolutionTransformerDecoder,
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
dynamic_conv=DynamicConvolutionTransformerDecoder,
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
rnn=RNNDecoder,
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
),
type_check=AbsDecoder,
default="rnn",
)
rnnt_decoder_choices = ClassChoices(
"rnnt_decoder",
classes=dict(
rnnt=RNNTDecoder,
),
type_check=RNNTDecoder,
default="rnnt",
)
joint_network_choices = ClassChoices(
name="joint_network",
classes=dict(
joint_network=JointNetwork,
),
default="joint_network",
optional=True,
)
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
cif_predictor=CifPredictor,
ctc_predictor=None,
cif_predictor_v2=CifPredictorV2,
cif_predictor_v3=CifPredictorV3,
bat_predictor=BATPredictor,
),
type_check=None,
default="cif_predictor",
optional=True,
)
predictor_choices2 = ClassChoices(
name="predictor2",
classes=dict(
cif_predictor=CifPredictor,
ctc_predictor=None,
cif_predictor_v2=CifPredictorV2,
),
type_check=None,
default="cif_predictor",
optional=True,
)
stride_conv_choices = ClassChoices(
name="stride_conv",
classes=dict(
stride_conv1d=Conv1dSubsampling
),
type_check=None,
default="stride_conv1d",
optional=True,
)
class ASRTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --postencoder and --postencoder_conf
postencoder_choices,
# --decoder and --decoder_conf
decoder_choices,
# --predictor and --predictor_conf
predictor_choices,
# --encoder2 and --encoder2_conf
encoder_choices2,
# --decoder2 and --decoder2_conf
decoder_choices2,
# --predictor2 and --predictor2_conf
predictor_choices2,
# --stride_conv and --stride_conv_conf
stride_conv_choices,
# --rnnt_decoder and --rnnt_decoder_conf
rnnt_decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--max_spk_num",
type=int_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="seg_dict_file for text processing",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group.add_argument(
"--ctc_conf",
action=NestedDictAction,
default=get_default_kwargs(CTC),
help="The keyword arguments for CTC class.",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word", "phn"],
help="The text will be tokenized " "in the specified level token",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file of sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
default=None,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "text")
else:
# Recognition mode
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
if args.token_list is not None:
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
else:
vocab_size = args.vocab_size
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 9. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("asr")
model = model_class(
args.whisper_dims,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
device: Device type, "cpu", "cuda", or "cuda:N".
"""
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
if model_file is not None:
model_dict = torch.load(model_file, map_location=device)
args.whisper_dims = model_dict["dims"]
model = cls.build_model(args)
if not isinstance(model, FunASRModel):
raise RuntimeError(
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
model_name_pth = None
if model_file is not None:
logging.info("model_file is {}".format(model_file))
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
model_dir = os.path.dirname(model_file)
model_name = os.path.basename(model_file)
model_dict = torch.load(model_file, map_location=device)
model.load_state_dict(model_dict["model_state_dict"])
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
logging.info("model_file is saved to pth: {}".format(model_name_pth))
return model, args

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -0,0 +1 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@ -0,0 +1 @@
{"<|endoftext|>": 50257}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@ -0,0 +1 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,124 @@
import os
from functools import lru_cache
from typing import Union
import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F
from funasr.utils.whisper_utils.utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

View File

@ -0,0 +1,710 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from funasr.utils.whisper_utils.audio import CHUNK_LENGTH
from funasr.utils.whisper_utils.tokenizer import Tokenizer, get_tokenizer
from funasr.utils.whisper_utils.utils import compression_ratio
if TYPE_CHECKING:
from funasr.models.whisper_models.model import Whisper
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: Tensor
language: str
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""Perform a forward pass on the decoder and return per-token logits"""
raise NotImplementedError
def rearrange_kv_cache(self, source_indices) -> None:
"""Update the key-value cache according to the updated beams"""
raise NotImplementedError
def cleanup_caching(self) -> None:
"""Clean up any resources or hooks after decoding is finished"""
pass
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
for module, tensor in self.kv_cache.items():
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = tensor[source_indices].detach()
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : Tensor, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
"""
raise NotImplementedError
def finalize(
self, tokens: Tensor, sum_logprobs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : Tensor, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size:
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
]
return tokens, sum_logprobs
class LogitFilter:
def apply(self, logits: Tensor, tokens: Tensor) -> None:
"""Apply any filtering or masking to logits in-place
Parameters
----------
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits: Tensor, tokens: Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int]):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: Tensor, tokens: Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(
options.beam_size, tokenizer.eot, self.inference, options.patience
)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: Tensor):
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
return audio_features
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel)
if single:
result = result[0]
return result

View File

@ -0,0 +1,334 @@
import os
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import GPT2TokenizerFast
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
@dataclass(frozen=True)
class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast"
language: Optional[str]
sot_sequence: Tuple[int]
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
@property
@lru_cache()
def eot(self) -> int:
return self.tokenizer.eos_token_id
@property
@lru_cache()
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
@property
@lru_cache()
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
@property
@lru_cache()
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
@property
@lru_cache()
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
@property
@lru_cache()
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
@property
@lru_cache()
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
@property
@lru_cache()
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.")
@property
@lru_cache()
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
@property
@lru_cache()
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@property
@lru_cache()
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2", resource_path: str = None):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if resource_path is not None:
path = os.path.join(resource_path, name)
else:
path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path)
specials = [
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
language = language or "en"
else:
tokenizer_name = "gpt2"
task = None
language = None
tokenizer = build_tokenizer(name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))

View File

@ -0,0 +1,316 @@
import argparse
import os
import warnings
from typing import Optional, Tuple, Union, TYPE_CHECKING
import numpy as np
import torch
import tqdm
from funasr.utils.whisper_utils.audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from funasr.utils.whisper_utils.decoding import DecodingOptions, DecodingResult
from funasr.utils.whisper_utils.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from funasr.utils.whisper_utils.utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
if TYPE_CHECKING:
from .model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
mel = log_mel_spectrogram(audio)
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
return decode_result
seek = 0
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt)
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
all_segments.append(
{
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment.shape[-1] # fast-forward to the next segment boundary
continue
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
)
seek += segment.shape[-1]
all_tokens.extend(tokens.tolist())
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
def cli():
from . import available_models
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else:
temperature = [temperature]
threads = args.pop("threads")
if threads > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
if __name__ == '__main__':
cli()

View File

@ -0,0 +1,163 @@
import json
import os
import sys
import zlib
from typing import Callable, TextIO
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
def write_result(self, result: dict, file: TextIO):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True)
class WriteVTT(ResultWriter):
extension: str = "vtt"
def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
for segment in result["segments"]:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteSRT(ResultWriter):
extension: str = "srt"
def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO):
json.dump(result, file)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO):
for writer in all_writers:
writer(result, file)
return write_all
return writers[output_format](output_dir)

View File

@ -25,6 +25,7 @@ requirements = {
"sentencepiece",
"jieba",
"rotary_embedding_torch",
"ffmpeg",
# TTS
"pypinyin>=0.44.0",
"espnet_tts_frontend",
@ -41,6 +42,7 @@ requirements = {
"protobuf",
"tqdm",
"hdbscan",
"umap",
],
# train: The modules invoked when training only.
"train": [