rnnt bug fix

This commit is contained in:
aky15 2023-03-15 17:34:34 +08:00
parent e33bb15d26
commit 96bae0153c
3 changed files with 20 additions and 136 deletions

View File

@ -31,7 +31,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2bool, str2triple_str, str_or_none
from funasr.utils.cli_utils import get_commandline_args
from funasr.models.frontend.wav_frontend import WavFrontend
class Speech2Text:
"""Speech2Text class for Transducer models.
@ -62,6 +62,7 @@ class Speech2Text:
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
beam_search_config: Dict[str, Any] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
@ -86,11 +87,14 @@ class Speech2Text:
super().__init__()
assert check_argument_types()
asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
asr_train_config, asr_model_file, device
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
if quantize_asr_model:
if quantize_modules is not None:
if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
@ -156,7 +160,7 @@ class Speech2Text:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.device = device
@ -181,23 +185,13 @@ class Speech2Text:
self.simu_streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)
if asr_train_args.frontend_conf.get("win_length", None) is not None:
self.frontend_window_size = asr_train_args.frontend_conf["win_length"]
else:
self.frontend_window_size = self.n_fft
self.frontend = frontend
self.window_size = self.chunk_size + self.right_context
self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
self.window_size, self.hop_length
)
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
#self.last_chunk_length = (
# self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
#) * self.hop_length
@ -218,112 +212,6 @@ class Speech2Text:
self.num_processed_frames = torch.tensor([[0]], device=self.device)
def apply_frontend(
self, speech: torch.Tensor, is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward frontend.
Args:
speech: Speech data. (S)
is_final: Whether speech corresponds to the final (or only) chunk of data.
Returns:
feats: Features sequence. (1, T_in, F)
feats_lengths: Features sequence length. (1, T_in, F)
"""
if self.frontend_cache is not None:
speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0)
if is_final:
if self.streaming and speech.size(0) < self.last_chunk_length:
pad = torch.zeros(
self.last_chunk_length - speech.size(0), dtype=speech.dtype
)
speech = torch.cat([speech, pad], dim=0)
speech_to_process = speech
waveform_buffer = None
else:
n_frames = (
speech.size(0) - (self.frontend_window_size - self.hop_length)
) // self.hop_length
n_residual = (
speech.size(0) - (self.frontend_window_size - self.hop_length)
) % self.hop_length
speech_to_process = speech.narrow(
0,
0,
(self.frontend_window_size - self.hop_length)
+ n_frames * self.hop_length,
)
waveform_buffer = speech.narrow(
0,
speech.size(0)
- (self.frontend_window_size - self.hop_length)
- n_residual,
(self.frontend_window_size - self.hop_length) + n_residual,
).clone()
speech_to_process = speech_to_process.unsqueeze(0).to(
getattr(torch, self.dtype)
)
lengths = speech_to_process.new_full(
[1], dtype=torch.long, fill_value=speech_to_process.size(1)
)
batch = {"speech": speech_to_process, "speech_lengths": lengths}
batch = to_device(batch, device=self.device)
feats, feats_lengths = self.asr_model._extract_feats(**batch)
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
if is_final:
if self.frontend_cache is None:
pass
else:
feats = feats.narrow(
1,
math.ceil(
math.ceil(self.frontend_window_size / self.hop_length) / 2
),
feats.size(1)
- math.ceil(
math.ceil(self.frontend_window_size / self.hop_length) / 2
),
)
else:
if self.frontend_cache is None:
feats = feats.narrow(
1,
0,
feats.size(1)
- math.ceil(
math.ceil(self.frontend_window_size / self.hop_length) / 2
),
)
else:
feats = feats.narrow(
1,
math.ceil(
math.ceil(self.frontend_window_size / self.hop_length) / 2
),
feats.size(1)
- 2
* math.ceil(
math.ceil(self.frontend_window_size / self.hop_length) / 2
),
)
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
if is_final:
self.frontend_cache = None
else:
self.frontend_cache = {"waveform_buffer": waveform_buffer}
return feats, feats_lengths
@torch.no_grad()
def streaming_decode(
self,
@ -410,14 +298,9 @@ class Speech2Text:
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
# lengths: (1,)
# feats, feats_length = self.apply_frontend(speech)
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
# lengths: (1,)
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
# print(feats.shape)
# print(feats_lengths)
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@ -495,6 +378,7 @@ def inference(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str],
beam_search_config: Optional[dict],
lm_train_config: Optional[str],
lm_file: Optional[str],
@ -562,7 +446,6 @@ def inference(
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
@ -570,6 +453,7 @@ def inference(
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
beam_search_config=beam_search_config,
lm_train_config=lm_train_config,
lm_file=lm_file,
@ -719,6 +603,11 @@ def get_parser():
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,

View File

@ -120,7 +120,7 @@ class ConvInput(torch.nn.Module):
self.create_new_mask = self.create_new_conv2d_mask
self.vgg_like = vgg_like
self.min_frame_length = 2
self.min_frame_length = 7
if output_size is not None:
self.output = torch.nn.Linear(output_proj, output_size)
@ -218,9 +218,4 @@ class ConvInput(torch.nn.Module):
: Number of frames before subsampling.
"""
if self.subsampling_factor > 1:
if self.vgg_like:
return ((size * 2) * self.stride_1) + 1
return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2
return size
return size * self.subsampling_factor

View File

@ -1576,7 +1576,7 @@ class AbsTask(ABC):
preprocess=iter_options.preprocess_fn,
max_cache_size=iter_options.max_cache_size,
max_cache_fd=iter_options.max_cache_fd,
dest_sample_rate=args.frontend_conf["fs"],
dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000,
)
cls.check_task_requirements(
dataset, args.allow_variable_data_keys, train=iter_options.train