mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #37 from alibaba-damo-academy/dev_lyb
fix ctc module
This commit is contained in:
commit
2de29621c5
@ -95,10 +95,13 @@ class Speech2Text:
|
|||||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||||
|
|
||||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
if asr_model.ctc != None:
|
||||||
|
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||||
|
scorers.update(
|
||||||
|
ctc=ctc
|
||||||
|
)
|
||||||
token_list = asr_model.token_list
|
token_list = asr_model.token_list
|
||||||
scorers.update(
|
scorers.update(
|
||||||
ctc=ctc,
|
|
||||||
length_bonus=LengthBonus(len(token_list)),
|
length_bonus=LengthBonus(len(token_list)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -98,10 +98,13 @@ class Speech2Text:
|
|||||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||||
|
|
||||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
if asr_model.ctc != None:
|
||||||
|
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||||
|
scorers.update(
|
||||||
|
ctc=ctc
|
||||||
|
)
|
||||||
token_list = asr_model.token_list
|
token_list = asr_model.token_list
|
||||||
scorers.update(
|
scorers.update(
|
||||||
ctc=ctc,
|
|
||||||
length_bonus=LengthBonus(len(token_list)),
|
length_bonus=LengthBonus(len(token_list)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -100,10 +100,13 @@ class Speech2Text:
|
|||||||
# logging.info("asr_train_args: {}".format(asr_train_args))
|
# logging.info("asr_train_args: {}".format(asr_train_args))
|
||||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||||
|
|
||||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
if asr_model.ctc != None:
|
||||||
|
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||||
|
scorers.update(
|
||||||
|
ctc=ctc
|
||||||
|
)
|
||||||
token_list = asr_model.token_list
|
token_list = asr_model.token_list
|
||||||
scorers.update(
|
scorers.update(
|
||||||
ctc=ctc,
|
|
||||||
length_bonus=LengthBonus(len(token_list)),
|
length_bonus=LengthBonus(len(token_list)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -663,7 +666,7 @@ def inference_modelscope(
|
|||||||
time_stamp_postprocessed))
|
time_stamp_postprocessed))
|
||||||
|
|
||||||
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
|
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
|
||||||
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
|
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
|
||||||
return asr_result_list
|
return asr_result_list
|
||||||
return _forward
|
return _forward
|
||||||
|
|
||||||
|
|||||||
@ -96,11 +96,14 @@ class Speech2Text:
|
|||||||
else:
|
else:
|
||||||
decoder = asr_model.decoder2
|
decoder = asr_model.decoder2
|
||||||
|
|
||||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
if asr_model.ctc != None:
|
||||||
|
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||||
|
scorers.update(
|
||||||
|
ctc=ctc
|
||||||
|
)
|
||||||
token_list = asr_model.token_list
|
token_list = asr_model.token_list
|
||||||
scorers.update(
|
scorers.update(
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
ctc=ctc,
|
|
||||||
length_bonus=LengthBonus(len(token_list)),
|
length_bonus=LengthBonus(len(token_list)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user