Merge branch 'main' of github.com:alibaba-damo-academy/FunASR

add
This commit is contained in:
游雁 2023-09-13 09:33:54 +08:00
commit 33d3d20844
9 changed files with 229 additions and 9 deletions

View File

@ -415,7 +415,7 @@ def inference_paraformer(
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
if use_timestamp and timestamp is not None:
if use_timestamp and timestamp is not None and len(timestamp):
postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
@ -427,7 +427,7 @@ def inference_paraformer(
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
item = {'key': key, 'value': text_postprocessed}
if timestamp_postprocessed != "":
if timestamp_postprocessed != "" or len(timestamp) == 0:
item['timestamp'] = timestamp_postprocessed
asr_result_list.append(item)
finish_count += 1
@ -692,7 +692,7 @@ def inference_paraformer_vad_punc(
text, token, token_int = result[0], result[1], result[2]
time_stamp = result[4] if len(result[4]) > 0 else None
if use_timestamp and time_stamp is not None:
if use_timestamp and time_stamp is not None and len(time_stamp):
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
@ -717,7 +717,7 @@ def inference_paraformer_vad_punc(
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
item['text_postprocessed'] = text_postprocessed
if time_stamp_postprocessed != "":
if time_stamp_postprocessed != "" or len(time_stamp) == 0:
item['time_stamp'] = time_stamp_postprocessed
item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)

View File

@ -117,12 +117,25 @@ class Text2Punc:
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if (i==0 or self.punc_list[punctuations[i-1]] == "" or self.punc_list[punctuations[i-1]] == "") and len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = mini_sentence[i].capitalize()
if i == 0:
if len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
words_with_punc.append(self.punc_list[punctuations[i]])
punc_res = self.punc_list[punctuations[i]]
if len(mini_sentence[i][0].encode()) == 1:
if punc_res == "":
punc_res = ","
elif punc_res == "":
punc_res = "."
elif punc_res == "":
punc_res = "?"
words_with_punc.append(punc_res)
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
@ -131,9 +144,15 @@ class Text2Punc:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "":
elif new_mini_sentence[-1] == ",":
new_mini_sentence_out = new_mini_sentence[:-1] + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "" and len(new_mini_sentence[-1].encode())==0:
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out

View File

@ -254,7 +254,7 @@ class ModelExport:
if not os.path.exists(quant_model_path):
onnx_model = onnx.load(model_path)
nodes = [n.name for n in onnx_model.graph.node]
nodes_to_exclude = [m for m in nodes if 'output' in m]
nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
quantize_dynamic(
model_input=model_path,
model_output=quant_model_path,

View File

@ -0,0 +1,47 @@
# Service with http-python
## Server
1. Install requirements
```shell
cd funasr/runtime/python/http
pip install -r requirements.txt
```
2. Start server
```shell
python server.py --port 8000
```
More parameters:
```shell
python server.py \
--host [host ip] \
--port [server port] \
--asr_model [asr model_name] \
--punc_model [punc model_name] \
--ngpu [0 or 1] \
--ncpu [1 or 4] \
--certfile [path of certfile for ssl] \
--keyfile [path of keyfile for ssl] \
--temp_dir [upload file temp dir]
```
## Client
```shell
# get test audio file
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav
python client.py --host=127.0.0.1 --port=8000 --audio_path=asr_example_zh.wav
```
More parameters:
```shell
python server.py \
--host [sever ip] \
--port [sever port] \
--add_pun [add pun to result] \
--audio_path [use audio path]
```

View File

@ -0,0 +1,34 @@
import requests
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="127.0.0.1",
required=False,
help="sever ip")
parser.add_argument("--port",
type=int,
default=8000,
required=False,
help="server port")
parser.add_argument("--add_pun",
type=int,
default=1,
required=False,
help="add pun to result")
parser.add_argument("--audio_path",
type=str,
default='asr_example_zh.wav',
required=False,
help="use audio path")
args = parser.parse_args()
url = f'http://{args.host}:{args.port}/recognition'
data = {'add_pun': args.add_pun}
headers = {}
files = [('audio', ('file', open(args.audio_path, 'rb'), 'application/octet-stream'))]
response = requests.post(url, headers=headers, data=data, files=files)
print(response.text)

View File

@ -0,0 +1,6 @@
modelscope>=1.8.4
fastapi>=0.95.1
ffmpeg-python
aiofiles
uvicorn
requests

View File

@ -0,0 +1,107 @@
import argparse
import logging
import os
import random
import time
import aiofiles
import ffmpeg
import uvicorn
from fastapi import FastAPI, File, UploadFile, Body
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="0.0.0.0",
required=False,
help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
type=int,
default=8000,
required=False,
help="server port")
parser.add_argument("--asr_model",
type=str,
default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope")
parser.add_argument("--punc_model",
type=str,
default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="model from modelscope")
parser.add_argument("--ngpu",
type=int,
default=1,
help="0 for cpu, 1 for gpu")
parser.add_argument("--ncpu",
type=int,
default=4,
help="cpu cores")
parser.add_argument("--certfile",
type=str,
default=None,
required=False,
help="certfile for ssl")
parser.add_argument("--keyfile",
type=str,
default=None,
required=False,
help="keyfile for ssl")
parser.add_argument("--temp_dir",
type=str,
default="temp_dir/",
required=False,
help="temp dir")
args = parser.parse_args()
os.makedirs(args.temp_dir, exist_ok=True)
print("model loading")
# asr
inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
model=args.asr_model,
ngpu=args.ngpu,
ncpu=args.ncpu,
model_revision=None)
print(f'loaded asr models.')
if args.punc_model != "":
inference_pipeline_punc = pipeline(task=Tasks.punctuation,
model=args.punc_model,
model_revision="v1.0.2",
ngpu=args.ngpu,
ncpu=args.ncpu)
print(f'loaded pun models.')
else:
inference_pipeline_punc = None
app = FastAPI(title="FunASR")
@app.post("/recognition")
async def api_recognition(audio: UploadFile = File(..., description="audio file"),
add_pun: int = Body(1, description="add punctuation", embed=True)):
suffix = audio.filename.split('.')[-1]
audio_path = f'{args.temp_dir}/{int(time.time() * 1000)}_{random.randint(100, 999)}.{suffix}'
async with aiofiles.open(audio_path, 'wb') as out_file:
content = await audio.read()
await out_file.write(content)
audio_bytes, _ = (
ffmpeg.input(audio_path, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
rec_result = inference_pipeline_asr(audio_in=audio_bytes, param_dict={})
if add_pun:
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
ret = {"results": rec_result['text'], "code": 0}
return ret
if __name__ == '__main__':
uvicorn.run(app, host=args.host, port=args.port, ssl_keyfile=args.keyfile, ssl_certfile=args.certfile)

View File

@ -5,7 +5,7 @@ model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
result = model(wav_path, hotwords)
print(result)

View File

@ -333,7 +333,14 @@ class ContextualParaformer(Paraformer):
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
# hotwords.append('<s>')
def word_map(word):
return torch.tensor([self.vocab[i] for i in word])
hotwords = []
for c in word:
if c not in self.vocab.keys():
hotwords.append(8403)
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
else:
hotwords.append(self.vocab[c])
return torch.tensor(hotwords)
hotword_int = [word_map(i) for i in hotwords]
# import pdb; pdb.set_trace()
hotword_int.append(torch.tensor([1]))