mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update cif onnx
This commit is contained in:
parent
4e50630527
commit
508e518b12
4
.gitignore
vendored
4
.gitignore
vendored
@ -7,4 +7,6 @@
|
||||
init_model/
|
||||
*.tar.gz
|
||||
test_local/
|
||||
RapidASR
|
||||
RapidASR
|
||||
export/*
|
||||
*.pyc
|
||||
@ -48,11 +48,11 @@ class CifPredictorV2(nn.Module):
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
|
||||
alphas = alphas.squeeze(-1)
|
||||
|
||||
token_num = alphas.sum(-1)
|
||||
|
||||
mask = mask.squeeze(-1)
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
@ -63,12 +63,14 @@ class CifPredictorV2(nn.Module):
|
||||
|
||||
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
||||
ones_t = torch.ones_like(zeros_t)
|
||||
|
||||
mask_1 = torch.cat([mask, zeros_t], dim=1)
|
||||
mask_2 = torch.cat([ones_t, mask], dim=1)
|
||||
mask = mask_2 - mask_1
|
||||
tail_threshold = mask * tail_threshold
|
||||
alphas = torch.cat([alphas, tail_threshold], dim=1)
|
||||
|
||||
alphas = torch.cat([alphas, zeros_t], dim=1)
|
||||
alphas = torch.add(alphas, tail_threshold)
|
||||
|
||||
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
||||
hidden = torch.cat([hidden, zeros], dim=1)
|
||||
token_num = alphas.sum(dim=-1)
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
|
||||
from rapid_paraformer import Paraformer
|
||||
|
||||
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model = Paraformer(model_dir, batch_size=1)
|
||||
|
||||
wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
|
||||
wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav']
|
||||
|
||||
result = model(wav_path)
|
||||
print(result)
|
||||
Loading…
Reference in New Issue
Block a user