diff --git a/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml index 47bc6bdb6..18614ddf3 100644 --- a/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml +++ b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml @@ -10,6 +10,7 @@ frontend_conf: lfr_m: 1 lfr_n: 1 use_channel: 0 + mc: False # encoder related asr_encoder: conformer diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py index 6718f3f6c..abbcd1b05 100644 --- a/funasr/models/frontend/default.py +++ b/funasr/models/frontend/default.py @@ -77,8 +77,8 @@ class DefaultFrontend(AbsFrontend): htk=htk, ) self.n_mels = n_mels - self.frontend_type = "default" self.use_channel = use_channel + self.frontend_type = "default" def output_size(self) -> int: return self.n_mels @@ -146,9 +146,11 @@ class MultiChannelFrontend(AbsFrontend): def __init__( self, fs: Union[int, str] = 16000, - n_fft: int = 400, - frame_length: int = 25, - frame_shift: int = 10, + n_fft: int = 512, + win_length: int = None, + hop_length: int = None, + frame_length: int = None, + frame_shift: int = None, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, @@ -162,7 +164,8 @@ class MultiChannelFrontend(AbsFrontend): use_channel: int = None, lfr_m: int = 1, lfr_n: int = 1, - cmvn_file: str = None + cmvn_file: str = None, + mc: bool = True ): assert check_argument_types() super().__init__() @@ -171,8 +174,18 @@ class MultiChannelFrontend(AbsFrontend): # Deepcopy (In general, dict shouldn't be used as default arg) frontend_conf = copy.deepcopy(frontend_conf) - self.win_length = frame_length * 16 - self.hop_length = frame_shift * 16 + if win_length is None and hop_length is None: + self.win_length = frame_length * 16 + self.hop_length = frame_shift * 16 + elif frame_length is None and frame_shift is None: + self.win_length = self.win_length + self.hop_length = self.hop_length + else: + logging.error( + "Only one of (win_length, hop_length) and (frame_length, frame_shift)" + "can be set." + ) + exit(1) if apply_stft: self.stft = Stft( @@ -202,17 +215,19 @@ class MultiChannelFrontend(AbsFrontend): htk=htk, ) self.n_mels = n_mels - self.frontend_type = "default" self.use_channel = use_channel - if self.use_channel is not None: - logging.info("use the channel %d" % (self.use_channel)) - else: - logging.info("random select channel") - self.cmvn_file = cmvn_file - if self.cmvn_file is not None: - mean, std = self._load_cmvn(self.cmvn_file) - self.register_buffer("mean", torch.from_numpy(mean)) - self.register_buffer("std", torch.from_numpy(std)) + self.mc = mc + if not self.mc: + if self.use_channel is not None: + logging.info("use the channel %d" % (self.use_channel)) + else: + logging.info("random select channel") + self.cmvn_file = cmvn_file + if self.cmvn_file is not None: + mean, std = self._load_cmvn(self.cmvn_file) + self.register_buffer("mean", torch.from_numpy(mean)) + self.register_buffer("std", torch.from_numpy(std)) + self.frontend_type = "multichannelfrontend" def output_size(self) -> int: return self.n_mels @@ -233,8 +248,8 @@ class MultiChannelFrontend(AbsFrontend): # input_stft: (Batch, Length, [Channel], Freq) input_stft, _, mask = self.frontend(input_stft, feats_lens) - # 3. [Multi channel case]: Select a channel - if input_stft.dim() == 4: + # 3. [Multi channel case]: Select a channel(sa_asr) + if input_stft.dim() == 4 and not self.mc: # h: (B, T, C, F) -> h: (B, T, F) if self.training: if self.use_channel is not None: @@ -256,27 +271,37 @@ class MultiChannelFrontend(AbsFrontend): # input_power: (Batch, [Channel,] Length, Freq) # -> input_feats: (Batch, Length, Dim) input_feats, _ = self.logmel(input_power, feats_lens) - - # 6. Apply CMVN - if self.cmvn_file is not None: - if feats_lens is None: - feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1)) - self.mean = self.mean.to(input_feats.device, input_feats.dtype) - self.std = self.std.to(input_feats.device, input_feats.dtype) - mask = make_pad_mask(feats_lens, input_feats, 1) - - if input_feats.requires_grad: - input_feats = input_feats + self.mean + if self.mc: + # MFCCA + if input_feats.dim() ==4: + bt = input_feats.size(0) + channel_size = input_feats.size(2) + input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous() + feats_lens = feats_lens.repeat(1,channel_size).squeeze() else: - input_feats += self.mean - if input_feats.requires_grad: - input_feats = input_feats.masked_fill(mask, 0.0) - else: - input_feats.masked_fill_(mask, 0.0) + channel_size = 1 + return input_feats, feats_lens, channel_size + else: + # 6. Apply CMVN + if self.cmvn_file is not None: + if feats_lens is None: + feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1)) + self.mean = self.mean.to(input_feats.device, input_feats.dtype) + self.std = self.std.to(input_feats.device, input_feats.dtype) + mask = make_pad_mask(feats_lens, input_feats, 1) - input_feats *= self.std + if input_feats.requires_grad: + input_feats = input_feats + self.mean + else: + input_feats += self.mean + if input_feats.requires_grad: + input_feats = input_feats.masked_fill(mask, 0.0) + else: + input_feats.masked_fill_(mask, 0.0) - return input_feats, feats_lens + input_feats *= self.std + + return input_feats, feats_lens def _compute_stft( self, input: torch.Tensor, input_lengths: torch.Tensor @@ -313,4 +338,4 @@ class MultiChannelFrontend(AbsFrontend): continue means = np.array(means_list).astype(np.float) vars = np.array(vars_list).astype(np.float) - return means, vars \ No newline at end of file + return means, vars diff --git a/funasr/runtime/csharp/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj new file mode 100644 index 000000000..b494bb50c --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj @@ -0,0 +1,18 @@ + + + + Exe + net6.0 + enable + enable + + + + + + + + + + + diff --git a/funasr/runtime/csharp/AliFsmnVadSharp.Examples/Program.cs b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/Program.cs new file mode 100644 index 000000000..dd3bf78b9 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/Program.cs @@ -0,0 +1,61 @@ +using AliFsmnVadSharp; +using AliFsmnVadSharp.Model; +using NAudio.Wave; + +internal static class Program +{ + [STAThread] + private static void Main() + { + string applicationBase = AppDomain.CurrentDomain.BaseDirectory; + string modelFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"; + string configFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.yaml"; + string mvnFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.mvn"; + int batchSize = 2; + TimeSpan start_time0 = new TimeSpan(DateTime.Now.Ticks); + AliFsmnVad aliFsmnVad = new AliFsmnVad(modelFilePath, configFilePath, mvnFilePath, batchSize); + TimeSpan end_time0 = new TimeSpan(DateTime.Now.Ticks); + double elapsed_milliseconds0 = end_time0.TotalMilliseconds - start_time0.TotalMilliseconds; + Console.WriteLine("load model and init config elapsed_milliseconds:{0}", elapsed_milliseconds0.ToString()); + List samples = new List(); + TimeSpan total_duration = new TimeSpan(0L); + for (int i = 0; i < 2; i++) + { + string wavFilePath = string.Format(applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/example/{0}.wav", i.ToString());//vad_example + if (!File.Exists(wavFilePath)) + { + continue; + } + AudioFileReader _audioFileReader = new AudioFileReader(wavFilePath); + byte[] datas = new byte[_audioFileReader.Length]; + _audioFileReader.Read(datas, 0, datas.Length); + TimeSpan duration = _audioFileReader.TotalTime; + float[] wavdata = new float[datas.Length / 4]; + Buffer.BlockCopy(datas, 0, wavdata, 0, datas.Length); + float[] sample = wavdata.Select((float x) => x * 32768f).ToArray(); + samples.Add(wavdata); + total_duration += duration; + } + TimeSpan start_time = new TimeSpan(DateTime.Now.Ticks); + //SegmentEntity[] segments_duration = aliFsmnVad.GetSegments(samples); + SegmentEntity[] segments_duration = aliFsmnVad.GetSegmentsByStep(samples); + TimeSpan end_time = new TimeSpan(DateTime.Now.Ticks); + Console.WriteLine("vad infer result:"); + foreach (SegmentEntity segment in segments_duration) + { + Console.Write("["); + foreach (var x in segment.Segment) + { + Console.Write("[" + string.Join(",", x.ToArray()) + "]"); + } + Console.Write("]\r\n"); + } + + double elapsed_milliseconds = end_time.TotalMilliseconds - start_time.TotalMilliseconds; + double rtf = elapsed_milliseconds / total_duration.TotalMilliseconds; + Console.WriteLine("elapsed_milliseconds:{0}", elapsed_milliseconds.ToString()); + Console.WriteLine("total_duration:{0}", total_duration.TotalMilliseconds.ToString()); + Console.WriteLine("rtf:{1}", "0".ToString(), rtf.ToString()); + Console.WriteLine("------------------------"); + } +} \ No newline at end of file diff --git a/funasr/runtime/csharp/AliFsmnVadSharp.sln b/funasr/runtime/csharp/AliFsmnVadSharp.sln new file mode 100644 index 000000000..8bf24aa81 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp.sln @@ -0,0 +1,37 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.1.32210.238 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AliFsmnVadSharp", "AliFsmnVadSharp\AliFsmnVadSharp.csproj", "{BFB82F2E-AD5B-405C-AAFF-3CE33C548748}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AliFsmnVadSharp.Examples", "AliFsmnVadSharp.Examples\AliFsmnVadSharp.Examples.csproj", "{2FFA4D03-A62B-435B-B57B-7E49209810E1}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{212561CC-9836-4F45-A31B-298EF576F519}" + ProjectSection(SolutionItems) = preProject + license = license + README.md = README.md + EndProjectSection +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Release|Any CPU.Build.0 = Release|Any CPU + {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {FCC1BBCC-91A3-4223-B368-D272FB5108B6} + EndGlobalSection +EndGlobal diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVad.cs b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVad.cs new file mode 100644 index 000000000..f42bfb12e --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVad.cs @@ -0,0 +1,387 @@ +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; +using Microsoft.Extensions.Logging; +using AliFsmnVadSharp.Model; +using AliFsmnVadSharp.Utils; + +namespace AliFsmnVadSharp +{ + public class AliFsmnVad + { + private InferenceSession _onnxSession; + private readonly ILogger _logger; + private string _frontend; + private WavFrontend _wavFrontend; + private int _batchSize = 1; + private int _max_end_sil = int.MinValue; + private EncoderConfEntity _encoderConfEntity; + private VadPostConfEntity _vad_post_conf; + + public AliFsmnVad(string modelFilePath, string configFilePath, string mvnFilePath, int batchSize = 1) + { + Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions(); + options.AppendExecutionProvider_CPU(0); + options.InterOpNumThreads = 1; + _onnxSession = new InferenceSession(modelFilePath, options); + + VadYamlEntity vadYamlEntity = YamlHelper.ReadYaml(configFilePath); + _wavFrontend = new WavFrontend(mvnFilePath, vadYamlEntity.frontend_conf); + _frontend = vadYamlEntity.frontend; + _vad_post_conf = vadYamlEntity.vad_post_conf; + _batchSize = batchSize; + _max_end_sil = _max_end_sil != int.MinValue ? _max_end_sil : vadYamlEntity.vad_post_conf.max_end_silence_time; + _encoderConfEntity = vadYamlEntity.encoder_conf; + + ILoggerFactory loggerFactory = new LoggerFactory(); + _logger = new Logger(loggerFactory); + } + + public SegmentEntity[] GetSegments(List samples) + { + int waveform_nums = samples.Count; + _batchSize = Math.Min(waveform_nums, _batchSize); + SegmentEntity[] segments = new SegmentEntity[waveform_nums]; + for (int beg_idx = 0; beg_idx < waveform_nums; beg_idx += _batchSize) + { + int end_idx = Math.Min(waveform_nums, beg_idx + _batchSize); + List waveform_list = new List(); + for (int i = beg_idx; i < end_idx; i++) + { + waveform_list.Add(samples[i]); + } + List vadInputEntitys = ExtractFeats(waveform_list); + try + { + int t_offset = 0; + int step = Math.Min(waveform_list.Max(x => x.Length), 6000); + bool is_final = true; + List vadOutputEntitys = Infer(vadInputEntitys); + for (int batch_num = beg_idx; batch_num < end_idx; batch_num++) + { + var scores = vadOutputEntitys[batch_num - beg_idx].Scores; + SegmentEntity[] segments_part = vadInputEntitys[batch_num].VadScorer.DefaultCall(scores, waveform_list[batch_num - beg_idx], is_final: is_final, max_end_sil: _max_end_sil, online: false); + if (segments_part.Length > 0) + { +#pragma warning disable CS8602 // 解引用可能出现空引用。 + if (segments[batch_num] == null) + { + segments[batch_num] = new SegmentEntity(); + } + segments[batch_num].Segment.AddRange(segments_part[0].Segment); // +#pragma warning restore CS8602 // 解引用可能出现空引用。 + + } + } + } + catch (OnnxRuntimeException ex) + { + _logger.LogWarning("input wav is silence or noise"); + segments = null; + } +// for (int batch_num = 0; batch_num < _batchSize; batch_num++) +// { +// List segment_waveforms = new List(); +// foreach (int[] segment in segments[beg_idx + batch_num].Segment) +// { +// // (int)(16000 * (segment[0] / 1000.0) * 2); +// int frame_length = (((6000 * 400) / 400 - 1) * 160 + 400) / 60 / 1000; +// int frame_start = segment[0] * frame_length; +// int frame_end = segment[1] * frame_length; +// float[] segment_waveform = new float[frame_end - frame_start]; +// Array.Copy(waveform_list[batch_num], frame_start, segment_waveform, 0, segment_waveform.Length); +// segment_waveforms.Add(segment_waveform); +// } +// segments[beg_idx + batch_num].Waveform.AddRange(segment_waveforms); +// } + } + + return segments; + } + + public SegmentEntity[] GetSegmentsByStep(List samples) + { + int waveform_nums = samples.Count; + _batchSize=Math.Min(waveform_nums, _batchSize); + SegmentEntity[] segments = new SegmentEntity[waveform_nums]; + for (int beg_idx = 0; beg_idx < waveform_nums; beg_idx += _batchSize) + { + int end_idx = Math.Min(waveform_nums, beg_idx + _batchSize); + List waveform_list = new List(); + for (int i = beg_idx; i < end_idx; i++) + { + waveform_list.Add(samples[i]); + } + List vadInputEntitys = ExtractFeats(waveform_list); + int feats_len = vadInputEntitys.Max(x => x.SpeechLength); + List in_cache = new List(); + in_cache = PrepareCache(in_cache); + try + { + int step = Math.Min(vadInputEntitys.Max(x => x.SpeechLength), 6000 * 400); + bool is_final = true; + for (int t_offset = 0; t_offset < (int)(feats_len); t_offset += Math.Min(step, feats_len - t_offset)) + { + + if (t_offset + step >= feats_len - 1) + { + step = feats_len - t_offset; + is_final = true; + } + else + { + is_final = false; + } + List vadInputEntitys_step = new List(); + foreach (VadInputEntity vadInputEntity in vadInputEntitys) + { + VadInputEntity vadInputEntity_step = new VadInputEntity(); + float[]? feats = vadInputEntity.Speech; + int curr_step = Math.Min(feats.Length - t_offset, step); + if (curr_step <= 0) + { + vadInputEntity_step.Speech = new float[32000]; + vadInputEntity_step.SpeechLength = 0; + vadInputEntity_step.InCaches = in_cache; + vadInputEntity_step.Waveform = new float[(((int)(32000) / 400 - 1) * 160 + 400)]; + vadInputEntitys_step.Add(vadInputEntity_step); + continue; + } + float[]? feats_step = new float[curr_step]; + Array.Copy(feats, t_offset, feats_step, 0, feats_step.Length); + float[]? waveform = vadInputEntity.Waveform; + float[]? waveform_step = new float[Math.Min(waveform.Length, ((int)(t_offset + step) / 400 - 1) * 160 + 400) - t_offset / 400 * 160]; + Array.Copy(waveform, t_offset / 400 * 160, waveform_step, 0, waveform_step.Length); + vadInputEntity_step.Speech = feats_step; + vadInputEntity_step.SpeechLength = feats_step.Length; + vadInputEntity_step.InCaches = vadInputEntity.InCaches; + vadInputEntity_step.Waveform = waveform_step; + vadInputEntitys_step.Add(vadInputEntity_step); + } + List vadOutputEntitys = Infer(vadInputEntitys_step); + for (int batch_num = 0; batch_num < _batchSize; batch_num++) + { + vadInputEntitys[batch_num].InCaches = vadOutputEntitys[batch_num].OutCaches; + var scores = vadOutputEntitys[batch_num].Scores; + SegmentEntity[] segments_part = vadInputEntitys[batch_num].VadScorer.DefaultCall(scores, vadInputEntitys_step[batch_num].Waveform, is_final: is_final, max_end_sil: _max_end_sil, online: false); + if (segments_part.Length > 0) + { + +#pragma warning disable CS8602 // 解引用可能出现空引用。 + if (segments[beg_idx + batch_num] == null) + { + segments[beg_idx + batch_num] = new SegmentEntity(); + } + if (segments_part[0] != null) + { + segments[beg_idx + batch_num].Segment.AddRange(segments_part[0].Segment); + } +#pragma warning restore CS8602 // 解引用可能出现空引用。 + + } + } + } + } + catch (OnnxRuntimeException ex) + { + _logger.LogWarning("input wav is silence or noise"); + segments = null; + } +// for (int batch_num = 0; batch_num < _batchSize; batch_num++) +// { +// List segment_waveforms=new List(); +// foreach (int[] segment in segments[beg_idx + batch_num].Segment) +// { +// // (int)(16000 * (segment[0] / 1000.0) * 2); +// int frame_length = (((6000 * 400) / 400 - 1) * 160 + 400) / 60 / 1000; +// int frame_start = segment[0] * frame_length; +// int frame_end = segment[1] * frame_length; +// if(frame_end > waveform_list[batch_num].Length) +// { +// break; +// } +// float[] segment_waveform = new float[frame_end - frame_start]; +// Array.Copy(waveform_list[batch_num], frame_start, segment_waveform, 0, segment_waveform.Length); +// segment_waveforms.Add(segment_waveform); +// } +// segments[beg_idx + batch_num].Waveform.AddRange(segment_waveforms); +// } + + } + return segments; + } + + private List PrepareCache(List in_cache) + { + if (in_cache.Count > 0) + { + return in_cache; + } + + int fsmn_layers = _encoderConfEntity.fsmn_layers; + + int proj_dim = _encoderConfEntity.proj_dim; + int lorder = _encoderConfEntity.lorder; + + for (int i = 0; i < fsmn_layers; i++) + { + float[] cache = new float[1 * proj_dim * (lorder - 1) * 1]; + in_cache.Add(cache); + } + return in_cache; + } + + private List ExtractFeats(List waveform_list) + { + List in_cache = new List(); + in_cache = PrepareCache(in_cache); + List vadInputEntitys = new List(); + foreach (var waveform in waveform_list) + { + float[] fbanks = _wavFrontend.GetFbank(waveform); + float[] features = _wavFrontend.LfrCmvn(fbanks); + VadInputEntity vadInputEntity = new VadInputEntity(); + vadInputEntity.Waveform = waveform; + vadInputEntity.Speech = features; + vadInputEntity.SpeechLength = features.Length; + vadInputEntity.InCaches = in_cache; + vadInputEntity.VadScorer = new E2EVadModel(_vad_post_conf); + vadInputEntitys.Add(vadInputEntity); + } + return vadInputEntitys; + } + /// + /// 一维数组转3维数组 + /// + /// + /// 一维长 + /// 二维长 + /// + public static T[,,] DimOneToThree(T[] oneDimObj, int len, int wid) + { + if (oneDimObj.Length % (len * wid) != 0) + return null; + int height = oneDimObj.Length / (len * wid); + T[,,] threeDimObj = new T[len, wid, height]; + + for (int i = 0; i < oneDimObj.Length; i++) + { + threeDimObj[i / (wid * height), (i / height) % wid, i % height] = oneDimObj[i]; + } + return threeDimObj; + } + + private List Infer(List vadInputEntitys) + { + List vadOutputEntities = new List(); + foreach (VadInputEntity vadInputEntity in vadInputEntitys) + { + int batchSize = 1;//_batchSize + var inputMeta = _onnxSession.InputMetadata; + var container = new List(); + int[] dim = new int[] { batchSize, vadInputEntity.Speech.Length / 400 / batchSize, 400 }; + var tensor = new DenseTensor(vadInputEntity.Speech, dim, false); + container.Add(NamedOnnxValue.CreateFromTensor("speech", tensor)); + + int i = 0; + foreach (var cache in vadInputEntity.InCaches) + { + int[] cache_dim = new int[] { 1, 128, cache.Length / 128 / 1, 1 }; + var cache_tensor = new DenseTensor(cache, cache_dim, false); + container.Add(NamedOnnxValue.CreateFromTensor("in_cache" + i.ToString(), cache_tensor)); + i++; + } + + IDisposableReadOnlyCollection results = _onnxSession.Run(container); + var resultsArray = results.ToArray(); + VadOutputEntity vadOutputEntity = new VadOutputEntity(); + for (int j = 0; j < resultsArray.Length; j++) + { + if (resultsArray[j].Name.Equals("logits")) + { + Tensor tensors = resultsArray[0].AsTensor(); + var _scores = DimOneToThree(tensors.ToArray(), 1, tensors.Dimensions[1]); + vadOutputEntity.Scores = _scores; + } + if (resultsArray[j].Name.StartsWith("out_cache")) + { + vadOutputEntity.OutCaches.Add(resultsArray[j].AsEnumerable().ToArray()); + } + + } + vadOutputEntities.Add(vadOutputEntity); + } + + return vadOutputEntities; + } + + private float[] PadSequence(List modelInputs) + { + int max_speech_length = modelInputs.Max(x => x.SpeechLength); + int speech_length = max_speech_length * modelInputs.Count; + float[] speech = new float[speech_length]; + float[,] xxx = new float[modelInputs.Count, max_speech_length]; + for (int i = 0; i < modelInputs.Count; i++) + { + if (max_speech_length == modelInputs[i].SpeechLength) + { + for (int j = 0; j < xxx.GetLength(1); j++) + { +#pragma warning disable CS8602 // 解引用可能出现空引用。 + xxx[i, j] = modelInputs[i].Speech[j]; +#pragma warning restore CS8602 // 解引用可能出现空引用。 + } + continue; + } + float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength]; + float[]? curr_speech = modelInputs[i].Speech; + float[] padspeech = new float[max_speech_length]; + // /////////////////////////////////////////////////// + var arr_neg_mean = _onnxSession.ModelMetadata.CustomMetadataMap["neg_mean"].ToString().Split(',').ToArray(); + double[] neg_mean = arr_neg_mean.Select(x => (double)Convert.ToDouble(x)).ToArray(); + var arr_inv_stddev = _onnxSession.ModelMetadata.CustomMetadataMap["inv_stddev"].ToString().Split(',').ToArray(); + double[] inv_stddev = arr_inv_stddev.Select(x => (double)Convert.ToDouble(x)).ToArray(); + + int dim = neg_mean.Length; + for (int j = 0; j < max_speech_length; j++) + { + int k = new Random().Next(0, dim); + padspeech[j] = (float)((float)(0 + neg_mean[k]) * inv_stddev[k]); + } + Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length); + for (int j = 0; j < padspeech.Length; j++) + { +#pragma warning disable CS8602 // 解引用可能出现空引用。 + xxx[i, j] = padspeech[j]; +#pragma warning restore CS8602 // 解引用可能出现空引用。 + } + + } + int s = 0; + for (int i = 0; i < xxx.GetLength(0); i++) + { + for (int j = 0; j < xxx.GetLength(1); j++) + { + speech[s] = xxx[i, j]; + s++; + } + } + return speech; + } + + + + + + + + + + + + + } +} \ No newline at end of file diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVadSharp.csproj b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVadSharp.csproj new file mode 100644 index 000000000..49915173e --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVadSharp.csproj @@ -0,0 +1,37 @@ + + + + net6.0 + enable + enable + + + + + + + + + + + PreserveNewest + kaldi-native-fbank-dll.dll + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KaldiNativeFbank.cs b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KaldiNativeFbank.cs new file mode 100644 index 000000000..af0ad3664 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KaldiNativeFbank.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Runtime.InteropServices; +using AliFsmnVadSharp.Struct; + +namespace AliFsmnVadSharp.DLL +{ + public static class KaldiNativeFbank + { + private const string dllName = @"kaldi-native-fbank-dll"; + + [DllImport(dllName, EntryPoint = "GetFbankOptions", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr GetFbankOptions(float dither, bool snip_edges, float sample_rate, int num_bins, float frame_shift = 10.0f, float frame_length = 25.0f, float energy_floor = 0.0f, bool debug_mel = false, string window_type = "hamming"); + + [DllImport(dllName, EntryPoint = "GetOnlineFbank", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern KnfOnlineFbank GetOnlineFbank(IntPtr opts); + + [DllImport(dllName, EntryPoint = "AcceptWaveform", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern void AcceptWaveform(KnfOnlineFbank knfOnlineFbank, float sample_rate, float[] samples, int samples_size); + + [DllImport(dllName, EntryPoint = "InputFinished", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern void InputFinished(KnfOnlineFbank knfOnlineFbank); + + [DllImport(dllName, EntryPoint = "GetNumFramesReady", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern int GetNumFramesReady(KnfOnlineFbank knfOnlineFbank); + + [DllImport(dllName, EntryPoint = "AcceptWaveformxxx", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern FbankDatas AcceptWaveformxxx(KnfOnlineFbank knfOnlineFbank, float sample_rate, float[] samples, int samples_size); + + [DllImport(dllName, EntryPoint = "GetFbank", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern void GetFbank(KnfOnlineFbank knfOnlineFbank,int frame, ref FbankData pData); + + [DllImport(dllName, EntryPoint = "GetFbanks", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + internal static extern void GetFbanks(KnfOnlineFbank knfOnlineFbank, int framesNum, ref FbankDatas fbankDatas); + + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KnfOnlineFbank.cs b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KnfOnlineFbank.cs new file mode 100644 index 000000000..45549b2bf --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KnfOnlineFbank.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.DLL +{ + internal struct FbankData + { + public IntPtr data; + public int data_length; + }; + + internal struct FbankDatas + { + public IntPtr data; + public int data_length; + }; + + internal struct KnfOnlineFbank + { + public IntPtr impl; + }; +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/E2EVadModel.cs b/funasr/runtime/csharp/AliFsmnVadSharp/E2EVadModel.cs new file mode 100644 index 000000000..ce519b1e9 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/E2EVadModel.cs @@ -0,0 +1,717 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using AliFsmnVadSharp.Model; + +namespace AliFsmnVadSharp +{ + enum VadStateMachine + { + kVadInStateStartPointNotDetected = 1, + kVadInStateInSpeechSegment = 2, + kVadInStateEndPointDetected = 3, + } + enum VadDetectMode + { + kVadSingleUtteranceDetectMode = 0, + kVadMutipleUtteranceDetectMode = 1, + } + + + internal class E2EVadModel + { + private VadPostConfEntity _vad_opts = new VadPostConfEntity(); + private WindowDetector _windows_detector = new WindowDetector(); + private bool _is_final = false; + private int _data_buf_start_frame = 0; + private int _frm_cnt = 0; + private int _latest_confirmed_speech_frame = 0; + private int _lastest_confirmed_silence_frame = -1; + private int _continous_silence_frame_count = 0; + private int _vad_state_machine = (int)VadStateMachine.kVadInStateStartPointNotDetected; + private int _confirmed_start_frame = -1; + private int _confirmed_end_frame = -1; + private int _number_end_time_detected = 0; + private int _sil_frame = 0; + private int[] _sil_pdf_ids = new int[0]; + private double _noise_average_decibel = -100.0D; + private bool _pre_end_silence_detected = false; + private bool _next_seg = true; + + private List _output_data_buf; + private int _output_data_buf_offset = 0; + private List _frame_probs = new List(); + private int _max_end_sil_frame_cnt_thresh = 800 - 150; + private float _speech_noise_thres = 0.6F; + private float[,,] _scores = null; + private int _idx_pre_chunk = 0; + private bool _max_time_out = false; + private List _decibel = new List(); + private int _data_buf_size = 0; + private int _data_buf_all_size = 0; + + public E2EVadModel(VadPostConfEntity vadPostConfEntity) + { + _vad_opts = vadPostConfEntity; + _windows_detector = new WindowDetector(_vad_opts.window_size_ms, + _vad_opts.sil_to_speech_time_thres, + _vad_opts.speech_to_sil_time_thres, + _vad_opts.frame_in_ms); + AllResetDetection(); + } + + private void AllResetDetection() + { + _is_final = false; + _data_buf_start_frame = 0; + _frm_cnt = 0; + _latest_confirmed_speech_frame = 0; + _lastest_confirmed_silence_frame = -1; + _continous_silence_frame_count = 0; + _vad_state_machine = (int)VadStateMachine.kVadInStateStartPointNotDetected; + _confirmed_start_frame = -1; + _confirmed_end_frame = -1; + _number_end_time_detected = 0; + _sil_frame = 0; + _sil_pdf_ids = _vad_opts.sil_pdf_ids; + _noise_average_decibel = -100.0F; + _pre_end_silence_detected = false; + _next_seg = true; + + _output_data_buf = new List(); + _output_data_buf_offset = 0; + _frame_probs = new List(); + _max_end_sil_frame_cnt_thresh = _vad_opts.max_end_silence_time - _vad_opts.speech_to_sil_time_thres; + _speech_noise_thres = _vad_opts.speech_noise_thres; + _scores = null; + _idx_pre_chunk = 0; + _max_time_out = false; + _decibel = new List(); + _data_buf_size = 0; + _data_buf_all_size = 0; + ResetDetection(); + } + + private void ResetDetection() + { + _continous_silence_frame_count = 0; + _latest_confirmed_speech_frame = 0; + _lastest_confirmed_silence_frame = -1; + _confirmed_start_frame = -1; + _confirmed_end_frame = -1; + _vad_state_machine = (int)VadStateMachine.kVadInStateStartPointNotDetected; + _windows_detector.Reset(); + _sil_frame = 0; + _frame_probs = new List(); + } + + private void ComputeDecibel(float[] waveform) + { + int frame_sample_length = (int)(_vad_opts.frame_length_ms * _vad_opts.sample_rate / 1000); + int frame_shift_length = (int)(_vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000); + if (_data_buf_all_size == 0) + { + _data_buf_all_size = waveform.Length; + _data_buf_size = _data_buf_all_size; + } + else + { + _data_buf_all_size += waveform.Length; + } + + for (int offset = 0; offset < waveform.Length - frame_sample_length + 1; offset += frame_shift_length) + { + float[] _waveform_chunk = new float[frame_sample_length]; + Array.Copy(waveform, offset, _waveform_chunk, 0, _waveform_chunk.Length); + float[] _waveform_chunk_pow = _waveform_chunk.Select(x => (float)Math.Pow((double)x, 2)).ToArray(); + _decibel.Add( + 10 * Math.Log10( + _waveform_chunk_pow.Sum() + 0.000001 + ) + ); + } + + } + + private void ComputeScores(float[,,] scores) + { + _vad_opts.nn_eval_block_size = scores.GetLength(1); + _frm_cnt += scores.GetLength(1); + _scores = scores; + } + + private void PopDataBufTillFrame(int frame_idx)// need check again + { + while (_data_buf_start_frame < frame_idx) + { + if (_data_buf_size >= (int)(_vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000)) + { + _data_buf_start_frame += 1; + _data_buf_size = _data_buf_all_size - _data_buf_start_frame * (int)(_vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000); + } + } + } + + private void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, + bool last_frm_is_end_point, bool end_point_is_sent_end) + { + PopDataBufTillFrame(start_frm); + int expected_sample_number = (int)(frm_cnt * _vad_opts.sample_rate * _vad_opts.frame_in_ms / 1000); + if (last_frm_is_end_point) + { + int extra_sample = Math.Max(0, (int)(_vad_opts.frame_length_ms * _vad_opts.sample_rate / 1000 - _vad_opts.sample_rate * _vad_opts.frame_in_ms / 1000)); + expected_sample_number += (int)(extra_sample); + } + + if (end_point_is_sent_end) + { + expected_sample_number = Math.Max(expected_sample_number, _data_buf_size); + } + if (_data_buf_size < expected_sample_number) + { + Console.WriteLine("error in calling pop data_buf\n"); + } + + if (_output_data_buf.Count == 0 || first_frm_is_start_point) + { + _output_data_buf.Add(new E2EVadSpeechBufWithDoaEntity()); + _output_data_buf.Last().Reset(); + _output_data_buf.Last().start_ms = start_frm * _vad_opts.frame_in_ms; + _output_data_buf.Last().end_ms = _output_data_buf.Last().start_ms; + _output_data_buf.Last().doa = 0; + } + + E2EVadSpeechBufWithDoaEntity cur_seg = _output_data_buf.Last(); + if (cur_seg.end_ms != start_frm * _vad_opts.frame_in_ms) + { + Console.WriteLine("warning\n"); + } + + int out_pos = cur_seg.buffer.Length; // cur_seg.buff现在没做任何操作 + int data_to_pop = 0; + if (end_point_is_sent_end) + { + data_to_pop = expected_sample_number; + } + else + { + data_to_pop = (int)(frm_cnt * _vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000); + } + if (data_to_pop > _data_buf_size) + { + Console.WriteLine("VAD data_to_pop is bigger than _data_buf_size!!!\n"); + data_to_pop = _data_buf_size; + expected_sample_number = _data_buf_size; + } + + + cur_seg.doa = 0; + for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) + { + out_pos += 1; + } + for (int sample_cpy_out = data_to_pop; sample_cpy_out < expected_sample_number; sample_cpy_out++) + { + out_pos += 1; + } + + if (cur_seg.end_ms != start_frm * _vad_opts.frame_in_ms) + { + Console.WriteLine("Something wrong with the VAD algorithm\n"); + } + + _data_buf_start_frame += frm_cnt; + cur_seg.end_ms = (start_frm + frm_cnt) * _vad_opts.frame_in_ms; + if (first_frm_is_start_point) + { + cur_seg.contain_seg_start_point = true; + } + + if (last_frm_is_end_point) + { + cur_seg.contain_seg_end_point = true; + } + } + + private void OnSilenceDetected(int valid_frame) + { + _lastest_confirmed_silence_frame = valid_frame; + if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected) + { + PopDataBufTillFrame(valid_frame); + } + + } + + private void OnVoiceDetected(int valid_frame) + { + _latest_confirmed_speech_frame = valid_frame; + PopDataToOutputBuf(valid_frame, 1, false, false, false); + } + + private void OnVoiceStart(int start_frame, bool fake_result = false) + { + if (_vad_opts.do_start_point_detection) + { + //do nothing + } + if (_confirmed_start_frame != -1) + { + + Console.WriteLine("not reset vad properly\n"); + } + else + { + _confirmed_start_frame = start_frame; + } + if (!fake_result || _vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected) + { + + PopDataToOutputBuf(_confirmed_start_frame, 1, true, false, false); + } + } + + private void OnVoiceEnd(int end_frame, bool fake_result, bool is_last_frame) + { + for (int t = _latest_confirmed_speech_frame + 1; t < end_frame; t++) + { + OnVoiceDetected(t); + } + if (_vad_opts.do_end_point_detection) + { + //do nothing + } + if (_confirmed_end_frame != -1) + { + Console.WriteLine("not reset vad properly\n"); + } + else + { + _confirmed_end_frame = end_frame; + } + if (!fake_result) + { + _sil_frame = 0; + PopDataToOutputBuf(_confirmed_end_frame, 1, false, true, is_last_frame); + } + _number_end_time_detected += 1; + } + + private void MaybeOnVoiceEndIfLastFrame(bool is_final_frame, int cur_frm_idx) + { + if (is_final_frame) + { + OnVoiceEnd(cur_frm_idx, false, true); + _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected; + } + } + + private int GetLatency() + { + return (int)(LatencyFrmNumAtStartPoint() * _vad_opts.frame_in_ms); + } + + private int LatencyFrmNumAtStartPoint() + { + int vad_latency = _windows_detector.GetWinSize(); + if (_vad_opts.do_extend != 0) + { + vad_latency += (int)(_vad_opts.lookback_time_start_point / _vad_opts.frame_in_ms); + } + return vad_latency; + } + + private FrameState GetFrameState(int t) + { + + FrameState frame_state = FrameState.kFrameStateInvalid; + double cur_decibel = _decibel[t]; + double cur_snr = cur_decibel - _noise_average_decibel; + if (cur_decibel < _vad_opts.decibel_thres) + { + frame_state = FrameState.kFrameStateSil; + DetectOneFrame(frame_state, t, false); + return frame_state; + } + + + double sum_score = 0.0D; + double noise_prob = 0.0D; + Trace.Assert(_sil_pdf_ids.Length == _vad_opts.silence_pdf_num, ""); + if (_sil_pdf_ids.Length > 0) + { + Trace.Assert(_scores.GetLength(0) == 1, "只支持batch_size = 1的测试"); // 只支持batch_size = 1的测试 + float[] sil_pdf_scores = new float[_sil_pdf_ids.Length]; + int j = 0; + foreach (int sil_pdf_id in _sil_pdf_ids) + { + sil_pdf_scores[j] = _scores[0,t - _idx_pre_chunk,sil_pdf_id]; + j++; + } + sum_score = sil_pdf_scores.Length == 0 ? 0 : sil_pdf_scores.Sum(); + noise_prob = Math.Log(sum_score) * _vad_opts.speech_2_noise_ratio; + double total_score = 1.0D; + sum_score = total_score - sum_score; + } + double speech_prob = Math.Log(sum_score); + if (_vad_opts.output_frame_probs) + { + E2EVadFrameProbEntity frame_prob = new E2EVadFrameProbEntity(); + frame_prob.noise_prob = noise_prob; + frame_prob.speech_prob = speech_prob; + frame_prob.score = sum_score; + frame_prob.frame_id = t; + _frame_probs.Add(frame_prob); + } + + if (Math.Exp(speech_prob) >= Math.Exp(noise_prob) + _speech_noise_thres) + { + if (cur_snr >= _vad_opts.snr_thres && cur_decibel >= _vad_opts.decibel_thres) + { + frame_state = FrameState.kFrameStateSpeech; + } + else + { + frame_state = FrameState.kFrameStateSil; + } + } + else + { + frame_state = FrameState.kFrameStateSil; + if (_noise_average_decibel < -99.9) + { + _noise_average_decibel = cur_decibel; + } + else + { + _noise_average_decibel = (cur_decibel + _noise_average_decibel * (_vad_opts.noise_frame_num_used_for_snr - 1)) / _vad_opts.noise_frame_num_used_for_snr; + } + } + return frame_state; + } + + public SegmentEntity[] DefaultCall(float[,,] score, float[] waveform, + bool is_final = false, int max_end_sil = 800, bool online = false + ) + { + _max_end_sil_frame_cnt_thresh = max_end_sil - _vad_opts.speech_to_sil_time_thres; + // compute decibel for each frame + ComputeDecibel(waveform); + ComputeScores(score); + if (!is_final) + { + DetectCommonFrames(); + } + else + { + DetectLastFrames(); + } + int batchSize = score.GetLength(0); + SegmentEntity[] segments = new SegmentEntity[batchSize]; + for (int batch_num = 0; batch_num < batchSize; batch_num++) // only support batch_size = 1 now + { + List segment_batch = new List(); + if (_output_data_buf.Count > 0) + { + for (int i = _output_data_buf_offset; i < _output_data_buf.Count; i++) + { + int start_ms; + int end_ms; + if (online) + { + if (!_output_data_buf[i].contain_seg_start_point) + { + continue; + } + if (!_next_seg && !_output_data_buf[i].contain_seg_end_point) + { + continue; + } + start_ms = _next_seg ? _output_data_buf[i].start_ms : -1; + if (_output_data_buf[i].contain_seg_end_point) + { + end_ms = _output_data_buf[i].end_ms; + _next_seg = true; + _output_data_buf_offset += 1; + } + else + { + end_ms = -1; + _next_seg = false; + } + } + else + { + if (!is_final && (!_output_data_buf[i].contain_seg_start_point || !_output_data_buf[i].contain_seg_end_point)) + { + continue; + } + start_ms = _output_data_buf[i].start_ms; + end_ms = _output_data_buf[i].end_ms; + _output_data_buf_offset += 1; + + } + int[] segment_ms = new int[] { start_ms, end_ms }; + segment_batch.Add(segment_ms); + + } + + } + + if (segment_batch.Count > 0) + { + if (segments[batch_num] == null) + { + segments[batch_num] = new SegmentEntity(); + } + segments[batch_num].Segment.AddRange(segment_batch); + } + } + + if (is_final) + { + // reset class variables and clear the dict for the next query + AllResetDetection(); + } + + return segments; + } + + private int DetectCommonFrames() + { + if (_vad_state_machine == (int)VadStateMachine.kVadInStateEndPointDetected) + { + return 0; + } + for (int i = _vad_opts.nn_eval_block_size - 1; i > -1; i += -1) + { + FrameState frame_state = FrameState.kFrameStateInvalid; + frame_state = GetFrameState(_frm_cnt - 1 - i); + DetectOneFrame(frame_state, _frm_cnt - 1 - i, false); + } + + _idx_pre_chunk += _scores.GetLength(1)* _scores.GetLength(0); //_scores.shape[1]; + return 0; + } + + private int DetectLastFrames() + { + if (_vad_state_machine == (int)VadStateMachine.kVadInStateEndPointDetected) + { + return 0; + } + for (int i = _vad_opts.nn_eval_block_size - 1; i > -1; i += -1) + { + FrameState frame_state = FrameState.kFrameStateInvalid; + frame_state = GetFrameState(_frm_cnt - 1 - i); + if (i != 0) + { + DetectOneFrame(frame_state, _frm_cnt - 1 - i, false); + } + else + { + DetectOneFrame(frame_state, _frm_cnt - 1, true); + } + + + } + + return 0; + } + + private void DetectOneFrame(FrameState cur_frm_state, int cur_frm_idx, bool is_final_frame) + { + FrameState tmp_cur_frm_state = FrameState.kFrameStateInvalid; + if (cur_frm_state == FrameState.kFrameStateSpeech) + { + if (Math.Abs(1.0) > _vad_opts.fe_prior_thres)//Fabs + { + tmp_cur_frm_state = FrameState.kFrameStateSpeech; + } + else + { + tmp_cur_frm_state = FrameState.kFrameStateSil; + } + } + else if (cur_frm_state == FrameState.kFrameStateSil) + { + tmp_cur_frm_state = FrameState.kFrameStateSil; + } + + AudioChangeState state_change = _windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx); + int frm_shift_in_ms = _vad_opts.frame_in_ms; + if (AudioChangeState.kChangeStateSil2Speech == state_change) + { + int silence_frame_count = _continous_silence_frame_count; // no used + _continous_silence_frame_count = 0; + _pre_end_silence_detected = false; + int start_frame = 0; + if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected) + { + start_frame = Math.Max(_data_buf_start_frame, cur_frm_idx - LatencyFrmNumAtStartPoint()); + OnVoiceStart(start_frame); + _vad_state_machine = (int)VadStateMachine.kVadInStateInSpeechSegment; + for (int t = start_frame + 1; t < cur_frm_idx + 1; t++) + { + OnVoiceDetected(t); + } + + } + else if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment) + { + for (int t = _latest_confirmed_speech_frame + 1; t < cur_frm_idx; t++) + { + OnVoiceDetected(t); + } + if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms) + { + OnVoiceEnd(cur_frm_idx, false, false); + _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected; + } + + else if (!is_final_frame) + { + OnVoiceDetected(cur_frm_idx); + } + else + { + MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); + } + + } + else + { + return; + } + } + else if (AudioChangeState.kChangeStateSpeech2Sil == state_change) + { + _continous_silence_frame_count = 0; + if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected) + { return; } + else if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment) + { + if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms) + { + OnVoiceEnd(cur_frm_idx, false, false); + _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected; + } + else if (!is_final_frame) + { + OnVoiceDetected(cur_frm_idx); + } + else + { + MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); + } + + } + else + { + return; + } + } + else if (AudioChangeState.kChangeStateSpeech2Speech == state_change) + { + _continous_silence_frame_count = 0; + if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment) + { + if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms) + { + _max_time_out = true; + OnVoiceEnd(cur_frm_idx, false, false); + _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected; + } + else if (!is_final_frame) + { + OnVoiceDetected(cur_frm_idx); + } + else + { + MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); + } + } + else + { + return; + } + + } + else if (AudioChangeState.kChangeStateSil2Sil == state_change) + { + _continous_silence_frame_count += 1; + if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected) + { + // silence timeout, return zero length decision + if (((_vad_opts.detect_mode == (int)VadDetectMode.kVadSingleUtteranceDetectMode) && ( + _continous_silence_frame_count * frm_shift_in_ms > _vad_opts.max_start_silence_time)) || (is_final_frame && _number_end_time_detected == 0)) + { + for (int t = _lastest_confirmed_silence_frame + 1; t < cur_frm_idx; t++) + { + OnSilenceDetected(t); + } + OnVoiceStart(0, true); + OnVoiceEnd(0, true, false); + _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected; + } + else + { + if (cur_frm_idx >= LatencyFrmNumAtStartPoint()) + { + OnSilenceDetected(cur_frm_idx - LatencyFrmNumAtStartPoint()); + } + } + } + else if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment) + { + if (_continous_silence_frame_count * frm_shift_in_ms >= _max_end_sil_frame_cnt_thresh) + { + int lookback_frame = (int)(_max_end_sil_frame_cnt_thresh / frm_shift_in_ms); + if (_vad_opts.do_extend != 0) + { + lookback_frame -= (int)(_vad_opts.lookahead_time_end_point / frm_shift_in_ms); + lookback_frame -= 1; + lookback_frame = Math.Max(0, lookback_frame); + } + + OnVoiceEnd(cur_frm_idx - lookback_frame, false, false); + _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected; + } + else if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms) + { + OnVoiceEnd(cur_frm_idx, false, false); + _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected; + } + + else if (_vad_opts.do_extend != 0 && !is_final_frame) + { + if (_continous_silence_frame_count <= (int)(_vad_opts.lookahead_time_end_point / frm_shift_in_ms)) + { + OnVoiceDetected(cur_frm_idx); + } + } + + else + { + MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); + } + } + else + { + return; + } + + } + + if (_vad_state_machine == (int)VadStateMachine.kVadInStateEndPointDetected && _vad_opts.detect_mode == (int)VadDetectMode.kVadMutipleUtteranceDetectMode) + { + ResetDetection(); + } + + } + + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Lib/kaldi-native-fbank-dll.dll b/funasr/runtime/csharp/AliFsmnVadSharp/Lib/kaldi-native-fbank-dll.dll new file mode 100644 index 000000000..cddc94074 Binary files /dev/null and b/funasr/runtime/csharp/AliFsmnVadSharp/Lib/kaldi-native-fbank-dll.dll differ diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/CmvnEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/CmvnEntity.cs new file mode 100644 index 000000000..2f93df1b2 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/CmvnEntity.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + internal class CmvnEntity + { + private List _means = new List(); + private List _vars = new List(); + + public List Means { get => _means; set => _means = value; } + public List Vars { get => _vars; set => _vars = value; } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadFrameProbEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadFrameProbEntity.cs new file mode 100644 index 000000000..58a4ca9a0 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadFrameProbEntity.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + internal class E2EVadFrameProbEntity + { + private double _noise_prob = 0.0F; + private double _speech_prob = 0.0F; + private double _score = 0.0F; + private int _frame_id = 0; + private int _frm_state = 0; + + public double noise_prob { get => _noise_prob; set => _noise_prob = value; } + public double speech_prob { get => _speech_prob; set => _speech_prob = value; } + public double score { get => _score; set => _score = value; } + public int frame_id { get => _frame_id; set => _frame_id = value; } + public int frm_state { get => _frm_state; set => _frm_state = value; } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadSpeechBufWithDoaEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadSpeechBufWithDoaEntity.cs new file mode 100644 index 000000000..8c2e7f7a9 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadSpeechBufWithDoaEntity.cs @@ -0,0 +1,98 @@ +// AliFsmnVadSharp, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null +// AliFsmnVadSharp.Model.E2EVadSpeechBufWithDoaEntity +internal class E2EVadSpeechBufWithDoaEntity +{ + private int _start_ms = 0; + + private int _end_ms = 0; + + private byte[]? _buffer; + + private bool _contain_seg_start_point = false; + + private bool _contain_seg_end_point = false; + + private int _doa = 0; + + public int start_ms + { + get + { + return _start_ms; + } + set + { + _start_ms = value; + } + } + + public int end_ms + { + get + { + return _end_ms; + } + set + { + _end_ms = value; + } + } + + public byte[]? buffer + { + get + { + return _buffer; + } + set + { + _buffer = value; + } + } + + public bool contain_seg_start_point + { + get + { + return _contain_seg_start_point; + } + set + { + _contain_seg_start_point = value; + } + } + + public bool contain_seg_end_point + { + get + { + return _contain_seg_end_point; + } + set + { + _contain_seg_end_point = value; + } + } + + public int doa + { + get + { + return _doa; + } + set + { + _doa = value; + } + } + + public void Reset() + { + _start_ms = 0; + _end_ms = 0; + _buffer = new byte[0]; + _contain_seg_start_point = false; + _contain_seg_end_point = false; + _doa = 0; + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/EncoderConfEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/EncoderConfEntity.cs new file mode 100644 index 000000000..8365b1206 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/EncoderConfEntity.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + public class EncoderConfEntity + { + private int _input_dim=400; + private int _input_affineDim = 140; + private int _fsmn_layers = 4; + private int _linear_dim = 250; + private int _proj_dim = 128; + private int _lorder = 20; + private int _rorder = 0; + private int _lstride = 1; + private int _rstride = 0; + private int _output_dffine_dim = 140; + private int _output_dim = 248; + + public int input_dim { get => _input_dim; set => _input_dim = value; } + public int input_affine_dim { get => _input_affineDim; set => _input_affineDim = value; } + public int fsmn_layers { get => _fsmn_layers; set => _fsmn_layers = value; } + public int linear_dim { get => _linear_dim; set => _linear_dim = value; } + public int proj_dim { get => _proj_dim; set => _proj_dim = value; } + public int lorder { get => _lorder; set => _lorder = value; } + public int rorder { get => _rorder; set => _rorder = value; } + public int lstride { get => _lstride; set => _lstride = value; } + public int rstride { get => _rstride; set => _rstride = value; } + public int output_affine_dim { get => _output_dffine_dim; set => _output_dffine_dim = value; } + public int output_dim { get => _output_dim; set => _output_dim = value; } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/FrontendConfEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/FrontendConfEntity.cs new file mode 100644 index 000000000..22bb35aed --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/FrontendConfEntity.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + public class FrontendConfEntity + { + private int _fs = 16000; + private string _window = "hamming"; + private int _n_mels = 80; + private int _frame_length = 25; + private int _frame_shift = 10; + private float _dither = 0.0F; + private int _lfr_m = 5; + private int _lfr_n = 1; + + public int fs { get => _fs; set => _fs = value; } + public string window { get => _window; set => _window = value; } + public int n_mels { get => _n_mels; set => _n_mels = value; } + public int frame_length { get => _frame_length; set => _frame_length = value; } + public int frame_shift { get => _frame_shift; set => _frame_shift = value; } + public float dither { get => _dither; set => _dither = value; } + public int lfr_m { get => _lfr_m; set => _lfr_m = value; } + public int lfr_n { get => _lfr_n; set => _lfr_n = value; } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/SegmentEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/SegmentEntity.cs new file mode 100644 index 000000000..bdb715d8d --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/SegmentEntity.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + public class SegmentEntity + { + private List _segment=new List(); + private List _waveform=new List(); + + public List Segment { get => _segment; set => _segment = value; } + public List Waveform { get => _waveform; set => _waveform = value; } + //public SegmentEntity() + //{ + // int[] t=new int[0]; + // _segment.Add(t); + //} + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadInputEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadInputEntity.cs new file mode 100644 index 000000000..fcd63d892 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadInputEntity.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + internal class VadInputEntity + { + private float[]? _speech; + private int _speechLength; + private List _inCaches = new List(); + private float[]? _waveform; + private E2EVadModel _vad_scorer; + + public float[]? Speech { get => _speech; set => _speech = value; } + public int SpeechLength { get => _speechLength; set => _speechLength = value; } + public List InCaches { get => _inCaches; set => _inCaches = value; } + public float[] Waveform { get => _waveform; set => _waveform = value; } + internal E2EVadModel VadScorer { get => _vad_scorer; set => _vad_scorer = value; } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadOutputEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadOutputEntity.cs new file mode 100644 index 000000000..fa8639e61 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadOutputEntity.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + internal class VadOutputEntity + { + private float[,,]? _scores; + private List _outCaches=new List(); + private float[]? _waveform; + + public float[,,]? Scores { get => _scores; set => _scores = value; } + public List OutCaches { get => _outCaches; set => _outCaches = value; } + public float[] Waveform { get => _waveform; set => _waveform = value; } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadPostConfEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadPostConfEntity.cs new file mode 100644 index 000000000..e566cf2b1 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadPostConfEntity.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + public class VadPostConfEntity + { + private int _sample_rate= 16000; + private int _detect_mode = 1 ; + private int _snr_mode = 0; + private int _max_end_silence_time = 800; + private int _max_start_silence_time = 3000; + private bool _do_start_point_detection = true; + private bool _do_end_point_detection = true; + private int _window_size_ms = 200; + private int _sil_to_speech_time_thres = 150; + private int _speech_to_sil_time_thres = 150; + private float _speech_2_noise_ratio = 1.0F; + private int _do_extend = 1; + private int _lookback_time_start_point = 200; + private int _lookahead_time_end_point = 100; + private int _max_single_segment_time = 60000; + private int _nn_eval_block_size = 8; + private int _dcd_block_size = 4; + private float _snr_thres = -100.0F; + private int _noise_frame_num_used_for_snr = 100; + private float _decibel_thres = -100.0F; + private float _speech_noise_thres = 0.6F; + private float _fe_prior_thres = 0.0001F; + private int _silence_pdf_num = 1; + private int[] _sil_pdf_ids = new int[] {0}; + private float _speech_noise_thresh_low = -0.1F; + private float _speech_noise_thresh_high = 0.3F; + private bool _output_frame_probs = false; + private int _frame_in_ms = 10; + private int _frame_length_ms = 25; + + public int sample_rate { get => _sample_rate; set => _sample_rate = value; } + public int detect_mode { get => _detect_mode; set => _detect_mode = value; } + public int snr_mode { get => _snr_mode; set => _snr_mode = value; } + public int max_end_silence_time { get => _max_end_silence_time; set => _max_end_silence_time = value; } + public int max_start_silence_time { get => _max_start_silence_time; set => _max_start_silence_time = value; } + public bool do_start_point_detection { get => _do_start_point_detection; set => _do_start_point_detection = value; } + public bool do_end_point_detection { get => _do_end_point_detection; set => _do_end_point_detection = value; } + public int window_size_ms { get => _window_size_ms; set => _window_size_ms = value; } + public int sil_to_speech_time_thres { get => _sil_to_speech_time_thres; set => _sil_to_speech_time_thres = value; } + public int speech_to_sil_time_thres { get => _speech_to_sil_time_thres; set => _speech_to_sil_time_thres = value; } + public float speech_2_noise_ratio { get => _speech_2_noise_ratio; set => _speech_2_noise_ratio = value; } + public int do_extend { get => _do_extend; set => _do_extend = value; } + public int lookback_time_start_point { get => _lookback_time_start_point; set => _lookback_time_start_point = value; } + public int lookahead_time_end_point { get => _lookahead_time_end_point; set => _lookahead_time_end_point = value; } + public int max_single_segment_time { get => _max_single_segment_time; set => _max_single_segment_time = value; } + public int nn_eval_block_size { get => _nn_eval_block_size; set => _nn_eval_block_size = value; } + public int dcd_block_size { get => _dcd_block_size; set => _dcd_block_size = value; } + public float snr_thres { get => _snr_thres; set => _snr_thres = value; } + public int noise_frame_num_used_for_snr { get => _noise_frame_num_used_for_snr; set => _noise_frame_num_used_for_snr = value; } + public float decibel_thres { get => _decibel_thres; set => _decibel_thres = value; } + public float speech_noise_thres { get => _speech_noise_thres; set => _speech_noise_thres = value; } + public float fe_prior_thres { get => _fe_prior_thres; set => _fe_prior_thres = value; } + public int silence_pdf_num { get => _silence_pdf_num; set => _silence_pdf_num = value; } + public int[] sil_pdf_ids { get => _sil_pdf_ids; set => _sil_pdf_ids = value; } + public float speech_noise_thresh_low { get => _speech_noise_thresh_low; set => _speech_noise_thresh_low = value; } + public float speech_noise_thresh_high { get => _speech_noise_thresh_high; set => _speech_noise_thresh_high = value; } + public bool output_frame_probs { get => _output_frame_probs; set => _output_frame_probs = value; } + public int frame_in_ms { get => _frame_in_ms; set => _frame_in_ms = value; } + public int frame_length_ms { get => _frame_length_ms; set => _frame_length_ms = value; } + + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadYamlEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadYamlEntity.cs new file mode 100644 index 000000000..65e77eda8 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadYamlEntity.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp.Model +{ + internal class VadYamlEntity + { + private int _input_size; + private string _frontend = "wav_frontend"; + private FrontendConfEntity _frontend_conf=new FrontendConfEntity(); + private string _model = "e2evad"; + private string _encoder = "fsmn"; + private EncoderConfEntity _encoder_conf=new EncoderConfEntity(); + private VadPostConfEntity _vad_post_conf=new VadPostConfEntity(); + + public int input_size { get => _input_size; set => _input_size = value; } + public string frontend { get => _frontend; set => _frontend = value; } + public string model { get => _model; set => _model = value; } + public string encoder { get => _encoder; set => _encoder = value; } + public FrontendConfEntity frontend_conf { get => _frontend_conf; set => _frontend_conf = value; } + public EncoderConfEntity encoder_conf { get => _encoder_conf; set => _encoder_conf = value; } + public VadPostConfEntity vad_post_conf { get => _vad_post_conf; set => _vad_post_conf = value; } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Struct/FbankData.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Struct/FbankData.cs new file mode 100644 index 000000000..bbad3dc95 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Struct/FbankData.cs @@ -0,0 +1,6 @@ +using System.Runtime.InteropServices; + +namespace AliFsmnVadSharp.Struct +{ + +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Utils/YamlHelper.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Utils/YamlHelper.cs new file mode 100644 index 000000000..0b460ff16 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/Utils/YamlHelper.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Text.Json; +using YamlDotNet.Serialization; + +namespace AliFsmnVadSharp.Utils +{ + internal class YamlHelper + { + public static T ReadYaml(string yamlFilePath) + { + if (!File.Exists(yamlFilePath)) + { +#pragma warning disable CS8603 // 可能返回 null 引用。 + return default(T); +#pragma warning restore CS8603 // 可能返回 null 引用。 + } + StreamReader yamlReader = File.OpenText(yamlFilePath); + Deserializer yamlDeserializer = new Deserializer(); + T info = yamlDeserializer.Deserialize(yamlReader); + yamlReader.Close(); + return info; + } + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/WavFrontend.cs b/funasr/runtime/csharp/AliFsmnVadSharp/WavFrontend.cs new file mode 100644 index 000000000..2c5b50fbb --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/WavFrontend.cs @@ -0,0 +1,185 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using AliFsmnVadSharp.Model; +using AliFsmnVadSharp.DLL; +using AliFsmnVadSharp.Struct; +using System.Runtime.InteropServices; + +namespace AliFsmnVadSharp +{ + internal class WavFrontend + { + private string _mvnFilePath; + private FrontendConfEntity _frontendConfEntity; + IntPtr _opts = IntPtr.Zero; + private CmvnEntity _cmvnEntity; + + private static int _fbank_beg_idx = 0; + + public WavFrontend(string mvnFilePath, FrontendConfEntity frontendConfEntity) + { + _mvnFilePath = mvnFilePath; + _frontendConfEntity = frontendConfEntity; + _fbank_beg_idx = 0; + _opts = KaldiNativeFbank.GetFbankOptions( + dither: _frontendConfEntity.dither, + snip_edges: true, + sample_rate: _frontendConfEntity.fs, + num_bins: _frontendConfEntity.n_mels + ); + _cmvnEntity = LoadCmvn(mvnFilePath); + } + + public float[] GetFbank(float[] samples) + { + float sample_rate = _frontendConfEntity.fs; + samples = samples.Select((float x) => x * 32768f).ToArray(); + // method1 + //FbankDatas fbankDatas = new FbankDatas(); + //KaldiNativeFbank.GetFbanks(_knfOnlineFbank, framesNum,ref fbankDatas); + // method2 + KnfOnlineFbank _knfOnlineFbank = KaldiNativeFbank.GetOnlineFbank(_opts); + KaldiNativeFbank.AcceptWaveform(_knfOnlineFbank, sample_rate, samples, samples.Length); + KaldiNativeFbank.InputFinished(_knfOnlineFbank); + int framesNum = KaldiNativeFbank.GetNumFramesReady(_knfOnlineFbank); + float[] fbanks = new float[framesNum * 80]; + for (int i = 0; i < framesNum; i++) + { + FbankData fbankData = new FbankData(); + KaldiNativeFbank.GetFbank(_knfOnlineFbank, i, ref fbankData); + float[] _fbankData = new float[fbankData.data_length]; + Marshal.Copy(fbankData.data, _fbankData, 0, fbankData.data_length); + Array.Copy(_fbankData, 0, fbanks, i * 80, _fbankData.Length); + fbankData.data = IntPtr.Zero; + _fbankData = null; + } + + samples = null; + GC.Collect(); + return fbanks; + } + + + public float[] LfrCmvn(float[] fbanks) + { + float[] features = fbanks; + if (_frontendConfEntity.lfr_m != 1 || _frontendConfEntity.lfr_n != 1) + { + features = ApplyLfr(fbanks, _frontendConfEntity.lfr_m, _frontendConfEntity.lfr_n); + } + if (_cmvnEntity != null) + { + features = ApplyCmvn(features); + } + return features; + } + + private float[] ApplyCmvn(float[] inputs) + { + var arr_neg_mean = _cmvnEntity.Means; + float[] neg_mean = arr_neg_mean.Select(x => (float)Convert.ToDouble(x)).ToArray(); + var arr_inv_stddev = _cmvnEntity.Vars; + float[] inv_stddev = arr_inv_stddev.Select(x => (float)Convert.ToDouble(x)).ToArray(); + + int dim = neg_mean.Length; + int num_frames = inputs.Length / dim; + + for (int i = 0; i < num_frames; i++) + { + for (int k = 0; k != dim; ++k) + { + inputs[dim * i + k] = (inputs[dim * i + k] + neg_mean[k]) * inv_stddev[k]; + } + } + return inputs; + } + + public float[] ApplyLfr(float[] inputs, int lfr_m, int lfr_n) + { + int t = inputs.Length / 80; + int t_lfr = (int)Math.Floor((double)(t / lfr_n)); + float[] input_0 = new float[80]; + Array.Copy(inputs, 0, input_0, 0, 80); + int tile_x = (lfr_m - 1) / 2; + t = t + tile_x; + float[] inputs_temp = new float[t * 80]; + for (int i = 0; i < tile_x; i++) + { + Array.Copy(input_0, 0, inputs_temp, tile_x * 80, 80); + } + Array.Copy(inputs, 0, inputs_temp, tile_x * 80, inputs.Length); + inputs = inputs_temp; + + float[] LFR_outputs = new float[t_lfr * lfr_m * 80]; + for (int i = 0; i < t_lfr; i++) + { + if (lfr_m <= t - i * lfr_n) + { + Array.Copy(inputs, i * lfr_n * 80, LFR_outputs, i* lfr_m * 80, lfr_m * 80); + } + else + { + // process last LFR frame + int num_padding = lfr_m - (t - i * lfr_n); + float[] frame = new float[lfr_m * 80]; + Array.Copy(inputs, i * lfr_n * 80, frame, 0, (t - i * lfr_n) * 80); + + for (int j = 0; j < num_padding; j++) + { + Array.Copy(inputs, (t - 1) * 80, frame, (lfr_m - num_padding + j) * 80, 80); + } + Array.Copy(frame, 0, LFR_outputs, i * lfr_m * 80, frame.Length); + } + } + return LFR_outputs; + } + + private CmvnEntity LoadCmvn(string mvnFilePath) + { + List means_list = new List(); + List vars_list = new List(); + FileStreamOptions options = new FileStreamOptions(); + options.Access = FileAccess.Read; + options.Mode = FileMode.Open; + StreamReader srtReader = new StreamReader(mvnFilePath, options); + int i = 0; + while (!srtReader.EndOfStream) + { + string? strLine = srtReader.ReadLine(); + if (!string.IsNullOrEmpty(strLine)) + { + if (strLine.StartsWith("")) + { + i=1; + continue; + } + if (strLine.StartsWith("")) + { + i = 2; + continue; + } + if (strLine.StartsWith("") && i==1) + { + string[] add_shift_line = strLine.Substring(strLine.IndexOf("[") + 1, strLine.LastIndexOf("]") - strLine.IndexOf("[") - 1).Split(" "); + means_list = add_shift_line.Where(x => !string.IsNullOrEmpty(x)).Select(x => float.Parse(x.Trim())).ToList(); + continue; + } + if (strLine.StartsWith("") && i==2) + { + string[] rescale_line = strLine.Substring(strLine.IndexOf("[") + 1, strLine.LastIndexOf("]") - strLine.IndexOf("[") - 1).Split(" "); + vars_list = rescale_line.Where(x => !string.IsNullOrEmpty(x)).Select(x => float.Parse(x.Trim())).ToList(); + continue; + } + } + } + CmvnEntity cmvnEntity = new CmvnEntity(); + cmvnEntity.Means = means_list; + cmvnEntity.Vars = vars_list; + return cmvnEntity; + } + + } +} diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/WindowDetector.cs b/funasr/runtime/csharp/AliFsmnVadSharp/WindowDetector.cs new file mode 100644 index 000000000..785af32b2 --- /dev/null +++ b/funasr/runtime/csharp/AliFsmnVadSharp/WindowDetector.cs @@ -0,0 +1,156 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace AliFsmnVadSharp +{ + public enum FrameState + { + kFrameStateInvalid = -1, + kFrameStateSpeech = 1, + kFrameStateSil = 0 + } + + /// + /// final voice/unvoice state per frame + /// + public enum AudioChangeState + { + kChangeStateSpeech2Speech = 0, + kChangeStateSpeech2Sil = 1, + kChangeStateSil2Sil = 2, + kChangeStateSil2Speech = 3, + kChangeStateNoBegin = 4, + kChangeStateInvalid = 5 + } + + + internal class WindowDetector + { + private int _window_size_ms = 0; //window_size_ms; + private int _sil_to_speech_time = 0; //sil_to_speech_time; + private int _speech_to_sil_time = 0; //speech_to_sil_time; + private int _frame_size_ms = 0; //frame_size_ms; + + private int _win_size_frame = 0; + private int _win_sum = 0; + private int[] _win_state = new int[0];// * _win_size_frame; // 初始化窗 + + private int _cur_win_pos = 0; + private int _pre_frame_state = (int)FrameState.kFrameStateSil; + private int _cur_frame_state = (int)FrameState.kFrameStateSil; + private int _sil_to_speech_frmcnt_thres = 0; //int(sil_to_speech_time / frame_size_ms); + private int _speech_to_sil_frmcnt_thres = 0; //int(speech_to_sil_time / frame_size_ms); + + private int _voice_last_frame_count = 0; + private int _noise_last_frame_count = 0; + private int _hydre_frame_count = 0; + + public WindowDetector() + { + + } + + public WindowDetector(int window_size_ms, int sil_to_speech_time, int speech_to_sil_time, int frame_size_ms) + { + _window_size_ms = window_size_ms; + _sil_to_speech_time = sil_to_speech_time; + _speech_to_sil_time = speech_to_sil_time; + _frame_size_ms = frame_size_ms; + + _win_size_frame = (int)(window_size_ms / frame_size_ms); + _win_sum = 0; + _win_state = new int[_win_size_frame];//[0] * _win_size_frame; // 初始化窗 + + _cur_win_pos = 0; + _pre_frame_state = (int)FrameState.kFrameStateSil; + _cur_frame_state = (int)FrameState.kFrameStateSil; + _sil_to_speech_frmcnt_thres = (int)(sil_to_speech_time / frame_size_ms); + _speech_to_sil_frmcnt_thres = (int)(speech_to_sil_time / frame_size_ms); + + _voice_last_frame_count = 0; + _noise_last_frame_count = 0; + _hydre_frame_count = 0; + } + + public void Reset() + { + _cur_win_pos = 0; + _win_sum = 0; + _win_state = new int[_win_size_frame]; + _pre_frame_state = (int)FrameState.kFrameStateSil; + _cur_frame_state = (int)FrameState.kFrameStateSil; + _voice_last_frame_count = 0; + _noise_last_frame_count = 0; + _hydre_frame_count = 0; + } + + + public int GetWinSize() + { + return _win_size_frame; + } + + public AudioChangeState DetectOneFrame(FrameState frameState, int frame_count) + { + + + _cur_frame_state = (int)FrameState.kFrameStateSil; + if (frameState == FrameState.kFrameStateSpeech) + { + _cur_frame_state = 1; + } + + else if (frameState == FrameState.kFrameStateSil) + { + _cur_frame_state = 0; + } + + else + { + return AudioChangeState.kChangeStateInvalid; + } + + _win_sum -= _win_state[_cur_win_pos]; + _win_sum += _cur_frame_state; + _win_state[_cur_win_pos] = _cur_frame_state; + _cur_win_pos = (_cur_win_pos + 1) % _win_size_frame; + + if (_pre_frame_state == (int)FrameState.kFrameStateSil && _win_sum >= _sil_to_speech_frmcnt_thres) + { + _pre_frame_state = (int)FrameState.kFrameStateSpeech; + return AudioChangeState.kChangeStateSil2Speech; + } + + + if (_pre_frame_state == (int)FrameState.kFrameStateSpeech && _win_sum <= _speech_to_sil_frmcnt_thres) + { + _pre_frame_state = (int)FrameState.kFrameStateSil; + return AudioChangeState.kChangeStateSpeech2Sil; + } + + + if (_pre_frame_state == (int)FrameState.kFrameStateSil) + { + return AudioChangeState.kChangeStateSil2Sil; + } + + if (_pre_frame_state == (int)FrameState.kFrameStateSpeech) + { + return AudioChangeState.kChangeStateSpeech2Speech; + } + + return AudioChangeState.kChangeStateInvalid; + } + + private int FrameSizeMs() + { + return _frame_size_ms; + } + + + + } +} diff --git a/funasr/runtime/csharp/README.md b/funasr/runtime/csharp/README.md new file mode 100644 index 000000000..68175cd5a --- /dev/null +++ b/funasr/runtime/csharp/README.md @@ -0,0 +1,59 @@ +# AliFsmnVadSharp +##### 简介: +项目中使用的VAD模型是阿里巴巴达摩院提供的FSMN-Monophone VAD模型。 +**项目基于Net 6.0,使用C#编写,调用Microsoft.ML.OnnxRuntime对onnx模型进行解码,支持跨平台编译。项目以库的形式进行调用,部署非常方便。** +VAD整体流程的rtf在0.008左右。 + +##### 用途: +16k中文通用VAD模型:可用于检测长语音片段中有效语音的起止时间点. +FSMN-Monophone VAD是达摩院语音团队提出的高效语音端点检测模型,用于检测输入音频中有效语音的起止时间点信息,并将检测出来的有效音频片段输入识别引擎进行识别,减少无效语音带来的识别错误。 + +##### VAD常用参数调整说明(参考:vad.yaml文件): +max_end_silence_time:尾部连续检测到多长时间静音进行尾点判停,参数范围500ms~6000ms,默认值800ms(该值过低容易出现语音提前截断的情况)。 +speech_noise_thres:speech的得分减去noise的得分大于此值则判断为speech,参数范围:(-1,1) +取值越趋于-1,噪音被误判定为语音的概率越大,FA越高 +取值越趋于+1,语音被误判定为噪音的概率越大,Pmiss越高 +通常情况下,该值会根据当前模型在长语音测试集上的效果取balance + +##### 模型获取 + +##### 调用方式: +###### 1.添加项目引用 +using AliFsmnVadSharp; + +###### 2.初始化模型和配置 +```csharp +string applicationBase = AppDomain.CurrentDomain.BaseDirectory; +string modelFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"; +string configFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.yaml"; +string mvnFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.mvn"; +int batchSize = 2;//批量解码 +AliFsmnVad aliFsmnVad = new AliFsmnVad(modelFilePath, configFilePath, mvnFilePath, batchSize); +``` +###### 3.调用 +方法一(适用于小文件): +```csharp +SegmentEntity[] segments_duration = aliFsmnVad.GetSegments(samples); +``` +方法二(适用于大文件): +```csharp +SegmentEntity[] segments_duration = aliFsmnVad.GetSegmentsByStep(samples); +``` +###### 4.输出结果: +``` +load model and init config elapsed_milliseconds:463.5390625 +vad infer result: +[[70,2340][2620,6200][6480,23670][23950,26250][26780,28990][29950,31430][31750,37600][38210,46900][47310,49630][49910,56460][56740,59540][59820,70450]] +elapsed_milliseconds:662.796875 +total_duration:70470.625 +rtf:0.009405292985552491 +``` +输出的数据,例如:[70,2340],是以毫秒为单位的segement的起止时间,可以以此为依据对音频进行分片。其中静音噪音部分已被去除。 + +其他说明: +测试用例:AliFsmnVadSharp.Examples。 +测试环境:windows11。 +测试用例中samples的计算,使用的是NAudio库。 + +通过以下链接了解更多: +https://www.modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary diff --git a/funasr/runtime/html5/readme.md b/funasr/runtime/html5/readme.md index 1e9031eec..0c1eba0ea 100644 --- a/funasr/runtime/html5/readme.md +++ b/funasr/runtime/html5/readme.md @@ -51,8 +51,18 @@ https://127.0.0.1:1337/static/index.html # https://30.220.136.139:1337/static/index.html ``` -### modify asr address in html according to your environment -asr address in index.html must be wss +### open browser to open html5 file directly without h5Server +you can run html5 client by just clicking the index.html file directly in your computer. +1) lauch asr service without ssl, it must be in ws mode as ssl protocol will prohibit such access. +2) copy whole directory /funasr/runtime/html5/static to your computer +3) open /funasr/runtime/html5/static/index.html by browser +4) enter asr service ws address and connect + + +```shell + +``` + ## Acknowledge diff --git a/funasr/runtime/html5/static/index.html b/funasr/runtime/html5/static/index.html index 99aa9b477..23d6fece3 100644 --- a/funasr/runtime/html5/static/index.html +++ b/funasr/runtime/html5/static/index.html @@ -21,14 +21,35 @@

