mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Fixed a bug that caused a runtime error when running the model on CPU, which was introduced in commit c1e365fea0. The error was related to incorrect handling of device placement.
This commit is contained in:
parent
c1e365fea0
commit
6ebf6e48eb
@ -366,8 +366,11 @@ class AutoModel:
|
||||
if pbar:
|
||||
# pbar.update(1)
|
||||
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
||||
with torch.cuda.device(next(model.parameters()).device):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
device = next(model.parameters()).device
|
||||
if device.type == 'cuda':
|
||||
with torch.cuda.device():
|
||||
torch.cuda.empty_cache()
|
||||
return asr_result_list
|
||||
|
||||
def inference_with_vad(self, input, input_len=None, **cfg):
|
||||
|
||||
@ -221,8 +221,10 @@ def main(**kwargs):
|
||||
)
|
||||
trainer.start_step = 0
|
||||
|
||||
with torch.cuda.device(kwargs["device"]):
|
||||
torch.cuda.empty_cache()
|
||||
device = next(model.parameters()).device
|
||||
if device.type == 'cuda':
|
||||
with torch.cuda.device():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
|
||||
logging.info(
|
||||
|
||||
@ -184,8 +184,10 @@ def main(**kwargs):
|
||||
)
|
||||
trainer.start_step = 0
|
||||
|
||||
with torch.cuda.device(kwargs["device"]):
|
||||
torch.cuda.empty_cache()
|
||||
device = next(model.parameters()).device
|
||||
if device.type == 'cuda':
|
||||
with torch.cuda.device():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
|
||||
logging.info(
|
||||
|
||||
@ -873,8 +873,10 @@ class Decoder(torch.nn.Module, ScorerInterface):
|
||||
ctc_state[idx], accum_best_ids
|
||||
)
|
||||
|
||||
with torch.cuda.device(vscores.device):
|
||||
torch.cuda.empty_cache()
|
||||
device = vscores.device
|
||||
if device.type == 'cuda':
|
||||
with torch.cuda.device():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}]
|
||||
ended_hyps = [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user