mirror of
https://github.com/FunAudioLLM/SenseVoice.git
synced 2025-09-15 15:08:35 +08:00
Fix: api audio_fs is inconsistent
This commit is contained in:
parent
98e97e6216
commit
84b75f4d5e
39
api.py
39
api.py
@ -2,7 +2,7 @@
|
||||
# export SENSEVOICE_DEVICE=cuda:1
|
||||
|
||||
import os, re
|
||||
from fastapi import FastAPI, File, Form
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from fastapi.responses import HTMLResponse
|
||||
from typing_extensions import Annotated
|
||||
from typing import List
|
||||
@ -12,6 +12,8 @@ from model import SenseVoiceSmall
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
from io import BytesIO
|
||||
|
||||
TARGET_FS = 16000
|
||||
|
||||
|
||||
class Language(str, Enum):
|
||||
auto = "auto"
|
||||
@ -22,6 +24,7 @@ class Language(str, Enum):
|
||||
ko = "ko"
|
||||
nospeech = "nospeech"
|
||||
|
||||
|
||||
model_dir = "iic/SenseVoiceSmall"
|
||||
m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device=os.getenv("SENSEVOICE_DEVICE", "cuda:0"))
|
||||
m.eval()
|
||||
@ -46,29 +49,41 @@ async def root():
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@app.post("/api/v1/asr")
|
||||
async def turn_audio_to_text(files: Annotated[List[bytes], File(description="wav or mp3 audios in 16KHz")], keys: Annotated[str, Form(description="name of each audio joined with comma")], lang: Annotated[Language, Form(description="language of audio content")] = "auto"):
|
||||
async def turn_audio_to_text(
|
||||
files: Annotated[List[UploadFile], File(description="wav or mp3 audios in 16KHz")],
|
||||
keys: Annotated[str, Form(description="name of each audio joined with comma")] = None,
|
||||
lang: Annotated[Language, Form(description="language of audio content")] = "auto",
|
||||
):
|
||||
audios = []
|
||||
audio_fs = 0
|
||||
for file in files:
|
||||
file_io = BytesIO(file)
|
||||
file_io = BytesIO(await file.read())
|
||||
data_or_path_or_list, audio_fs = torchaudio.load(file_io)
|
||||
|
||||
# transform to target sample
|
||||
if audio_fs != TARGET_FS:
|
||||
resampler = torchaudio.transforms.Resample(orig_freq=audio_fs, new_freq=TARGET_FS)
|
||||
data_or_path_or_list = resampler(data_or_path_or_list)
|
||||
|
||||
data_or_path_or_list = data_or_path_or_list.mean(0)
|
||||
audios.append(data_or_path_or_list)
|
||||
file_io.close()
|
||||
|
||||
if lang == "":
|
||||
lang = "auto"
|
||||
if keys == "":
|
||||
key = ["wav_file_tmp_name"]
|
||||
|
||||
if not keys:
|
||||
key = [f.filename for f in files]
|
||||
else:
|
||||
key = keys.split(",")
|
||||
|
||||
res = m.inference(
|
||||
data_in=audios,
|
||||
language=lang, # "zh", "en", "yue", "ja", "ko", "nospeech"
|
||||
language=lang, # "zh", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
ban_emo_unk=False,
|
||||
key=key,
|
||||
fs=audio_fs,
|
||||
fs=TARGET_FS,
|
||||
**kwargs,
|
||||
)
|
||||
if len(res) == 0:
|
||||
@ -78,3 +93,9 @@ async def turn_audio_to_text(files: Annotated[List[bytes], File(description="wav
|
||||
it["clean_text"] = re.sub(regex, "", it["text"], 0, re.MULTILINE)
|
||||
it["text"] = rich_transcription_postprocess(it["text"])
|
||||
return {"result": res[0]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=50000)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user