mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rnnt bug fix
This commit is contained in:
parent
e33bb15d26
commit
96bae0153c
@ -31,7 +31,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.types import str2bool, str2triple_str, str_or_none
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class for Transducer models.
|
||||
@ -62,6 +62,7 @@ class Speech2Text:
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
beam_search_config: Dict[str, Any] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
@ -86,11 +87,14 @@ class Speech2Text:
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, device
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
|
||||
if quantize_asr_model:
|
||||
if quantize_modules is not None:
|
||||
if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
|
||||
@ -156,7 +160,7 @@ class Speech2Text:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.device = device
|
||||
@ -181,23 +185,13 @@ class Speech2Text:
|
||||
self.simu_streaming = False
|
||||
self.asr_model.encoder.dynamic_chunk_training = False
|
||||
|
||||
self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
|
||||
self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)
|
||||
|
||||
if asr_train_args.frontend_conf.get("win_length", None) is not None:
|
||||
self.frontend_window_size = asr_train_args.frontend_conf["win_length"]
|
||||
else:
|
||||
self.frontend_window_size = self.n_fft
|
||||
|
||||
self.frontend = frontend
|
||||
self.window_size = self.chunk_size + self.right_context
|
||||
self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
|
||||
self.window_size, self.hop_length
|
||||
)
|
||||
|
||||
self._ctx = self.asr_model.encoder.get_encoder_input_size(
|
||||
self.window_size
|
||||
)
|
||||
|
||||
|
||||
#self.last_chunk_length = (
|
||||
# self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
|
||||
#) * self.hop_length
|
||||
@ -218,112 +212,6 @@ class Speech2Text:
|
||||
|
||||
self.num_processed_frames = torch.tensor([[0]], device=self.device)
|
||||
|
||||
def apply_frontend(
|
||||
self, speech: torch.Tensor, is_final: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward frontend.
|
||||
Args:
|
||||
speech: Speech data. (S)
|
||||
is_final: Whether speech corresponds to the final (or only) chunk of data.
|
||||
Returns:
|
||||
feats: Features sequence. (1, T_in, F)
|
||||
feats_lengths: Features sequence length. (1, T_in, F)
|
||||
"""
|
||||
if self.frontend_cache is not None:
|
||||
speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0)
|
||||
|
||||
if is_final:
|
||||
if self.streaming and speech.size(0) < self.last_chunk_length:
|
||||
pad = torch.zeros(
|
||||
self.last_chunk_length - speech.size(0), dtype=speech.dtype
|
||||
)
|
||||
speech = torch.cat([speech, pad], dim=0)
|
||||
|
||||
speech_to_process = speech
|
||||
waveform_buffer = None
|
||||
else:
|
||||
n_frames = (
|
||||
speech.size(0) - (self.frontend_window_size - self.hop_length)
|
||||
) // self.hop_length
|
||||
|
||||
n_residual = (
|
||||
speech.size(0) - (self.frontend_window_size - self.hop_length)
|
||||
) % self.hop_length
|
||||
|
||||
speech_to_process = speech.narrow(
|
||||
0,
|
||||
0,
|
||||
(self.frontend_window_size - self.hop_length)
|
||||
+ n_frames * self.hop_length,
|
||||
)
|
||||
|
||||
waveform_buffer = speech.narrow(
|
||||
0,
|
||||
speech.size(0)
|
||||
- (self.frontend_window_size - self.hop_length)
|
||||
- n_residual,
|
||||
(self.frontend_window_size - self.hop_length) + n_residual,
|
||||
).clone()
|
||||
|
||||
speech_to_process = speech_to_process.unsqueeze(0).to(
|
||||
getattr(torch, self.dtype)
|
||||
)
|
||||
lengths = speech_to_process.new_full(
|
||||
[1], dtype=torch.long, fill_value=speech_to_process.size(1)
|
||||
)
|
||||
batch = {"speech": speech_to_process, "speech_lengths": lengths}
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
feats, feats_lengths = self.asr_model._extract_feats(**batch)
|
||||
if self.asr_model.normalize is not None:
|
||||
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
|
||||
|
||||
if is_final:
|
||||
if self.frontend_cache is None:
|
||||
pass
|
||||
else:
|
||||
feats = feats.narrow(
|
||||
1,
|
||||
math.ceil(
|
||||
math.ceil(self.frontend_window_size / self.hop_length) / 2
|
||||
),
|
||||
feats.size(1)
|
||||
- math.ceil(
|
||||
math.ceil(self.frontend_window_size / self.hop_length) / 2
|
||||
),
|
||||
)
|
||||
else:
|
||||
if self.frontend_cache is None:
|
||||
feats = feats.narrow(
|
||||
1,
|
||||
0,
|
||||
feats.size(1)
|
||||
- math.ceil(
|
||||
math.ceil(self.frontend_window_size / self.hop_length) / 2
|
||||
),
|
||||
)
|
||||
else:
|
||||
feats = feats.narrow(
|
||||
1,
|
||||
math.ceil(
|
||||
math.ceil(self.frontend_window_size / self.hop_length) / 2
|
||||
),
|
||||
feats.size(1)
|
||||
- 2
|
||||
* math.ceil(
|
||||
math.ceil(self.frontend_window_size / self.hop_length) / 2
|
||||
),
|
||||
)
|
||||
|
||||
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
||||
|
||||
if is_final:
|
||||
self.frontend_cache = None
|
||||
else:
|
||||
self.frontend_cache = {"waveform_buffer": waveform_buffer}
|
||||
|
||||
return feats, feats_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def streaming_decode(
|
||||
self,
|
||||
@ -410,14 +298,9 @@ class Speech2Text:
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
# lengths: (1,)
|
||||
# feats, feats_length = self.apply_frontend(speech)
|
||||
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
||||
# lengths: (1,)
|
||||
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
||||
|
||||
# print(feats.shape)
|
||||
# print(feats_lengths)
|
||||
if self.asr_model.normalize is not None:
|
||||
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
|
||||
|
||||
@ -495,6 +378,7 @@ def inference(
|
||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str],
|
||||
beam_search_config: Optional[dict],
|
||||
lm_train_config: Optional[str],
|
||||
lm_file: Optional[str],
|
||||
@ -562,7 +446,6 @@ def inference(
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
@ -570,6 +453,7 @@ def inference(
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
beam_search_config=beam_search_config,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
@ -719,6 +603,11 @@ def get_parser():
|
||||
type=str,
|
||||
help="ASR model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str,
|
||||
help="Global cmvn file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_train_config",
|
||||
type=str,
|
||||
|
||||
@ -120,7 +120,7 @@ class ConvInput(torch.nn.Module):
|
||||
self.create_new_mask = self.create_new_conv2d_mask
|
||||
|
||||
self.vgg_like = vgg_like
|
||||
self.min_frame_length = 2
|
||||
self.min_frame_length = 7
|
||||
|
||||
if output_size is not None:
|
||||
self.output = torch.nn.Linear(output_proj, output_size)
|
||||
@ -218,9 +218,4 @@ class ConvInput(torch.nn.Module):
|
||||
: Number of frames before subsampling.
|
||||
|
||||
"""
|
||||
if self.subsampling_factor > 1:
|
||||
if self.vgg_like:
|
||||
return ((size * 2) * self.stride_1) + 1
|
||||
|
||||
return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2
|
||||
return size
|
||||
return size * self.subsampling_factor
|
||||
|
||||
@ -1576,7 +1576,7 @@ class AbsTask(ABC):
|
||||
preprocess=iter_options.preprocess_fn,
|
||||
max_cache_size=iter_options.max_cache_size,
|
||||
max_cache_fd=iter_options.max_cache_fd,
|
||||
dest_sample_rate=args.frontend_conf["fs"],
|
||||
dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000,
|
||||
)
|
||||
cls.check_task_requirements(
|
||||
dataset, args.allow_variable_data_keys, train=iter_options.train
|
||||
|
||||
Loading…
Reference in New Issue
Block a user