mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
in_cache & support soundfile read
This commit is contained in:
parent
dae8f7472d
commit
31eed1834f
@ -43,6 +43,7 @@ from funasr.utils.types import str_or_none
|
|||||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
from funasr.tasks.vad import VADTask
|
from funasr.tasks.vad import VADTask
|
||||||
|
from funasr.bin.vad_inference import Speech2VadSegment
|
||||||
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
|
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
|
||||||
from funasr.bin.punctuation_infer import Text2Punc
|
from funasr.bin.punctuation_infer import Text2Punc
|
||||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||||
@ -364,101 +365,6 @@ class Speech2Text:
|
|||||||
hotword_list = None
|
hotword_list = None
|
||||||
return hotword_list
|
return hotword_list
|
||||||
|
|
||||||
class Speech2VadSegment:
|
|
||||||
"""Speech2VadSegment class
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> import soundfile
|
|
||||||
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
|
|
||||||
>>> audio, rate = soundfile.read("speech.wav")
|
|
||||||
>>> speech2segment(audio)
|
|
||||||
[[10, 230], [245, 450], ...]
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vad_infer_config: Union[Path, str] = None,
|
|
||||||
vad_model_file: Union[Path, str] = None,
|
|
||||||
vad_cmvn_file: Union[Path, str] = None,
|
|
||||||
device: str = "cpu",
|
|
||||||
batch_size: int = 1,
|
|
||||||
dtype: str = "float32",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
assert check_argument_types()
|
|
||||||
|
|
||||||
# 1. Build vad model
|
|
||||||
vad_model, vad_infer_args = VADTask.build_model_from_file(
|
|
||||||
vad_infer_config, vad_model_file, device
|
|
||||||
)
|
|
||||||
frontend = None
|
|
||||||
if vad_infer_args.frontend is not None:
|
|
||||||
frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
|
|
||||||
|
|
||||||
# logging.info("vad_model: {}".format(vad_model))
|
|
||||||
# logging.info("vad_infer_args: {}".format(vad_infer_args))
|
|
||||||
vad_model.to(dtype=getattr(torch, dtype)).eval()
|
|
||||||
|
|
||||||
self.vad_model = vad_model
|
|
||||||
self.vad_infer_args = vad_infer_args
|
|
||||||
self.device = device
|
|
||||||
self.dtype = dtype
|
|
||||||
self.frontend = frontend
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""Inference
|
|
||||||
|
|
||||||
Args:
|
|
||||||
speech: Input speech data
|
|
||||||
Returns:
|
|
||||||
text, token, token_int, hyp
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert check_argument_types()
|
|
||||||
|
|
||||||
# Input as audio signal
|
|
||||||
if isinstance(speech, np.ndarray):
|
|
||||||
speech = torch.tensor(speech)
|
|
||||||
|
|
||||||
if self.frontend is not None:
|
|
||||||
self.frontend.filter_length_max = math.inf
|
|
||||||
fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
|
|
||||||
feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
|
|
||||||
fbanks = to_device(fbanks, device=self.device)
|
|
||||||
feats = to_device(feats, device=self.device)
|
|
||||||
feats_len = feats_len.int()
|
|
||||||
else:
|
|
||||||
raise Exception("Need to extract feats first, please configure frontend configuration")
|
|
||||||
|
|
||||||
# b. Forward Encoder streaming
|
|
||||||
t_offset = 0
|
|
||||||
step = min(feats_len, 6000)
|
|
||||||
segments = [[]] * self.batch_size
|
|
||||||
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
|
|
||||||
if t_offset + step >= feats_len - 1:
|
|
||||||
step = feats_len - t_offset
|
|
||||||
is_final_send = True
|
|
||||||
else:
|
|
||||||
is_final_send = False
|
|
||||||
batch = {
|
|
||||||
"feats": feats[:, t_offset:t_offset + step, :],
|
|
||||||
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
|
|
||||||
"is_final_send": is_final_send
|
|
||||||
}
|
|
||||||
# a. To device
|
|
||||||
batch = to_device(batch, device=self.device)
|
|
||||||
segments_part = self.vad_model(**batch)
|
|
||||||
if segments_part:
|
|
||||||
for batch_num in range(0, self.batch_size):
|
|
||||||
segments[batch_num] += segments_part[batch_num]
|
|
||||||
|
|
||||||
return fbanks, segments
|
|
||||||
|
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
maxlenratio: float,
|
maxlenratio: float,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from typing import Tuple
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from typeguard import check_argument_types
|
from typeguard import check_argument_types
|
||||||
@ -86,7 +87,7 @@ class Speech2VadSegment:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
|
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
|
||||||
) -> List[List[int]]:
|
) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
|
||||||
"""Inference
|
"""Inference
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -102,7 +103,10 @@ class Speech2VadSegment:
|
|||||||
speech = torch.tensor(speech)
|
speech = torch.tensor(speech)
|
||||||
|
|
||||||
if self.frontend is not None:
|
if self.frontend is not None:
|
||||||
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
self.frontend.filter_length_max = math.inf
|
||||||
|
fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
|
||||||
|
feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
|
||||||
|
fbanks = to_device(fbanks, device=self.device)
|
||||||
feats = to_device(feats, device=self.device)
|
feats = to_device(feats, device=self.device)
|
||||||
feats_len = feats_len.int()
|
feats_len = feats_len.int()
|
||||||
else:
|
else:
|
||||||
@ -110,18 +114,18 @@ class Speech2VadSegment:
|
|||||||
|
|
||||||
# b. Forward Encoder streaming
|
# b. Forward Encoder streaming
|
||||||
t_offset = 0
|
t_offset = 0
|
||||||
step = min(feats_len, 6000)
|
step = min(feats_len.max(), 6000)
|
||||||
segments = [[]] * self.batch_size
|
segments = [[]] * self.batch_size
|
||||||
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
|
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
|
||||||
if t_offset + step >= feats_len - 1:
|
if t_offset + step >= feats_len - 1:
|
||||||
step = feats_len - t_offset
|
step = feats_len - t_offset
|
||||||
is_final_send = True
|
is_final = True
|
||||||
else:
|
else:
|
||||||
is_final_send = False
|
is_final = False
|
||||||
batch = {
|
batch = {
|
||||||
"feats": feats[:, t_offset:t_offset + step, :],
|
"feats": feats[:, t_offset:t_offset + step, :],
|
||||||
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
|
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
|
||||||
"is_final_send": is_final_send
|
"is_final": is_final
|
||||||
}
|
}
|
||||||
# a. To device
|
# a. To device
|
||||||
batch = to_device(batch, device=self.device)
|
batch = to_device(batch, device=self.device)
|
||||||
@ -129,7 +133,7 @@ class Speech2VadSegment:
|
|||||||
if segments_part:
|
if segments_part:
|
||||||
for batch_num in range(0, self.batch_size):
|
for batch_num in range(0, self.batch_size):
|
||||||
segments[batch_num] += segments_part[batch_num]
|
segments[batch_num] += segments_part[batch_num]
|
||||||
return segments
|
return fbanks, segments
|
||||||
|
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
@ -219,9 +223,13 @@ def inference_modelscope(
|
|||||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||||
output_dir_v2: Optional[str] = None,
|
output_dir_v2: Optional[str] = None,
|
||||||
fs: dict = None,
|
fs: dict = None,
|
||||||
param_dict: dict = None,
|
param_dict: dict = None
|
||||||
):
|
):
|
||||||
# 3. Build data-iterator
|
# 3. Build data-iterator
|
||||||
|
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||||
|
if isinstance(raw_inputs, torch.Tensor):
|
||||||
|
raw_inputs = raw_inputs.numpy()
|
||||||
|
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||||
loader = VADTask.build_streaming_iterator(
|
loader = VADTask.build_streaming_iterator(
|
||||||
data_path_and_name_and_type,
|
data_path_and_name_and_type,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -254,7 +262,7 @@ def inference_modelscope(
|
|||||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||||
|
|
||||||
# do vad segment
|
# do vad segment
|
||||||
results = speech2vadsegment(**batch)
|
_, results = speech2vadsegment(**batch)
|
||||||
for i, _ in enumerate(keys):
|
for i, _ in enumerate(keys):
|
||||||
results[i] = json.dumps(results[i])
|
results[i] = json.dumps(results[i])
|
||||||
item = {'key': keys[i], 'value': results[i]}
|
item = {'key': keys[i], 'value': results[i]}
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class E2EVadModel(nn.Module):
|
|||||||
self.vad_opts.frame_in_ms)
|
self.vad_opts.frame_in_ms)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
# init variables
|
# init variables
|
||||||
self.is_final_send = False
|
self.is_final = False
|
||||||
self.data_buf_start_frame = 0
|
self.data_buf_start_frame = 0
|
||||||
self.frm_cnt = 0
|
self.frm_cnt = 0
|
||||||
self.latest_confirmed_speech_frame = 0
|
self.latest_confirmed_speech_frame = 0
|
||||||
@ -230,8 +230,7 @@ class E2EVadModel(nn.Module):
|
|||||||
self.ResetDetection()
|
self.ResetDetection()
|
||||||
|
|
||||||
def AllResetDetection(self):
|
def AllResetDetection(self):
|
||||||
self.encoder.cache_reset() # reset the in_cache in self.encoder for next query or next long sentence
|
self.is_final = False
|
||||||
self.is_final_send = False
|
|
||||||
self.data_buf_start_frame = 0
|
self.data_buf_start_frame = 0
|
||||||
self.frm_cnt = 0
|
self.frm_cnt = 0
|
||||||
self.latest_confirmed_speech_frame = 0
|
self.latest_confirmed_speech_frame = 0
|
||||||
@ -283,8 +282,8 @@ class E2EVadModel(nn.Module):
|
|||||||
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
|
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
|
||||||
0.000001))
|
0.000001))
|
||||||
|
|
||||||
def ComputeScores(self, feats: torch.Tensor) -> None:
|
def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
|
||||||
scores = self.encoder(feats) # return B * T * D
|
scores = self.encoder(feats, in_cache) # return B * T * D
|
||||||
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
|
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
|
||||||
self.vad_opts.nn_eval_block_size = scores.shape[1]
|
self.vad_opts.nn_eval_block_size = scores.shape[1]
|
||||||
self.frm_cnt += scores.shape[1] # count total frames
|
self.frm_cnt += scores.shape[1] # count total frames
|
||||||
@ -306,7 +305,7 @@ class E2EVadModel(nn.Module):
|
|||||||
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
|
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
|
||||||
if last_frm_is_end_point:
|
if last_frm_is_end_point:
|
||||||
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
|
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
|
||||||
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
|
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
|
||||||
expected_sample_number += int(extra_sample)
|
expected_sample_number += int(extra_sample)
|
||||||
if end_point_is_sent_end:
|
if end_point_is_sent_end:
|
||||||
expected_sample_number = max(expected_sample_number, len(self.data_buf))
|
expected_sample_number = max(expected_sample_number, len(self.data_buf))
|
||||||
@ -443,11 +442,13 @@ class E2EVadModel(nn.Module):
|
|||||||
|
|
||||||
return frame_state
|
return frame_state
|
||||||
|
|
||||||
def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]:
|
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||||
|
is_final: bool = False
|
||||||
|
) -> List[List[List[int]]]:
|
||||||
self.waveform = waveform # compute decibel for each frame
|
self.waveform = waveform # compute decibel for each frame
|
||||||
self.ComputeDecibel()
|
self.ComputeDecibel()
|
||||||
self.ComputeScores(feats)
|
self.ComputeScores(feats, in_cache)
|
||||||
if not is_final_send:
|
if not is_final:
|
||||||
self.DetectCommonFrames()
|
self.DetectCommonFrames()
|
||||||
else:
|
else:
|
||||||
self.DetectLastFrames()
|
self.DetectLastFrames()
|
||||||
@ -456,15 +457,18 @@ class E2EVadModel(nn.Module):
|
|||||||
segment_batch = []
|
segment_batch = []
|
||||||
if len(self.output_data_buf) > 0:
|
if len(self.output_data_buf) > 0:
|
||||||
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
||||||
if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[
|
if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
|
||||||
i].contain_seg_end_point:
|
i].contain_seg_end_point:
|
||||||
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
|
continue
|
||||||
segment_batch.append(segment)
|
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
|
||||||
self.output_data_buf_offset += 1 # need update this parameter
|
segment_batch.append(segment)
|
||||||
|
self.output_data_buf_offset += 1 # need update this parameter
|
||||||
if segment_batch:
|
if segment_batch:
|
||||||
segments.append(segment_batch)
|
segments.append(segment_batch)
|
||||||
if is_final_send:
|
if is_final:
|
||||||
self.AllResetDetection()
|
# reset class variables and clear the dict for the next query
|
||||||
|
self.AllResetDetection()
|
||||||
|
in_cache.clear()
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
def DetectCommonFrames(self) -> int:
|
def DetectCommonFrames(self) -> int:
|
||||||
|
|||||||
@ -79,14 +79,12 @@ class FSMNBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.conv_right = None
|
self.conv_right = None
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, in_cache=None):
|
def forward(self, input: torch.Tensor, cache: torch.Tensor):
|
||||||
x = torch.unsqueeze(input, 1)
|
x = torch.unsqueeze(input, 1)
|
||||||
x_per = x.permute(0, 3, 2, 1) # B D T C
|
x_per = x.permute(0, 3, 2, 1) # B D T C
|
||||||
if in_cache is None: # offline
|
|
||||||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
|
y_left = torch.cat((cache, x_per), dim=2)
|
||||||
else:
|
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
|
||||||
y_left = torch.cat((in_cache, x_per), dim=2)
|
|
||||||
in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
|
|
||||||
y_left = self.conv_left(y_left)
|
y_left = self.conv_left(y_left)
|
||||||
out = x_per + y_left
|
out = x_per + y_left
|
||||||
|
|
||||||
@ -100,7 +98,7 @@ class FSMNBlock(nn.Module):
|
|||||||
out_per = out.permute(0, 3, 2, 1)
|
out_per = out.permute(0, 3, 2, 1)
|
||||||
output = out_per.squeeze(1)
|
output = out_per.squeeze(1)
|
||||||
|
|
||||||
return output, in_cache
|
return output, cache
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Sequential):
|
class BasicBlock(nn.Sequential):
|
||||||
@ -124,28 +122,25 @@ class BasicBlock(nn.Sequential):
|
|||||||
self.affine = AffineTransform(proj_dim, linear_dim)
|
self.affine = AffineTransform(proj_dim, linear_dim)
|
||||||
self.relu = RectifiedLinear(linear_dim, linear_dim)
|
self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, in_cache=None):
|
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
|
||||||
x1 = self.linear(input) # B T D
|
x1 = self.linear(input) # B T D
|
||||||
if in_cache is not None: # Dict[str, tensor.Tensor]
|
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
|
||||||
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
|
if cache_layer_name not in in_cache:
|
||||||
if cache_layer_name not in in_cache:
|
in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
|
||||||
in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
|
x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
|
||||||
x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
|
|
||||||
else:
|
|
||||||
x2, _ = self.fsmn_block(x1)
|
|
||||||
x3 = self.affine(x2)
|
x3 = self.affine(x2)
|
||||||
x4 = self.relu(x3)
|
x4 = self.relu(x3)
|
||||||
return x4, in_cache
|
return x4
|
||||||
|
|
||||||
|
|
||||||
class FsmnStack(nn.Sequential):
|
class FsmnStack(nn.Sequential):
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
super(FsmnStack, self).__init__(*args)
|
super(FsmnStack, self).__init__(*args)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, in_cache=None):
|
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
|
||||||
x = input
|
x = input
|
||||||
for module in self._modules.values():
|
for module in self._modules.values():
|
||||||
x, in_cache = module(x, in_cache)
|
x = module(x, in_cache)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -174,8 +169,7 @@ class FSMN(nn.Module):
|
|||||||
lstride: int,
|
lstride: int,
|
||||||
rstride: int,
|
rstride: int,
|
||||||
output_affine_dim: int,
|
output_affine_dim: int,
|
||||||
output_dim: int,
|
output_dim: int
|
||||||
streaming=False
|
|
||||||
):
|
):
|
||||||
super(FSMN, self).__init__()
|
super(FSMN, self).__init__()
|
||||||
|
|
||||||
@ -186,8 +180,6 @@ class FSMN(nn.Module):
|
|||||||
self.proj_dim = proj_dim
|
self.proj_dim = proj_dim
|
||||||
self.output_affine_dim = output_affine_dim
|
self.output_affine_dim = output_affine_dim
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
self.in_cache_original = dict() if streaming else None
|
|
||||||
self.in_cache = copy.deepcopy(self.in_cache_original)
|
|
||||||
|
|
||||||
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
||||||
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
||||||
@ -201,12 +193,10 @@ class FSMN(nn.Module):
|
|||||||
def fuse_modules(self):
|
def fuse_modules(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def cache_reset(self):
|
|
||||||
self.in_cache = copy.deepcopy(self.in_cache_original)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
|
in_cache: Dict[str, torch.Tensor]
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -218,7 +208,7 @@ class FSMN(nn.Module):
|
|||||||
x1 = self.in_linear1(input)
|
x1 = self.in_linear1(input)
|
||||||
x2 = self.in_linear2(x1)
|
x2 = self.in_linear2(x1)
|
||||||
x3 = self.relu(x2)
|
x3 = self.relu(x2)
|
||||||
x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
|
x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
|
||||||
x5 = self.out_linear1(x4)
|
x5 = self.out_linear1(x4)
|
||||||
x6 = self.out_linear2(x5)
|
x6 = self.out_linear2(x5)
|
||||||
x7 = self.softmax(x6)
|
x7 = self.softmax(x6)
|
||||||
@ -307,4 +297,4 @@ if __name__ == '__main__':
|
|||||||
print('input shape: {}'.format(x.shape))
|
print('input shape: {}'.format(x.shape))
|
||||||
print('output shape: {}'.format(y.shape))
|
print('output shape: {}'.format(y.shape))
|
||||||
|
|
||||||
print(fsmn.to_kaldi_net())
|
print(fsmn.to_kaldi_net())
|
||||||
Loading…
Reference in New Issue
Block a user