diff --git a/api.py b/api.py index 2550dc8..88d0398 100644 --- a/api.py +++ b/api.py @@ -2,7 +2,7 @@ # export SENSEVOICE_DEVICE=cuda:1 import os, re -from fastapi import FastAPI, File, Form +from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import HTMLResponse from typing_extensions import Annotated from typing import List @@ -12,6 +12,8 @@ from model import SenseVoiceSmall from funasr.utils.postprocess_utils import rich_transcription_postprocess from io import BytesIO +TARGET_FS = 16000 + class Language(str, Enum): auto = "auto" @@ -22,6 +24,7 @@ class Language(str, Enum): ko = "ko" nospeech = "nospeech" + model_dir = "iic/SenseVoiceSmall" m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device=os.getenv("SENSEVOICE_DEVICE", "cuda:0")) m.eval() @@ -46,29 +49,41 @@ async def root(): """ + @app.post("/api/v1/asr") -async def turn_audio_to_text(files: Annotated[List[bytes], File(description="wav or mp3 audios in 16KHz")], keys: Annotated[str, Form(description="name of each audio joined with comma")], lang: Annotated[Language, Form(description="language of audio content")] = "auto"): +async def turn_audio_to_text( + files: Annotated[List[UploadFile], File(description="wav or mp3 audios in 16KHz")], + keys: Annotated[str, Form(description="name of each audio joined with comma")] = None, + lang: Annotated[Language, Form(description="language of audio content")] = "auto", +): audios = [] - audio_fs = 0 for file in files: - file_io = BytesIO(file) + file_io = BytesIO(await file.read()) data_or_path_or_list, audio_fs = torchaudio.load(file_io) + + # transform to target sample + if audio_fs != TARGET_FS: + resampler = torchaudio.transforms.Resample(orig_freq=audio_fs, new_freq=TARGET_FS) + data_or_path_or_list = resampler(data_or_path_or_list) + data_or_path_or_list = data_or_path_or_list.mean(0) audios.append(data_or_path_or_list) - file_io.close() + if lang == "": lang = "auto" - if keys == "": - key = ["wav_file_tmp_name"] + + if not keys: + key = [f.filename for f in files] else: key = keys.split(",") + res = m.inference( data_in=audios, - language=lang, # "zh", "en", "yue", "ja", "ko", "nospeech" + language=lang, # "zh", "en", "yue", "ja", "ko", "nospeech" use_itn=False, ban_emo_unk=False, key=key, - fs=audio_fs, + fs=TARGET_FS, **kwargs, ) if len(res) == 0: @@ -78,3 +93,9 @@ async def turn_audio_to_text(files: Annotated[List[bytes], File(description="wav it["clean_text"] = re.sub(regex, "", it["text"], 0, re.MULTILINE) it["text"] = rich_transcription_postprocess(it["text"]) return {"result": res[0]} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=50000)