Merge pull request #69 from iflamed/api

Add fastapi server
This commit is contained in:
zhifu gao 2024-07-25 23:41:10 +08:00 committed by GitHub
commit 8306f2e7a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 93 additions and 4 deletions

View File

@ -228,7 +228,11 @@ Note: Libtorch model is exported to the original model directory.
## Service
Undo
### Deployment with FastAPI
```shell
export SENSEVOICE_DEVICE=cuda:0
fastapi run --port 50000
```
## Finetune

View File

@ -234,7 +234,11 @@ print([rich_transcription_postprocess(i) for i in res])
### 展開
未完了
### FastAPIを使った展開
```shell
export SENSEVOICE_DEVICE=cuda:0
fastapi run --port 50000
```
## 微調整

View File

@ -237,7 +237,11 @@ print([rich_transcription_postprocess(i) for i in res])
### 部署
待完成
### 使用 FastAPI 部署
```shell
export SENSEVOICE_DEVICE=cuda:0
fastapi run --port 50000
```
## 微调

76
api.py Normal file
View File

@ -0,0 +1,76 @@
# Set the device with environment, default is cuda:0
# export SENSEVOICE_DEVICE=cuda:1
import os, re
from fastapi import FastAPI, File, Form
from fastapi.responses import HTMLResponse
from typing_extensions import Annotated
from typing import List
from enum import Enum
import torchaudio
from model import SenseVoiceSmall
from funasr.utils.postprocess_utils import rich_transcription_postprocess
class Language(str, Enum):
auto = "auto"
zh = "zh"
en = "en"
yue = "yue"
ja = "ja"
ko = "ko"
nospeech = "nospeech"
model_dir = "iic/SenseVoiceSmall"
m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device=os.getenv("SENSEVOICE_DEVICE", "cuda:0"))
m.eval()
regex = r"<\|.*\|>"
app = FastAPI()
@app.get("/", response_class=HTMLResponse)
async def root():
return """
<!DOCTYPE html>
<html>
<head>
<meta charset=utf-8>
<title>Api information</title>
</head>
<body>
<a href='./docs'>Documents of API</a>
</body>
</html>
"""
@app.post("/api/v1/asr")
async def turn_audio_to_text(files: Annotated[List[bytes], File(description="wav or mp3 audios in 16KHz")], keys: Annotated[str, Form(description="name of each audio joined with comma")], lang: Annotated[Language, Form(description="language of audio content")] = "auto"):
audios = []
audio_fs = 0
for file in files:
data_or_path_or_list, audio_fs = torchaudio.load(file)
data_or_path_or_list = data_or_path_or_list.mean(0)
audios.append(data_or_path_or_list)
if lang == "":
lang = "auto"
if keys == "":
key = ["wav_file_tmp_name"]
else:
key = keys.split(",")
res = m.inference(
data_in=audios,
language=lang, # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=False,
ban_emo_unk=False,
key=key,
fs=audio_fs,
**kwargs,
)
if len(res) == 0:
return {"result": []}
for it in res[0]:
it["raw_text"] = it["text"]
it["clean_text"] = re.sub(regex, "", it["text"], 0, re.MULTILINE)
it["text"] = rich_transcription_postprocess(it["text"])
return {"result": res[0]}

View File

@ -5,4 +5,5 @@ huggingface
huggingface_hub
funasr>=1.1.3
numpy<=1.26.4
gradio
gradio
fastapi>=0.111.1