+
+
+ 选择录音模式:
+ +    + + +
+
-
+
选择asr模型模式:
+       - + + +
+ +
+ + + 语音识别结果显示:
diff --git a/funasr/runtime/html5/static/main.js b/funasr/runtime/html5/static/main.js index 22f53c152..35e533a43 100644 --- a/funasr/runtime/html5/static/main.js +++ b/funasr/runtime/html5/static/main.js @@ -32,15 +32,121 @@ btnStart.disabled = true; btnConnect= document.getElementById('btnConnect'); btnConnect.onclick = start; -var rec_text=""; -var offline_text=""; +var rec_text=""; // for online rec asr result +var offline_text=""; // for offline rec asr result var info_div = document.getElementById('info_div'); -//var now_ipaddress=window.location.href; -//now_ipaddress=now_ipaddress.replace("https://","wss://"); -//now_ipaddress=now_ipaddress.replace("static/index.html",""); -//document.getElementById('wssip').value=now_ipaddress; +var upfile = document.getElementById('upfile'); + + +var isfilemode=false; // if it is in file mode +var file_data_array; // array to save file data +var isconnected=0; // for file rec, 0 is not begin, 1 is connected, -1 is error +var totalsend=0; + +upfile.onchange = function () { +      var len = this.files.length; + for(let i = 0; i < len; i++) { + let fileAudio = new FileReader(); + fileAudio.readAsArrayBuffer(this.files[i]); + fileAudio.onload = function() { + var audioblob= fileAudio.result; + file_data_array=audioblob; + console.log(audioblob); + btnConnect.disabled = false; + info_div.innerHTML='请点击连接进行识别'; + + } +          fileAudio.onerror = function(e) { +            console.log('error' + e); +          } + } + } + +function play_file() +{ + var audioblob=new Blob( [ new Uint8Array(file_data_array)] , {type :"audio/wav"}); + var audio_record = document.getElementById('audio_record'); + audio_record.src = (window.URL||webkitURL).createObjectURL(audioblob); + audio_record.controls=true; + audio_record.play(); +} +function start_file_send() +{ + sampleBuf=new Int16Array( file_data_array ); + + var chunk_size=960; // for asr chunk_size [5, 10, 5] + + + + + + while(sampleBuf.length>=chunk_size){ + + sendBuf=sampleBuf.slice(0,chunk_size); + totalsend=totalsend+sampleBuf.length; + sampleBuf=sampleBuf.slice(chunk_size,sampleBuf.length); + wsconnecter.wsSend(sendBuf,false); + + + } + + stop(); + + + +} +function start_file_offline() +{ + console.log("start_file_offline",isconnected); + if(isconnected==-1) + { + return; + } + if(isconnected==0){ + + setTimeout(start_file_offline, 1000); + return; + } + start_file_send(); + + + + +} + +function on_recoder_mode_change() +{ + var item = null; + var obj = document.getElementsByName("recoder_mode"); + for (var i = 0; i < obj.length; i++) { //遍历Radio + if (obj[i].checked) { + item = obj[i].value; + break; + } + + + } + if(item=="mic") + { + document.getElementById("mic_mode_div").style.display = 'block'; + document.getElementById("rec_mode_div").style.display = 'none'; + + btnConnect.disabled=false; + isfilemode=false; + } + else + { + document.getElementById("mic_mode_div").style.display = 'none'; + document.getElementById("rec_mode_div").style.display = 'block'; + btnConnect.disabled = true; + isfilemode=true; + info_div.innerHTML='请点击选择文件'; + + + } +} function getAsrMode(){ var item = null; @@ -53,7 +159,12 @@ function getAsrMode(){ } + if(isfilemode) + { + item= "offline"; + } console.log("asr mode"+item); + return item; } @@ -78,6 +189,18 @@ function getJsonMessage( jsonMsg ) { varArea.value=rec_text; console.log( "offline_text: " + asrmodel+","+offline_text); console.log( "rec_text: " + rec_text); + if (isfilemode==true){ + console.log("call stop ws!"); + play_file(); + wsconnecter.wsStop(); + + info_div.innerHTML="请点击连接"; + isconnected=0; + btnStart.disabled = true; + btnStop.disabled = true; + btnConnect.disabled=false; + } + } @@ -86,14 +209,11 @@ function getJsonMessage( jsonMsg ) { function getConnState( connState ) { if ( connState === 0 ) { - //rec.open( function(){ - // rec.start(); - // console.log("开始录音"); - //}); - btnStart.disabled = false; - btnConnect.disabled = true; info_div.innerHTML='连接成功!请点击开始'; + if (isfilemode==true){ + info_div.innerHTML='请耐心等待,大文件等待时间更长'; + } } else if ( connState === 1 ) { //stop(); } else if ( connState === 2 ) { @@ -102,36 +222,52 @@ function getConnState( connState ) { alert("连接地址"+document.getElementById('wssip').value+"失败,请检查asr地址和端口,并确保h5服务和asr服务在同一个域内。或换个浏览器试试。"); btnStart.disabled = true; - + isconnected=0; + info_div.innerHTML='请点击连接'; } } function record() { + rec.open( function(){ rec.start(); console.log("开始"); btnStart.disabled = true; }); + } + + + // 识别启动、停止、清空操作 function start() { // 清除显示 clear(); //控件状态更新 - + console.log("isfilemode"+isfilemode+","+isconnected); info_div.innerHTML="正在连接asr服务器,请等待..."; //启动连接 var ret=wsconnecter.wsStart(); if(ret==1){ isRec = true; - btnStart.disabled = true; + btnStart.disabled = false; btnStop.disabled = false; btnConnect.disabled=true; - + if (isfilemode) + { + console.log("start file now"); + start_file_offline(); + + btnStart.disabled = true; + btnStop.disabled = true; + btnConnect.disabled = true; + } + return 1; } + return 0; } @@ -152,21 +288,26 @@ function stop() { } wsconnecter.wsSend( JSON.stringify(request) ,false); + + - - - + //isconnected=0; // 控件状态更新 + isRec = false; - info_div.innerHTML="请等候..."; - btnStop.disabled = true; - setTimeout(function(){ - console.log("call stop ws!"); - wsconnecter.wsStop(); + info_div.innerHTML="发送完数据,请等候,正在识别..."; + + if(isfilemode==false){ + btnStop.disabled = true; btnStart.disabled = true; btnConnect.disabled=false; + setTimeout(function(){ + console.log("call stop ws!"); + wsconnecter.wsStop(); + isconnected=0; info_div.innerHTML="请点击连接";}, 3000 ); + rec.stop(function(blob,duration){ console.log(blob); @@ -189,8 +330,9 @@ function stop() { },function(errMsg){ console.log("errMsg: " + errMsg); }); + } // 停止连接 - + } diff --git a/funasr/runtime/html5/static/wsconnecter.js b/funasr/runtime/html5/static/wsconnecter.js index 676a94ae5..b9098bb5a 100644 --- a/funasr/runtime/html5/static/wsconnecter.js +++ b/funasr/runtime/html5/static/wsconnecter.js @@ -15,8 +15,7 @@ function WebSocketConnectMethod( config ) { //定义socket连接方法类 this.wsStart = function () { var Uri = document.getElementById('wssip').value; //"wss://111.205.137.58:5821/wss/" //设置wss asr online接口地址 如 wss://X.X.X.X:port/wss/ - - if(Uri.match(/wss:\S*/)) + if(Uri.match(/wss:\S*|ws:\S*/)) { console.log("Uri"+Uri); } @@ -25,6 +24,7 @@ function WebSocketConnectMethod( config ) { //定义socket连接方法类 alert("请检查wss地址正确性"); return 0; } + if ( 'WebSocket' in window ) { speechSokt = new WebSocket( Uri ); // 定义socket连接对象 speechSokt.onopen = function(e){onOpen(e);}; // 定义响应函数 @@ -80,6 +80,7 @@ function WebSocketConnectMethod( config ) { //定义socket连接方法类 speechSokt.send( JSON.stringify(request) ); console.log("连接成功"); stateHandle(0); + isconnected=1; } function onClose( e ) { @@ -92,9 +93,11 @@ function WebSocketConnectMethod( config ) { //定义socket连接方法类 } function onError( e ) { + isconnected=-1; info_div.innerHTML="连接"+e; console.log(e); stateHandle(2); + } diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp index 8b59000fc..53301253c 100644 --- a/funasr/runtime/websocket/funasr-wss-client.cpp +++ b/funasr/runtime/websocket/funasr-wss-client.cpp @@ -277,13 +277,14 @@ class WebsocketClient { }; int main(int argc, char* argv[]) { + google::InitGoogleLogging(argv[0]); FLAGS_logtostderr = true; TCLAP::CmdLine cmd("funasr-ws-client", ' ', "1.0"); TCLAP::ValueArg server_ip_("", "server-ip", "server-ip", true, "127.0.0.1", "string"); - TCLAP::ValueArg port_("", "port", "port", true, "8889", "string"); + TCLAP::ValueArg port_("", "port", "port", true, "10095", "string"); TCLAP::ValueArg wav_path_("", "wav-path", "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); diff --git a/funasr/runtime/websocket/funasr-wss-server.cpp b/funasr/runtime/websocket/funasr-wss-server.cpp index 5f2af5ce0..a6fa2e137 100644 --- a/funasr/runtime/websocket/funasr-wss-server.cpp +++ b/funasr/runtime/websocket/funasr-wss-server.cpp @@ -11,6 +11,7 @@ // [--vad-quant ] [--vad-dir ] [--quantize // ] --model-dir [--] [--version] [-h] #include "websocket-server.h" +#include using namespace std; void GetValue(TCLAP::ValueArg& value_arg, string key, @@ -20,10 +21,15 @@ void GetValue(TCLAP::ValueArg& value_arg, string key, } int main(int argc, char* argv[]) { try { + google::InitGoogleLogging(argv[0]); FLAGS_logtostderr = true; TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0"); + TCLAP::ValueArg download_model_dir( + "", "download-model-dir", + "Download model from Modelscope to download_model_dir", + false, "", "string"); TCLAP::ValueArg model_dir( "", MODEL_DIR, "default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn", @@ -53,15 +59,15 @@ int main(int argc, char* argv[]) { "true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); - TCLAP::ValueArg listen_ip("", "listen_ip", "listen_ip", false, + TCLAP::ValueArg listen_ip("", "listen-ip", "listen ip", false, "0.0.0.0", "string"); - TCLAP::ValueArg port("", "port", "port", false, 8889, "int"); - TCLAP::ValueArg io_thread_num("", "io_thread_num", "io_thread_num", + TCLAP::ValueArg port("", "port", "port", false, 10095, "int"); + TCLAP::ValueArg io_thread_num("", "io-thread-num", "io thread num", false, 8, "int"); TCLAP::ValueArg decoder_thread_num( - "", "decoder_thread_num", "decoder_thread_num", false, 8, "int"); - TCLAP::ValueArg model_thread_num("", "model_thread_num", - "model_thread_num", false, 1, "int"); + "", "decoder-thread-num", "decoder thread num", false, 8, "int"); + TCLAP::ValueArg model_thread_num("", "model-thread-num", + "model thread num", false, 1, "int"); TCLAP::ValueArg certfile("", "certfile", "default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.", @@ -73,6 +79,7 @@ int main(int argc, char* argv[]) { cmd.add(certfile); cmd.add(keyfile); + cmd.add(download_model_dir); cmd.add(model_dir); cmd.add(quantize); cmd.add(vad_dir); @@ -95,6 +102,71 @@ int main(int argc, char* argv[]) { GetValue(punc_dir, PUNC_DIR, model_path); GetValue(punc_quant, PUNC_QUANT, model_path); + // Download model form Modelscope + try{ + std::string s_download_model_dir = download_model_dir.getValue(); + if(download_model_dir.isSet() && !s_download_model_dir.empty()){ + if (access(s_download_model_dir.c_str(), F_OK) != 0){ + LOG(ERROR) << s_download_model_dir << " do not exists."; + exit(-1); + } + std::string s_vad_path = model_path[VAD_DIR]; + std::string s_asr_path = model_path[MODEL_DIR]; + std::string s_punc_path = model_path[PUNC_DIR]; + std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True "; + if(vad_dir.isSet() && !s_vad_path.empty()){ + std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir; + LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: "; + system(python_cmd_vad.c_str()); + std::string down_vad_path = s_download_model_dir+"/"+s_vad_path; + std::string down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx"; + if (access(down_vad_model.c_str(), F_OK) != 0){ + LOG(ERROR) << down_vad_model << " do not exists."; + exit(-1); + }else{ + model_path[VAD_DIR]=down_vad_path; + LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR]; + } + }else{ + LOG(INFO) << "VAD model is not set, use default."; + } + if(model_dir.isSet() && !s_asr_path.empty()){ + std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir; + LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: "; + system(python_cmd_asr.c_str()); + std::string down_asr_path = s_download_model_dir+"/"+s_asr_path; + std::string down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx"; + if (access(down_asr_model.c_str(), F_OK) != 0){ + LOG(ERROR) << down_asr_model << " do not exists."; + exit(-1); + }else{ + model_path[MODEL_DIR]=down_asr_path; + LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR]; + } + }else{ + LOG(INFO) << "ASR model is not set, use default."; + } + if(punc_dir.isSet() && !s_punc_path.empty()){ + std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir; + LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: "; + system(python_cmd_punc.c_str()); + std::string down_punc_path = s_download_model_dir+"/"+s_punc_path; + std::string down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx"; + if (access(down_punc_model.c_str(), F_OK) != 0){ + LOG(ERROR) << down_punc_model << " do not exists."; + exit(-1); + }else{ + model_path[PUNC_DIR]=down_punc_path; + LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR]; + } + }else{ + LOG(INFO) << "PUNC model is not set, use default."; + } + } + } catch (std::exception const& e) { + LOG(ERROR) << "Error: " << e.what(); + } + std::string s_listen_ip = listen_ip.getValue(); int s_port = port.getValue(); int s_io_thread_num = io_thread_num.getValue();