Merge pull request #37 from alibaba-damo-academy/dev_lyb

fix ctc module
This commit is contained in:
zhifu gao 2023-01-18 13:52:01 +08:00 committed by GitHub
commit 2de29621c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 9 deletions

View File

@ -95,10 +95,13 @@ class Speech2Text:
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)

View File

@ -98,10 +98,13 @@ class Speech2Text:
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)

View File

@ -100,10 +100,13 @@ class Speech2Text:
# logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
@ -663,7 +666,7 @@ def inference_modelscope(
time_stamp_postprocessed))
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
return asr_result_list
return _forward

View File

@ -96,11 +96,14 @@ class Speech2Text:
else:
decoder = asr_model.decoder2
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)