mirror of
https://github.com/FunAudioLLM/SenseVoice.git
synced 2025-09-15 15:08:35 +08:00
add support for timestamp
This commit is contained in:
parent
de00f2b75c
commit
530b2e3ce5
@ -44,6 +44,7 @@ Online Demo:
|
||||
|
||||
<a name="What's News"></a>
|
||||
# What's New 🔥
|
||||
- 2024/11: Add support for timestamp based on the CTC alignment.
|
||||
- 2024/7: Added Export Features for [ONNX](./demo_onnx.py) and [libtorch](./demo_libtorch.py), as well as Python Version Runtimes: [funasr-onnx-0.4.0](https://pypi.org/project/funasr-onnx/), [funasr-torch-0.1.1](https://pypi.org/project/funasr-torch/)
|
||||
- 2024/7: The [SenseVoice-Small](https://www.modelscope.cn/models/iic/SenseVoiceSmall) voice understanding model is open-sourced, which offers high-precision multilingual speech recognition, emotion recognition, and audio event detection capabilities for Mandarin, Cantonese, English, Japanese, and Korean and leads to exceptionally low inference latency.
|
||||
- 2024/7: The CosyVoice for natural speech generation with multi-language, timbre, and emotion control. CosyVoice excels in multi-lingual voice generation, zero-shot voice generation, cross-lingual voice cloning, and instruction-following capabilities. [CosyVoice repo](https://github.com/FunAudioLLM/CosyVoice) and [CosyVoice space](https://www.modelscope.cn/studios/iic/CosyVoice-300M).
|
||||
|
||||
14
demo2.py
14
demo2.py
@ -21,3 +21,17 @@ res = m.inference(
|
||||
|
||||
text = rich_transcription_postprocess(res[0][0]["text"])
|
||||
print(text)
|
||||
|
||||
res = m.inference(
|
||||
data_in=f"{kwargs['model_path']}/example/en.mp3",
|
||||
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
ban_emo_unk=False,
|
||||
output_timestamp=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
timestamp = res[0][0]["timestamp"]
|
||||
text = rich_transcription_postprocess(res[0][0]["text"])
|
||||
print(text)
|
||||
print(timestamp)
|
||||
|
||||
44
model.py
44
model.py
@ -13,7 +13,7 @@ from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
from utils.ctc_alignment import ctc_forced_align
|
||||
|
||||
class SinusoidalPositionEncoder(torch.nn.Module):
|
||||
""" """
|
||||
@ -830,6 +830,8 @@ class SenseVoiceSmall(nn.Module):
|
||||
).repeat(speech.size(0), 1, 1)
|
||||
|
||||
use_itn = kwargs.get("use_itn", False)
|
||||
output_timestamp = kwargs.get("output_timestamp", False)
|
||||
|
||||
textnorm = kwargs.get("text_norm", None)
|
||||
if textnorm is None:
|
||||
textnorm = "withitn" if use_itn else "woitn"
|
||||
@ -878,13 +880,45 @@ class SenseVoiceSmall(nn.Module):
|
||||
|
||||
# Change integer-ids to tokens
|
||||
text = tokenizer.decode(token_int)
|
||||
|
||||
result_i = {"key": key[i], "text": text}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["text"][key[i]] = text
|
||||
|
||||
if output_timestamp:
|
||||
from itertools import groupby
|
||||
timestamp = []
|
||||
tokens = tokenizer.text2tokens(text)[4:]
|
||||
|
||||
logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :]
|
||||
|
||||
pred = logits_speech.argmax(-1).cpu()
|
||||
logits_speech[pred==self.blank_id, self.blank_id] = 0
|
||||
|
||||
align = ctc_forced_align(
|
||||
logits_speech.unsqueeze(0).float(),
|
||||
torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
|
||||
(encoder_out_lens-4).long(),
|
||||
torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device),
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
|
||||
pred = groupby(align[0, :encoder_out_lens[0]])
|
||||
_start = 0
|
||||
token_id = 0
|
||||
ts_max = encoder_out_lens[i] - 4
|
||||
for pred_token, pred_frame in pred:
|
||||
_end = _start + len(list(pred_frame))
|
||||
if pred_token != 0:
|
||||
ts_left = max((_start*60-30)/1000, 0)
|
||||
ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000)
|
||||
timestamp.append([tokens[token_id], ts_left, ts_right])
|
||||
token_id += 1
|
||||
_start = _end
|
||||
|
||||
result_i = {"key": key[i], "text": text, "timestamp": timestamp}
|
||||
results.append(result_i)
|
||||
else:
|
||||
result_i = {"key": key[i], "text": text}
|
||||
results.append(result_i)
|
||||
return results, meta_data
|
||||
|
||||
def export(self, **kwargs):
|
||||
|
||||
76
utils/ctc_alignment.py
Normal file
76
utils/ctc_alignment.py
Normal file
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
|
||||
def ctc_forced_align(
|
||||
log_probs: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
blank: int = 0,
|
||||
ignore_id: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""Align a CTC label sequence to an emission.
|
||||
|
||||
Args:
|
||||
log_probs (Tensor): log probability of CTC emission output.
|
||||
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
|
||||
`C` is the number of characters in alphabet including blank.
|
||||
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
|
||||
where `L` is the target length.
|
||||
input_lengths (Tensor):
|
||||
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
|
||||
target_lengths (Tensor):
|
||||
Lengths of the targets. 1-D Tensor of shape `(B,)`.
|
||||
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
|
||||
ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
|
||||
"""
|
||||
targets[targets == ignore_id] = blank
|
||||
|
||||
batch_size, input_time_size, _ = log_probs.size()
|
||||
bsz_indices = torch.arange(batch_size, device=input_lengths.device)
|
||||
|
||||
_t_a_r_g_e_t_s_ = torch.cat(
|
||||
(
|
||||
torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
|
||||
torch.full_like(targets[:, :1], blank),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
diff_labels = torch.cat(
|
||||
(
|
||||
torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
|
||||
_t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
|
||||
padding_num = 2
|
||||
padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
|
||||
best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
|
||||
best_score[:, padding_num + 0] = log_probs[:, 0, blank]
|
||||
best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
|
||||
|
||||
backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
|
||||
|
||||
for t in range(1, input_time_size):
|
||||
prev = torch.stack(
|
||||
(best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
|
||||
)
|
||||
prev_max_value, prev_max_idx = prev.max(dim=0)
|
||||
best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
|
||||
backpointers[:, t, padding_num:] = prev_max_idx
|
||||
|
||||
l1l2 = best_score.gather(
|
||||
-1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
|
||||
)
|
||||
|
||||
path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
|
||||
path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
|
||||
|
||||
for t in range(input_time_size - 1, 0, -1):
|
||||
target_indices = path[:, t]
|
||||
prev_max_idx = backpointers[bsz_indices, t, target_indices]
|
||||
path[:, t - 1] += target_indices - prev_max_idx
|
||||
|
||||
alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
|
||||
return alignments
|
||||
Loading…
Reference in New Issue
Block a user