diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 821f69429..ff8bb8c77 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -8,6 +8,7 @@ import os import codecs import tempfile import requests +import yaml from pathlib import Path from typing import Optional from typing import Sequence @@ -19,7 +20,6 @@ from typing import List import numpy as np import torch -import torchaudio from typeguard import check_argument_types from funasr.fileio.datadir_writer import DatadirWriter @@ -40,11 +40,12 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer +from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export + np.set_printoptions(threshold=np.inf) + class Speech2Text: """Speech2Text class @@ -89,7 +90,7 @@ class Speech2Text: ) frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) @@ -189,8 +190,7 @@ class Speech2Text: @torch.no_grad() def __call__( - self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, - begin_time: int = 0, end_time: int = None, + self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None ): """Inference @@ -201,38 +201,59 @@ class Speech2Text: """ assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None + results = [] + cache_en = cache["encoder"] + if speech.shape[1] < 16 * 60 and cache_en["is_final"]: + cache_en["tail_chunk"] = True + feats = cache_en["feats"] + feats_len = torch.tensor([feats.shape[1]]) + results = self.infer(feats, feats_len, cache) + return results else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - feats_len = cache["encoder"]["stride"] + cache["encoder"]["pad_left"] + cache["encoder"]["pad_right"] - feats = feats[:,cache["encoder"]["start_idx"]:cache["encoder"]["start_idx"]+feats_len,:] - feats_len = torch.tensor([feats_len]) - batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache} + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"]) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths - # a. To device + if feats.shape[1] != 0: + if cache_en["is_final"]: + if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]: + cache_en["last_chunk"] = True + else: + # first chunk + feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :] + feats_len = torch.tensor([feats_chunk1.shape[1]]) + results_chunk1 = self.infer(feats_chunk1, feats_len, cache) + + # last chunk + cache_en["last_chunk"] = True + feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :] + feats_len = torch.tensor([feats_chunk2.shape[1]]) + results_chunk2 = self.infer(feats_chunk2, feats_len, cache) + + return ["".join(results_chunk1 + results_chunk2)] + + results = self.infer(feats, feats_len, cache) + + return results + + @torch.no_grad() + def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None): + batch = {"speech": feats, "speech_lengths": feats_len} batch = to_device(batch, device=self.device) - # b. Forward Encoder - enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache) + enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache) if isinstance(enc, tuple): enc = enc[0] # assert len(enc) == 1, len(enc) enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] - pre_token_length = pre_token_length.floor().long() + pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1] if torch.max(pre_token_length) < 1: return [] decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache) @@ -279,166 +300,12 @@ class Speech2Text: text = self.tokenizer.tokens2text(token) else: text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + results.append(text) # assert check_return_type(results) return results -class Speech2TextExport: - """Speech2TextExport class - - """ - - def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - frontend_conf: dict = None, - hotword_list_or_file: str = None, - **kwargs, - ): - - # 1. Build ASR model - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device - ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() - - token_list = asr_model.token_list - - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text - if token_type is None: - token_type = asr_train_args.token_type - if bpemodel is None: - bpemodel = asr_train_args.bpemodel - - if token_type is None: - tokenizer = None - elif token_type == "bpe": - if bpemodel is not None: - tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) - else: - tokenizer = None - else: - tokenizer = build_tokenizer(token_type=token_type) - converter = TokenIDConverter(token_list=token_list) - logging.info(f"Text tokenizer: {tokenizer}") - - # self.asr_model = asr_model - self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - - self.device = device - self.dtype = dtype - self.nbest = nbest - self.frontend = frontend - - model = Paraformer_export(asr_model, onnx=False) - self.asr_model = model - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ): - """Inference - - Args: - speech: Input speech data - Returns: - text, token, token_int, hyp - - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - - enc_len_batch_total = feats_len.sum() - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - - decoder_outs = self.asr_model(**batch) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - - results = [] - b, n, d = decoder_out.size() - for i in range(b): - am_scores = decoder_out[i, :ys_pad_lens[i], :] - - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - yseq.tolist(), device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) - - return results - - def inference( maxlenratio: float, minlenratio: float, @@ -536,8 +403,6 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() - ncpu = kwargs.get("ncpu", 1) - torch.set_num_threads(ncpu) if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") @@ -580,11 +445,9 @@ def inference_modelscope( penalty=penalty, nbest=nbest, ) - if export_mode: - speech2text = Speech2TextExport(**speech2text_kwargs) - else: - speech2text = Speech2Text(**speech2text_kwargs) - + + speech2text = Speech2Text(**speech2text_kwargs) + def _load_bytes(input): middle_data = np.frombuffer(input, dtype=np.int16) middle_data = np.asarray(middle_data) @@ -599,7 +462,46 @@ def inference_modelscope( offset = i.min + abs_max array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array - + + def _read_yaml(yaml_path: Union[str, Path]) -> Dict: + if not Path(yaml_path).exists(): + raise FileExistsError(f'The {yaml_path} does not exist.') + + with open(str(yaml_path), 'rb') as f: + data = yaml.load(f, Loader=yaml.Loader) + return data + + def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + return cache + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + + def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + def _forward( data_path_and_name_and_type, raw_inputs: Union[np.ndarray, torch.Tensor] = None, @@ -610,123 +512,35 @@ def inference_modelscope( ): # 3. Build data-iterator + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": + raw_inputs = _load_bytes(data_path_and_name_and_type[0]) + raw_inputs = torch.tensor(raw_inputs) + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, np.ndarray): + raw_inputs = torch.tensor(raw_inputs) is_final = False cache = {} + chunk_size = [5, 10, 5] if param_dict is not None and "cache" in param_dict: cache = param_dict["cache"] if param_dict is not None and "is_final" in param_dict: is_final = param_dict["is_final"] + if param_dict is not None and "chunk_size" in param_dict: + chunk_size = param_dict["chunk_size"] - if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": - raw_inputs = _load_bytes(data_path_and_name_and_type[0]) - raw_inputs = torch.tensor(raw_inputs) - if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": - raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] - is_final = True - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, np.ndarray): - raw_inputs = torch.tensor(raw_inputs) # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about + raw_inputs = torch.unsqueeze(raw_inputs, axis=0) + input_lens = torch.tensor([raw_inputs.shape[1]]) asr_result_list = [] - results = [] - asr_result = "" - wait = True - if len(cache) == 0: - cache["encoder"] = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None, "is_final": is_final, "left": 0, "right": 0} - cache_de = {"decode_fsmn": None} - cache["decoder"] = cache_de - cache["first_chunk"] = True - cache["speech"] = [] - cache["accum_speech"] = 0 - if raw_inputs is not None: - if len(cache["speech"]) == 0: - cache["speech"] = raw_inputs - else: - cache["speech"] = torch.cat([cache["speech"], raw_inputs], dim=0) - cache["accum_speech"] += len(raw_inputs) - while cache["accum_speech"] >= 960: - if cache["first_chunk"]: - if cache["accum_speech"] >= 14400: - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["stride"] = 10 - cache["encoder"]["left"] = 5 - cache["encoder"]["right"] = 0 - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 4800 - cache["first_chunk"] = False - cache["encoder"]["start_idx"] = -5 - cache["encoder"]["is_final"] = False - wait = False - else: - if is_final: - cache["encoder"]["stride"] = len(cache["speech"]) // 960 - cache["encoder"]["pad_left"] = 0 - cache["encoder"]["pad_right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] = 0 - wait = False - else: - break - else: - if cache["accum_speech"] >= 19200: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = 10 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 9600 - wait = False - else: - if is_final: - cache["encoder"]["is_final"] = True - if cache["accum_speech"] >= 14400: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = 10 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = cache["accum_speech"] // 960 - 15 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 9600 - wait = False - else: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = cache["accum_speech"] // 960 - 5 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 0 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] = 0 - wait = False - else: - break - - if len(results) >= 1: - asr_result += results[0][0] - if asr_result == "": - asr_result = "sil" - if wait: - asr_result = "waiting_for_more_voice" - item = {'key': "utt", 'value': asr_result} - asr_result_list.append(item) - else: - return [] + cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) + cache["encoder"]["is_final"] = is_final + asr_result = speech2text(cache, raw_inputs, input_lens) + item = {'key': "utt", 'value': asr_result} + asr_result_list.append(item) + if is_final: + cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1) return asr_result_list return _forward @@ -920,5 +734,3 @@ if __name__ == "__main__": # # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') # print(rec_result) - - diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 699d85fdb..d02783f49 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -712,9 +712,9 @@ class ParaformerOnline(Paraformer): def calc_predictor_chunk(self, encoder_out, cache=None): - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \ + pre_acoustic_embeds, pre_token_length = \ self.predictor.forward_chunk(encoder_out, cache["encoder"]) - return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + return pre_acoustic_embeds, pre_token_length def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): decoder_outs = self.decoder.forward_chunk( diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index f2502bbb6..969ddadf2 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -6,9 +6,11 @@ from typing import Union import logging import torch import torch.nn as nn +import torch.nn.functional as F from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk from typeguard import check_argument_types import numpy as np +from funasr.torch_utils.device_funcs import to_device from funasr.modules.nets_utils import make_pad_mask from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder @@ -349,6 +351,23 @@ class SANMEncoder(AbsEncoder): return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None + def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}): + if len(cache) == 0: + return feats + # process last chunk + cache["feats"] = to_device(cache["feats"], device=feats.device) + overlap_feats = torch.cat((cache["feats"], feats), dim=1) + if cache["is_final"]: + cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :] + if not cache["last_chunk"]: + padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1] + overlap_feats = overlap_feats.transpose(1, 2) + overlap_feats = F.pad(overlap_feats, (0, padding_length)) + overlap_feats = overlap_feats.transpose(1, 2) + else: + cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] + return overlap_feats + def forward_chunk(self, xs_pad: torch.Tensor, ilens: torch.Tensor, @@ -360,7 +379,10 @@ class SANMEncoder(AbsEncoder): xs_pad = xs_pad else: xs_pad = self.embed(xs_pad, cache) - + if cache["tail_chunk"]: + xs_pad = cache["feats"] + else: + xs_pad = self._add_overlap_chunk(xs_pad, cache) encoder_outs = self.encoders0(xs_pad, None, None, None, None) xs_pad, masks = encoder_outs[0], encoder_outs[1] intermediate_outs = [] diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index a5273f841..c59e24502 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -2,6 +2,7 @@ import torch from torch import nn import logging import numpy as np +from funasr.torch_utils.device_funcs import to_device from funasr.modules.nets_utils import make_pad_mask from funasr.modules.streaming_utils.utils import sequence_mask @@ -200,7 +201,7 @@ class CifPredictorV2(nn.Module): return acoustic_embeds, token_num, alphas, cif_peak def forward_chunk(self, hidden, cache=None): - b, t, d = hidden.size() + batch_size, len_time, hidden_size = hidden.shape h = hidden context = h.transpose(1, 2) queries = self.pad(context) @@ -211,58 +212,81 @@ class CifPredictorV2(nn.Module): alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) alphas = alphas.squeeze(-1) - mask_chunk_predictor = None - if cache is not None: - mask_chunk_predictor = None - mask_chunk_predictor = torch.zeros_like(alphas) - mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0 - - if mask_chunk_predictor is not None: - alphas = alphas * mask_chunk_predictor - - if cache is not None: - if cache["is_final"]: - alphas[:, cache["stride"] + cache["pad_left"] - 1] += 0.45 - if cache["cif_hidden"] is not None: - hidden = torch.cat((cache["cif_hidden"], hidden), 1) - if cache["cif_alphas"] is not None: - alphas = torch.cat((cache["cif_alphas"], alphas), -1) - token_num = alphas.sum(-1) - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) - len_time = alphas.size(-1) - last_fire_place = len_time - 1 - last_fire_remainds = 0.0 - pre_alphas_length = 0 - last_fire = False - - mask_chunk_peak_predictor = None - if cache is not None: - mask_chunk_peak_predictor = None - mask_chunk_peak_predictor = torch.zeros_like(cif_peak) - if cache["cif_alphas"] is not None: - pre_alphas_length = cache["cif_alphas"].size(-1) - mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0 - mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0 - - if mask_chunk_peak_predictor is not None: - cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1) - - for i in range(len_time): - if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold: - last_fire_place = len_time - 1 - i - last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold - last_fire = True - break - if last_fire: - last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device) - cache["cif_hidden"] = hidden[:, last_fire_place:, :] - cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1) - else: - cache["cif_hidden"] = hidden - cache["cif_alphas"] = alphas - token_num_int = token_num.floor().type(torch.int32).item() - return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak + token_length = [] + list_fires = [] + list_frames = [] + cache_alphas = [] + cache_hiddens = [] + + if cache is not None and "chunk_size" in cache: + alphas[:, :cache["chunk_size"][0]] = 0.0 + alphas[:, sum(cache["chunk_size"][:2]):] = 0.0 + if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache: + cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device) + cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device) + hidden = torch.cat((cache["cif_hidden"], hidden), dim=1) + alphas = torch.cat((cache["cif_alphas"], alphas), dim=1) + if cache is not None and "last_chunk" in cache and cache["last_chunk"]: + tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device) + tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device) + tail_alphas = torch.tile(tail_alphas, (batch_size, 1)) + hidden = torch.cat((hidden, tail_hidden), dim=1) + alphas = torch.cat((alphas, tail_alphas), dim=1) + + len_time = alphas.shape[1] + for b in range(batch_size): + integrate = 0.0 + frames = torch.zeros((hidden_size), device=hidden.device) + list_frame = [] + list_fire = [] + for t in range(len_time): + alpha = alphas[b][t] + if alpha + integrate < self.threshold: + integrate += alpha + list_fire.append(integrate) + frames += alpha * hidden[b][t] + else: + frames += (self.threshold - integrate) * hidden[b][t] + list_frame.append(frames) + integrate += alpha + list_fire.append(integrate) + integrate -= self.threshold + frames = integrate * hidden[b][t] + + cache_alphas.append(integrate) + if integrate > 0.0: + cache_hiddens.append(frames / integrate) + else: + cache_hiddens.append(frames) + + token_length.append(torch.tensor(len(list_frame), device=alphas.device)) + list_fires.append(list_fire) + list_frames.append(list_frame) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + + max_token_len = max(token_length) + if max_token_len == 0: + return hidden, torch.stack(token_length, 0) + list_ls = [] + for b in range(batch_size): + pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device) + if token_length[b] == 0: + list_ls.append(pad_frames) + else: + list_frames[b] = torch.stack(list_frames[b]) + list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0)) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + return torch.stack(list_ls, 0), torch.stack(token_length, 0) + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): b, t, d = hidden.size() diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index c347e24f1..aaac80a7d 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -425,21 +425,14 @@ class StreamSinusoidalPositionEncoder(torch.nn.Module): return encoding.type(dtype) def forward(self, x, cache=None): - start_idx = 0 - pad_left = 0 - pad_right = 0 batch_size, timesteps, input_dim = x.size() + start_idx = 0 if cache is not None: start_idx = cache["start_idx"] - pad_left = cache["left"] - pad_right = cache["right"] + cache["start_idx"] += timesteps positions = torch.arange(1, timesteps+start_idx+1)[None, :] position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) - outputs = x + position_encoding[:, start_idx: start_idx + timesteps] - outputs = outputs.transpose(1, 2) - outputs = F.pad(outputs, (pad_left, pad_right)) - outputs = outputs.transpose(1, 2) - return outputs + return x + position_encoding[:, start_idx: start_idx + timesteps] class StreamingRelPositionalEncoding(torch.nn.Module): """Relative positional encoding. diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ASR_client.py deleted file mode 100644 index fe6798127..000000000 --- a/funasr/runtime/python/websocket/ASR_client.py +++ /dev/null @@ -1,100 +0,0 @@ -import pyaudio -# import websocket #区别服务端这里是 websocket-client库 -import time -import websockets -import asyncio -from queue import Queue -# import threading -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument("--host", - type=str, - default="localhost", - required=False, - help="host ip, localhost, 0.0.0.0") -parser.add_argument("--port", - type=int, - default=10095, - required=False, - help="grpc server port") -parser.add_argument("--chunk_size", - type=int, - default=300, - help="ms") - -args = parser.parse_args() - -voices = Queue() - - - -# 其他函数可以通过调用send(data)来发送数据,例如: -async def record(): - #print("2") - global voices - FORMAT = pyaudio.paInt16 - CHANNELS = 1 - RATE = 16000 - CHUNK = int(RATE / 1000 * args.chunk_size) - - p = pyaudio.PyAudio() - - stream = p.open(format=FORMAT, - channels=CHANNELS, - rate=RATE, - input=True, - frames_per_buffer=CHUNK) - - while True: - - data = stream.read(CHUNK) - - voices.put(data) - #print(voices.qsize()) - - await asyncio.sleep(0.01) - - - -async def ws_send(): - global voices - global websocket - print("started to sending data!") - while True: - while not voices.empty(): - data = voices.get() - voices.task_done() - try: - await websocket.send(data) # 通过ws对象发送数据 - except Exception as e: - print('Exception occurred:', e) - await asyncio.sleep(0.01) - await asyncio.sleep(0.01) - - - -async def message(): - global websocket - while True: - try: - print(await websocket.recv()) - except Exception as e: - print("Exception:", e) - - - -async def ws_client(): - global websocket # 定义一个全局变量ws,用于保存websocket连接对象 - # uri = "ws://11.167.134.197:8899" - uri = "ws://{}:{}".format(args.host, args.port) - #ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接 - async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None): - task = asyncio.create_task(record()) # 创建一个后台任务录音 - task2 = asyncio.create_task(ws_send()) # 创建一个后台任务发送 - task3 = asyncio.create_task(message()) # 创建一个后台接收消息的任务 - await asyncio.gather(task, task2, task3) - - -asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程 -asyncio.get_event_loop().run_forever() diff --git a/funasr/runtime/python/websocket/ASR_server.py b/funasr/runtime/python/websocket/ASR_server.py deleted file mode 100644 index 827df7b58..000000000 --- a/funasr/runtime/python/websocket/ASR_server.py +++ /dev/null @@ -1,185 +0,0 @@ -import asyncio -import websockets -import time -from queue import Queue -import threading -import argparse - -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger -import logging -import tracemalloc -tracemalloc.start() - -logger = get_logger(log_level=logging.CRITICAL) -logger.setLevel(logging.CRITICAL) - - -websocket_users = set() #维护客户端列表 - -parser = argparse.ArgumentParser() -parser.add_argument("--host", - type=str, - default="0.0.0.0", - required=False, - help="host ip, localhost, 0.0.0.0") -parser.add_argument("--port", - type=int, - default=10095, - required=False, - help="grpc server port") -parser.add_argument("--asr_model", - type=str, - default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - help="model from modelscope") -parser.add_argument("--vad_model", - type=str, - default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - help="model from modelscope") - -parser.add_argument("--punc_model", - type=str, - default="", - help="model from modelscope") -parser.add_argument("--ngpu", - type=int, - default=1, - help="0 for cpu, 1 for gpu") - -args = parser.parse_args() - -print("model loading") - - -# vad -inference_pipeline_vad = pipeline( - task=Tasks.voice_activity_detection, - model=args.vad_model, - model_revision=None, - output_dir=None, - batch_size=1, - mode='online', - ngpu=args.ngpu, -) -# param_dict_vad = {'in_cache': dict(), "is_final": False} - -# asr -param_dict_asr = {} -# param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开 -inference_pipeline_asr = pipeline( - task=Tasks.auto_speech_recognition, - model=args.asr_model, - param_dict=param_dict_asr, - ngpu=args.ngpu, -) -if args.punc_model != "": - # param_dict_punc = {'cache': list()} - inference_pipeline_punc = pipeline( - task=Tasks.punctuation, - model=args.punc_model, - model_revision=None, - ngpu=args.ngpu, - ) -else: - inference_pipeline_punc = None - -print("model loaded") - - - -async def ws_serve(websocket, path): - #speek = Queue() - frames = [] # 存储所有的帧数据 - buffer = [] # 存储缓存中的帧数据(最多两个片段) - RECORD_NUM = 0 - global websocket_users - speech_start, speech_end = False, False - # 调用asr函数 - websocket.param_dict_vad = {'in_cache': dict(), "is_final": False} - websocket.param_dict_punc = {'cache': list()} - websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 - websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 - websocket_users.add(websocket) - ss = threading.Thread(target=asr, args=(websocket,)) - ss.start() - - try: - async for message in websocket: - #voices.put(message) - #print("put") - #await websocket.send("123") - buffer.append(message) - if len(buffer) > 2: - buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 - - if speech_start: - frames.append(message) - RECORD_NUM += 1 - speech_start_i, speech_end_i = vad(message, websocket) - #print(speech_start_i, speech_end_i) - if speech_start_i: - speech_start = speech_start_i - frames = [] - frames.extend(buffer) # 把之前2个语音数据快加入 - if speech_end_i or RECORD_NUM > 300: - speech_start = False - audio_in = b"".join(frames) - websocket.speek.put(audio_in) - frames = [] # 清空所有的帧数据 - buffer = [] # 清空缓存中的帧数据(最多两个片段) - RECORD_NUM = 0 - if not websocket.send_msg.empty(): - await websocket.send(websocket.send_msg.get()) - websocket.send_msg.task_done() - - - except websockets.ConnectionClosed: - print("ConnectionClosed...", websocket_users) # 链接断开 - websocket_users.remove(websocket) - except websockets.InvalidState: - print("InvalidState...") # 无效状态 - except Exception as e: - print("Exception:", e) - - -def asr(websocket): # ASR推理 - global inference_pipeline_asr, inference_pipeline_punc - # global param_dict_punc - global websocket_users - while websocket in websocket_users: - if not websocket.speek.empty(): - audio_in = websocket.speek.get() - websocket.speek.task_done() - if len(audio_in) > 0: - rec_result = inference_pipeline_asr(audio_in=audio_in) - if inference_pipeline_punc is not None and 'text' in rec_result: - rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) - # print(rec_result) - if "text" in rec_result: - websocket.send_msg.put(rec_result["text"]) # 存入发送队列 直接调用send发送不了 - - time.sleep(0.1) - -def vad(data, websocket): # VAD推理 - global inference_pipeline_vad - #print(type(data)) - # print(param_dict_vad) - segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad) - # print(segments_result) - # print(param_dict_vad) - speech_start = False - speech_end = False - - if len(segments_result) == 0 or len(segments_result["text"]) > 1: - return speech_start, speech_end - if segments_result["text"][0][0] != -1: - speech_start = True - if segments_result["text"][0][1] != -1: - speech_end = True - return speech_start, speech_end - - -start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/README.md b/funasr/runtime/python/websocket/README.md index 73f8aebc3..723782f74 100644 --- a/funasr/runtime/python/websocket/README.md +++ b/funasr/runtime/python/websocket/README.md @@ -5,7 +5,7 @@ The audio data is in streaming, the asr inference process is in offline. ## For the Server -Install the modelscope and funasr +### Install the modelscope and funasr ```shell pip install -U modelscope funasr @@ -14,18 +14,34 @@ pip install -U modelscope funasr git clone https://github.com/alibaba/FunASR.git && cd FunASR ``` -Install the requirements for server +### Install the requirements for server ```shell cd funasr/runtime/python/websocket pip install -r requirements_server.txt ``` -Start server +### Start server +#### ASR offline server +[//]: # (```shell) + +[//]: # (python ws_server_online.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") + +[//]: # (```) +#### ASR streaming server ```shell -python ASR_server.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +python ws_server_online.py --host "0.0.0.0" --port 10095 ``` +#### + +#### ASR offline/online 2pass server + +[//]: # (```shell) + +[//]: # (python ws_server_online.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") + +[//]: # (```) ## For the client @@ -39,8 +55,10 @@ pip install -r requirements_client.txt Start client ```shell -python ASR_client.py --host "127.0.0.1" --port 10095 --chunk_size 300 +# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms +python ws_client.py --host "127.0.0.1" --port 10096 --chunk_size "5,10,5" ``` ## Acknowledge -1. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service. +1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR). +2. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service. diff --git a/funasr/runtime/python/websocket/parse_args.py b/funasr/runtime/python/websocket/parse_args.py new file mode 100644 index 000000000..2528a7624 --- /dev/null +++ b/funasr/runtime/python/websocket/parse_args.py @@ -0,0 +1,35 @@ +# -*- encoding: utf-8 -*- +import argparse +parser = argparse.ArgumentParser() +parser.add_argument("--host", + type=str, + default="0.0.0.0", + required=False, + help="host ip, localhost, 0.0.0.0") +parser.add_argument("--port", + type=int, + default=10095, + required=False, + help="grpc server port") +parser.add_argument("--asr_model", + type=str, + default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + help="model from modelscope") +parser.add_argument("--asr_model_online", + type=str, + default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", + help="model from modelscope") +parser.add_argument("--vad_model", + type=str, + default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + help="model from modelscope") +parser.add_argument("--punc_model", + type=str, + default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + help="model from modelscope") +parser.add_argument("--ngpu", + type=int, + default=1, + help="0 for cpu, 1 for gpu") + +args = parser.parse_args() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py new file mode 100644 index 000000000..8bbf1032d --- /dev/null +++ b/funasr/runtime/python/websocket/ws_client.py @@ -0,0 +1,182 @@ +# -*- encoding: utf-8 -*- +import os +import time +import websockets +import asyncio +# import threading +import argparse +import json + +parser = argparse.ArgumentParser() +parser.add_argument("--host", + type=str, + default="localhost", + required=False, + help="host ip, localhost, 0.0.0.0") +parser.add_argument("--port", + type=int, + default=10095, + required=False, + help="grpc server port") +parser.add_argument("--chunk_size", + type=str, + default="5, 10, 5", + help="chunk") +parser.add_argument("--chunk_interval", + type=int, + default=10, + help="chunk") +parser.add_argument("--audio_in", + type=str, + default=None, + help="audio_in") + +args = parser.parse_args() +args.chunk_size = [int(x) for x in args.chunk_size.split(",")] + +# voices = asyncio.Queue() +from queue import Queue +voices = Queue() + +# 其他函数可以通过调用send(data)来发送数据,例如: +async def record_microphone(): + is_finished = False + import pyaudio + #print("2") + global voices + FORMAT = pyaudio.paInt16 + CHANNELS = 1 + RATE = 16000 + chunk_size = 60*args.chunk_size[1]/args.chunk_interval + CHUNK = int(RATE / 1000 * chunk_size) + + p = pyaudio.PyAudio() + + stream = p.open(format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=CHUNK) + is_speaking = True + while True: + + data = stream.read(CHUNK) + data = data.decode('ISO-8859-1') + message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished}) + + voices.put(message) + #print(voices.qsize()) + + await asyncio.sleep(0.005) + +# 其他函数可以通过调用send(data)来发送数据,例如: +async def record_from_scp(): + import wave + global voices + is_finished = False + if args.audio_in.endswith(".scp"): + f_scp = open(args.audio_in) + wavs = f_scp.readlines() + else: + wavs = [args.audio_in] + for wav in wavs: + wav_splits = wav.strip().split() + wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0] + # bytes_f = open(wav_path, "rb") + # bytes_data = bytes_f.read() + with wave.open(wav_path, "rb") as wav_file: + # 获取音频参数 + params = wav_file.getparams() + # 获取头信息的长度 + # header_length = wav_file.getheaders()[0][1] + # 读取音频帧数据,跳过头信息 + # wav_file.setpos(header_length) + frames = wav_file.readframes(wav_file.getnframes()) + + # 将音频帧数据转换为字节类型的数据 + audio_bytes = bytes(frames) + # stride = int(args.chunk_size/1000*16000*2) + stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2) + chunk_num = (len(audio_bytes)-1)//stride + 1 + # print(stride) + is_speaking = True + for i in range(chunk_num): + if i == chunk_num-1: + is_speaking = False + beg = i*stride + data = audio_bytes[beg:beg+stride] + data = data.decode('ISO-8859-1') + message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished}) + voices.put(message) + # print("data_chunk: ", len(data_chunk)) + # print(voices.qsize()) + + await asyncio.sleep(60*args.chunk_size[1]/args.chunk_interval/1000) + + is_finished = True + message = json.dumps({"is_finished": is_finished}) + voices.put(message) + +async def ws_send(): + global voices + global websocket + print("started to sending data!") + while True: + while not voices.empty(): + data = voices.get() + voices.task_done() + try: + await websocket.send(data) # 通过ws对象发送数据 + except Exception as e: + print('Exception occurred:', e) + await asyncio.sleep(0.005) + await asyncio.sleep(0.005) + + + +async def message(): + global websocket + text_print = "" + while True: + try: + meg = await websocket.recv() + meg = json.loads(meg) + # print(meg, end = '') + # print("\r") + text = meg["text"][0] + text_print += text + text_print = text_print[-55:] + os.system('clear') + print("\r"+text_print) + except Exception as e: + print("Exception:", e) + + +async def print_messge(): + global websocket + while True: + try: + meg = await websocket.recv() + meg = json.loads(meg) + print(meg) + except Exception as e: + print("Exception:", e) + + +async def ws_client(): + global websocket # 定义一个全局变量ws,用于保存websocket连接对象 + # uri = "ws://11.167.134.197:8899" + uri = "ws://{}:{}".format(args.host, args.port) + #ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接 + async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None): + if args.audio_in is not None: + task = asyncio.create_task(record_from_scp()) # 创建一个后台任务录音 + else: + task = asyncio.create_task(record_microphone()) # 创建一个后台任务录音 + task2 = asyncio.create_task(ws_send()) # 创建一个后台任务发送 + task3 = asyncio.create_task(message()) # 创建一个后台接收消息的任务 + await asyncio.gather(task, task2, task3) + + +asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程 +asyncio.get_event_loop().run_forever() diff --git a/funasr/runtime/python/websocket/ws_server_online.py b/funasr/runtime/python/websocket/ws_server_online.py new file mode 100644 index 000000000..7ef0e2125 --- /dev/null +++ b/funasr/runtime/python/websocket/ws_server_online.py @@ -0,0 +1,108 @@ +import asyncio +import json +import websockets +import time +from queue import Queue +import threading +import logging +import tracemalloc +import numpy as np + +from parse_args import args +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from funasr_onnx.utils.frontend import load_bytes + +tracemalloc.start() + +logger = get_logger(log_level=logging.CRITICAL) +logger.setLevel(logging.CRITICAL) + + +websocket_users = set() + + +print("model loading") + +inference_pipeline_asr_online = pipeline( + task=Tasks.auto_speech_recognition, + model=args.asr_model_online, + model_revision='v1.0.4') + +print("model loaded") + + + +async def ws_serve(websocket, path): + frames_online = [] + global websocket_users + websocket.send_msg = Queue() + websocket_users.add(websocket) + websocket.param_dict_asr_online = {"cache": dict()} + websocket.speek_online = Queue() + ss_online = threading.Thread(target=asr_online, args=(websocket,)) + ss_online.start() + + try: + async for message in websocket: + message = json.loads(message) + is_finished = message["is_finished"] + if not is_finished: + audio = bytes(message['audio'], 'ISO-8859-1') + + is_speaking = message["is_speaking"] + websocket.param_dict_asr_online["is_final"] = not is_speaking + + websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"] + + + frames_online.append(audio) + + if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking: + + audio_in = b"".join(frames_online) + websocket.speek_online.put(audio_in) + frames_online = [] + + if not websocket.send_msg.empty(): + await websocket.send(websocket.send_msg.get()) + websocket.send_msg.task_done() + + + except websockets.ConnectionClosed: + print("ConnectionClosed...", websocket_users) # 链接断开 + websocket_users.remove(websocket) + except websockets.InvalidState: + print("InvalidState...") # 无效状态 + except Exception as e: + print("Exception:", e) + + + +def asr_online(websocket): # ASR推理 + global websocket_users + while websocket in websocket_users: + if not websocket.speek_online.empty(): + audio_in = websocket.speek_online.get() + websocket.speek_online.task_done() + if len(audio_in) > 0: + # print(len(audio_in)) + audio_in = load_bytes(audio_in) + rec_result = inference_pipeline_asr_online(audio_in=audio_in, + param_dict=websocket.param_dict_asr_online) + if websocket.param_dict_asr_online["is_final"]: + websocket.param_dict_asr_online["cache"] = dict() + + if "text" in rec_result: + if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": + print(rec_result["text"]) + message = json.dumps({"mode": "online", "text": rec_result["text"]}) + websocket.send_msg.put(message) + + time.sleep(0.005) + + +start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() \ No newline at end of file