mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #37 from alibaba-damo-academy/dev_lyb
fix ctc module
This commit is contained in:
commit
2de29621c5
@ -95,10 +95,13 @@ class Speech2Text:
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
if asr_model.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
|
||||
@ -98,10 +98,13 @@ class Speech2Text:
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
if asr_model.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
|
||||
@ -100,10 +100,13 @@ class Speech2Text:
|
||||
# logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
if asr_model.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
@ -663,7 +666,7 @@ def inference_modelscope(
|
||||
time_stamp_postprocessed))
|
||||
|
||||
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
|
||||
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
|
||||
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
|
||||
return asr_result_list
|
||||
return _forward
|
||||
|
||||
|
||||
@ -96,11 +96,14 @@ class Speech2Text:
|
||||
else:
|
||||
decoder = asr_model.decoder2
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
if asr_model.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user