mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rnnt bug fix
This commit is contained in:
parent
931cac99c1
commit
490531ed51
@ -174,7 +174,7 @@ class Speech2Text:
|
|||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.simu_streaming = simu_streaming
|
self.simu_streaming = simu_streaming
|
||||||
self.chunk_size = max(chunk_size, 0)
|
self.chunk_size = max(chunk_size, 0)
|
||||||
self.left_context = max(left_context, 0)
|
self.left_context = left_context
|
||||||
self.right_context = max(right_context, 0)
|
self.right_context = max(right_context, 0)
|
||||||
|
|
||||||
if not streaming or chunk_size == 0:
|
if not streaming or chunk_size == 0:
|
||||||
|
|||||||
@ -531,8 +531,8 @@ class UnifiedTransducerModel(AbsESPnetModel):
|
|||||||
sym_blank: str = "<blank>",
|
sym_blank: str = "<blank>",
|
||||||
report_cer: bool = True,
|
report_cer: bool = True,
|
||||||
report_wer: bool = True,
|
report_wer: bool = True,
|
||||||
sym_sos: str = "<sos/eos>",
|
sym_sos: str = "<s>",
|
||||||
sym_eos: str = "<sos/eos>",
|
sym_eos: str = "</s>",
|
||||||
extract_feats_in_collect_stats: bool = True,
|
extract_feats_in_collect_stats: bool = True,
|
||||||
lsm_weight: float = 0.0,
|
lsm_weight: float = 0.0,
|
||||||
length_normalized_loss: bool = False,
|
length_normalized_loss: bool = False,
|
||||||
|
|||||||
@ -595,7 +595,7 @@ def make_chunk_mask(
|
|||||||
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
|
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||||
|
|
||||||
for i in range(size):
|
for i in range(size):
|
||||||
if left_chunk_size <= 0:
|
if left_chunk_size < 0:
|
||||||
start = 0
|
start = 0
|
||||||
else:
|
else:
|
||||||
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
|
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user