update asr with speaker

This commit is contained in:
shixian.shi 2024-01-11 17:03:00 +08:00
parent 78ffd04ac9
commit 7037971392
7 changed files with 747 additions and 671 deletions

View File

@ -11,7 +11,23 @@ model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k
vad_model_revision="v2.0.0",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
'''try asr with speaker label with
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
spk_mode='punc_segment',
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
'''

View File

@ -1,26 +1,26 @@
import os.path
import torch
import numpy as np
import hydra
import json
from omegaconf import DictConfig, OmegaConf, ListConfig
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.load_utils import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
import time
import torch
import hydra
import random
import string
from funasr.register import tables
import logging
import os.path
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.timestamp_tools import time_stamp_sentence
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
@ -126,13 +126,27 @@ class AutoModel:
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
# if spk_model is not None, build spk model else None
spk_model = kwargs.get("spk_model", None)
spk_kwargs = kwargs.get("spk_model_revision", None)
if spk_model is not None:
spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend()
spk_mode = kwargs.get("spk_mode", 'punc_segment')
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
self.spk_mode = spk_mode
logging.warning("Many to print when using speaker model...")
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
self.punc_model = punc_model
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
def build_model(self, **kwargs):
@ -198,7 +212,6 @@ class AutoModel:
return self.generate_with_vad(input, input_len=input_len, **cfg)
def generate(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
# import pdb; pdb.set_trace()
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
model = self.model if model is None else model
@ -260,6 +273,7 @@ class AutoModel:
kwargs.update(cfg)
beg_vad = time.time()
res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
vad_res = res
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
@ -314,10 +328,20 @@ class AutoModel:
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
beg_idx = end_idx
results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
all_segments = []
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)):
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
speech_j[_b]]]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
continue
results_sorted.extend(results)
@ -336,40 +360,64 @@ class AutoModel:
restored_data[index] = results_sorted[j]
result = {}
# results combine for texts, timestamps, speaker embeddings and others
# TODO: rewrite for clean code
for j in range(n):
for k, v in restored_data[j].items():
if not k.startswith("timestamp"):
if k.startswith("timestamp"):
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += restored_data[j][k]
else:
result[k] = []
for t in restored_data[j][k]:
t[0] += vadsegments[j][0]
t[1] += vadsegments[j][0]
result[k].extend(restored_data[j][k])
elif k == 'spk_embedding':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
elif k == 'text':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += " " + restored_data[j][k]
else:
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += restored_data[j][k]
# step.3 compute punc model
if self.punc_model is not None:
self.punc_kwargs.update(cfg)
punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
result["text_with_punc"] = punc_res[0]["text"]
# speaker embedding cluster after resorted
if self.spk_model is not None:
all_segments = sorted(all_segments, key=lambda x: x[0])
spk_embedding = result['spk_embedding']
labels = self.cb_model(spk_embedding)
del result['spk_embedding']
sv_output = postprocess(all_segments, None, labels, spk_embedding)
if self.spk_mode == 'vad_segment':
sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments):
sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1],
"sentence": res['text'],
"timestamp": res['timestamp']})
else: # punc_segment
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['text'])
distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list
result["key"] = key
results_ret_list.append(result)
pbar_total.update(1)
# step.3 compute punc model
model = self.punc_model
kwargs = self.punc_kwargs
kwargs.update(cfg)
for i, result in enumerate(results_ret_list):
beg_punc = time.time()
res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg)
end_punc = time.time()
print(f"time punc: {end_punc - beg_punc:0.3f}")
# sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"])
# results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"]
# results_ret_list[i]["sentences"] = sentences
results_ret_list[i]["text_with_punc"] = res[i]["text"]
pbar_total.update(1)
end_total = time.time()
time_escape_total_all_samples = end_total - beg_total

View File

@ -0,0 +1,191 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union
import hdbscan
import numpy as np
import scipy
import sklearn
import umap
from sklearn.cluster._kmeans import k_means
from torch import nn
class SpectralCluster:
r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
This implementation is adapted from https://github.com/speechbrain/speechbrain.
"""
def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
self.min_num_spks = min_num_spks
self.max_num_spks = max_num_spks
self.pval = pval
def __call__(self, X, oracle_num=None):
# Similarity matrix computation
sim_mat = self.get_sim_mat(X)
# Refining similarity matrix with pval
prunned_sim_mat = self.p_pruning(sim_mat)
# Symmetrization
sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
# Laplacian calculation
laplacian = self.get_laplacian(sym_prund_sim_mat)
# Get Spectral Embeddings
emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
# Perform clustering
labels = self.cluster_embs(emb, num_of_spk)
return labels
def get_sim_mat(self, X):
# Cosine similarities
M = sklearn.metrics.pairwise.cosine_similarity(X, X)
return M
def p_pruning(self, A):
if A.shape[0] * self.pval < 6:
pval = 6. / A.shape[0]
else:
pval = self.pval
n_elems = int((1 - pval) * A.shape[0])
# For each row in a affinity matrix
for i in range(A.shape[0]):
low_indexes = np.argsort(A[i, :])
low_indexes = low_indexes[0:n_elems]
# Replace smaller similarity values by 0s
A[i, low_indexes] = 0
return A
def get_laplacian(self, M):
M[np.diag_indices(M.shape[0])] = 0
D = np.sum(np.abs(M), axis=1)
D = np.diag(D)
L = D - M
return L
def get_spec_embs(self, L, k_oracle=None):
lambdas, eig_vecs = scipy.linalg.eigh(L)
if k_oracle is not None:
num_of_spk = k_oracle
else:
lambda_gap_list = self.getEigenGaps(
lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
emb = eig_vecs[:, :num_of_spk]
return emb, num_of_spk
def cluster_embs(self, emb, k):
_, labels, _ = k_means(emb, k)
return labels
def getEigenGaps(self, eig_vals):
eig_vals_gap_list = []
for i in range(len(eig_vals) - 1):
gap = float(eig_vals[i + 1]) - float(eig_vals[i])
eig_vals_gap_list.append(gap)
return eig_vals_gap_list
class UmapHdbscan:
r"""
Reference:
- Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
Emphasis On Topological Structure. ICASSP2022
"""
def __init__(self,
n_neighbors=20,
n_components=60,
min_samples=10,
min_cluster_size=10,
metric='cosine'):
self.n_neighbors = n_neighbors
self.n_components = n_components
self.min_samples = min_samples
self.min_cluster_size = min_cluster_size
self.metric = metric
def __call__(self, X):
umap_X = umap.UMAP(
n_neighbors=self.n_neighbors,
min_dist=0.0,
n_components=min(self.n_components, X.shape[0] - 2),
metric=self.metric,
).fit_transform(X)
labels = hdbscan.HDBSCAN(
min_samples=self.min_samples,
min_cluster_size=self.min_cluster_size,
allow_single_cluster=True).fit_predict(umap_X)
return labels
class ClusterBackend(nn.Module):
r"""Perfom clustering for input embeddings and output the labels.
Args:
model_dir: A model dir.
model_config: The model config.
"""
def __init__(self):
super().__init__()
self.model_config = {'merge_thr':0.78}
# self.other_config = kwargs
self.spectral_cluster = SpectralCluster()
self.umap_hdbscan_cluster = UmapHdbscan()
def forward(self, X, **params):
# clustering and return the labels
k = params['oracle_num'] if 'oracle_num' in params else None
assert len(
X.shape
) == 2, 'modelscope error: the shape of input should be [N, C]'
if X.shape[0] < 20:
return np.zeros(X.shape[0], dtype='int')
if X.shape[0] < 2048 or k is not None:
labels = self.spectral_cluster(X, k)
else:
labels = self.umap_hdbscan_cluster(X)
if k is None and 'merge_thr' in self.model_config:
labels = self.merge_by_cos(labels, X,
self.model_config['merge_thr'])
return labels
def merge_by_cos(self, labels, embs, cos_thr):
# merge the similar speakers by cosine similarity
assert cos_thr > 0 and cos_thr <= 1
while True:
spk_num = labels.max() + 1
if spk_num == 1:
break
spk_center = []
for i in range(spk_num):
spk_emb = embs[labels == i].mean(0)
spk_center.append(spk_emb)
assert len(spk_center) > 0
spk_center = np.stack(spk_center, axis=0)
norm_spk_center = spk_center / np.linalg.norm(
spk_center, axis=1, keepdims=True)
affinity = np.matmul(norm_spk_center, norm_spk_center.T)
affinity = np.triu(affinity, 1)
spks = np.unravel_index(np.argmax(affinity), affinity.shape)
if affinity[spks] < cos_thr:
break
for i in range(len(labels)):
if labels[i] == spks[1]:
labels[i] = spks[0]
elif labels[i] > spks[1]:
labels[i] -= 1
return labels

View File

@ -109,13 +109,9 @@ class CAMPPlus(nn.Module):
audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_feature(audio_sample_list)
speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
# import pdb; pdb.set_trace()
results = []
embeddings = self.forward(speech)
for embedding in embeddings:
results.append({"spk_embedding":embedding})
meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
results = [{"spk_embedding": self.forward(speech)}]
return results, meta_data

View File

@ -2,23 +2,19 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import io
from typing import Union
import librosa as sf
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
import contextlib
import os
import torch
import requests
import tempfile
from abc import ABCMeta, abstractmethod
import contextlib
import numpy as np
import librosa as sf
from typing import Union
from pathlib import Path
from typing import Generator, Union
import requests
from abc import ABCMeta, abstractmethod
import torchaudio.compliance.kaldi as Kaldi
from funasr.models.transformer.utils.nets_utils import pad_list
def check_audio_list(audio: list):
@ -105,15 +101,19 @@ def sv_chunk(vad_segments: list, fs = 16000) -> list:
def extract_feature(audio):
features = []
feature_times = []
feature_lengths = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
feature_lengths.append(au.shape[0])
features = torch.cat(features)
return features, feature_lengths
features.append(feature)
feature_times.append(au.shape[0])
feature_lengths.append(feature.shape[0])
# padding for batch inference
features_padded = pad_list(features, pad_value=0)
# features = torch.cat(features)
return features_padded, feature_lengths, feature_times
def postprocess(segments: list, vad_segments: list,
@ -195,8 +195,8 @@ def smooth(res, mindur=1):
def distribute_spk(sentence_list, sd_time_list):
sd_sentence_list = []
for d in sentence_list:
sentence_start = d['ts_list'][0][0]
sentence_end = d['ts_list'][-1][1]
sentence_start = d['start']
sentence_end = d['end']
sentence_spk = 0
max_overlap = 0
for sd_time in sd_time_list:
@ -213,8 +213,6 @@ def distribute_spk(sentence_list, sd_time_list):
return sd_sentence_list
class Storage(metaclass=ABCMeta):
"""Abstract class of storage.

View File

@ -239,6 +239,7 @@ class CTTransformer(nn.Module):
cache_pop_trigger_limit = 200
results = []
meta_data = {}
punc_array = None
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
@ -320,8 +321,13 @@ class CTTransformer(nn.Module):
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
# keep a punctuations array for punc segment
if punc_array is None:
punc_array = punctuations
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
result_i = {"key": key[0], "text": new_mini_sentence_out}
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i)
return results, meta_data

View File

@ -98,14 +98,14 @@ def ts_prediction_lfr6_standard(us_alphas,
return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed):
punc_list = ['', '', '', '']
res = []
if text_postprocessed is None:
return res
if time_stamp_postprocessed is None:
if timestamp_postprocessed is None:
return res
if len(time_stamp_postprocessed) == 0:
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
@ -113,23 +113,22 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
if punc_id_list is None or len(punc_id_list) == 0:
res.append({
'text': text_postprocessed.split(),
"start": time_stamp_postprocessed[0][0],
"end": time_stamp_postprocessed[-1][1],
'text_seg': text_postprocessed.split(),
"ts_list": time_stamp_postprocessed,
"start": timestamp_postprocessed[0][0],
"end": timestamp_postprocessed[-1][1],
"timestamp": timestamp_postprocessed,
})
return res
if len(punc_id_list) != len(time_stamp_postprocessed):
print(" warning length mistach!!!!!!")
if len(punc_id_list) != len(timestamp_postprocessed):
logging.warning("length mismatch between punc and timestamp")
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = time_stamp_postprocessed[0][0]
sentence_end = time_stamp_postprocessed[0][1]
sentence_start = timestamp_postprocessed[0][0]
sentence_end = timestamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
punc_stamp_text_list = list(zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None))
for punc_stamp_text in punc_stamp_text_list:
punc_id, time_stamp, text = punc_stamp_text
punc_id, timestamp, text = punc_stamp_text
# sentence_text += text if text is not None else ''
if text is not None:
if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z':
@ -139,10 +138,10 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
else:
sentence_text += text
sentence_text_seg += text + ' '
ts_list.append(time_stamp)
ts_list.append(timestamp)
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
sentence_end = timestamp[1] if timestamp is not None else sentence_end
if punc_id > 1:
sentence_text += punc_list[punc_id - 2]
@ -150,8 +149,7 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
'text': sentence_text,
"start": sentence_start,
"end": sentence_end,
"text_seg": sentence_text_seg,
"ts_list": ts_list
"timestamp": ts_list
})
sentence_text = ''
sentence_text_seg = ''
@ -160,181 +158,4 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
return res
# class AverageShiftCalculator():
# def __init__(self):
# logging.warning("Calculating average shift.")
# def __call__(self, file1, file2):
# uttid_list1, ts_dict1 = self.read_timestamps(file1)
# uttid_list2, ts_dict2 = self.read_timestamps(file2)
# uttid_intersection = self._intersection(uttid_list1, uttid_list2)
# res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
# logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
# logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
#
# def _intersection(self, list1, list2):
# set1 = set(list1)
# set2 = set(list2)
# if set1 == set2:
# logging.warning("Uttid same checked.")
# return set1
# itsc = list(set1 & set2)
# logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
# return itsc
#
# def read_timestamps(self, file):
# # read timestamps file in standard format
# uttid_list = []
# ts_dict = {}
# with codecs.open(file, 'r') as fin:
# for line in fin.readlines():
# text = ''
# ts_list = []
# line = line.rstrip()
# uttid = line.split()[0]
# uttid_list.append(uttid)
# body = " ".join(line.split()[1:])
# for pd in body.split(';'):
# if not len(pd): continue
# # pdb.set_trace()
# char, start, end = pd.lstrip(" ").split(' ')
# text += char + ','
# ts_list.append((float(start), float(end)))
# # ts_lists.append(ts_list)
# ts_dict[uttid] = (text[:-1], ts_list)
# logging.warning("File {} read done.".format(file))
# return uttid_list, ts_dict
#
# def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time = 0
# for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
# num_tokens = len(filtered_timestamp_list1)
# return shift_time, num_tokens
#
# # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# # # calculate average shift between timestamp1 and timestamp2
# # # when characters differ, use edit distance alignment
# # # and calculate the error between the same characters
# # self._accumlated_shift = 0
# # self._accumlated_tokens = 0
# # self.max_shift = 0
# # self.max_shift_uttid = None
# # for uttid in uttid_list:
# # (t1, ts1) = ts_dict1[uttid]
# # (t2, ts2) = ts_dict2[uttid]
# # _align, _align2, _align3 = [], [], []
# # fts1, fts2 = [], []
# # _t1, _t2 = [], []
# # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
# # s = sm.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align.append(0)
# # if s[j][0] == "replace" or s[j][0] == "delete":
# # _align3.append(0)
# # elif s[j][0] == "equal":
# # _align.append(1)
# # _align3.append(1)
# # else:
# # continue
# # # use s to index t2
# # for a, ts , t in zip(_align, ts2, t2.split(',')):
# # if a:
# # fts2.append(ts)
# # _t2.append(t)
# # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
# # s = sm2.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align2.append(0)
# # elif s[j][0] == "equal":
# # _align2.append(1)
# # else:
# # continue
# # # use s2 tp index t1
# # for a, ts, t in zip(_align3, ts1, t1.split(',')):
# # if a:
# # fts1.append(ts)
# # _t1.append(t)
# # if len(fts1) == len(fts2):
# # shift_time, num_tokens = self._shift(fts1, fts2)
# # self._accumlated_shift += shift_time
# # self._accumlated_tokens += num_tokens
# # if shift_time/num_tokens > self.max_shift:
# # self.max_shift = shift_time/num_tokens
# # self.max_shift_uttid = uttid
# # else:
# # logging.warning("length mismatch")
# # return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
from funasr.models.paraformer.cif_predictor import cif_wo_hidden
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
for line1, line2 in zip(f1.readlines(), f2.readlines()):
line1 = line1.rstrip()
line2 = line2.rstrip()
assert line1.split()[0] == line2.split()[0]
uttid = line1.split()[0]
alphas = [float(i) for i in line1.split()[1:]]
new_alphas = np.array(remove_chunk_padding(alphas))
new_alphas[-1] += 1e-4
text = line2.split()[1:]
if len(text) + 1 != int(new_alphas.sum()):
# force resize
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
if " " in text:
text = text.split()
else:
text = [i for i in text]
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
force_time_shift=-7.0,
sil_in_str=False)
f3.write("{} {}\n".format(uttid, res_str))
def remove_chunk_padding(alphas):
# remove the padding part in alphas if using chunk paraformer for GPU
START_ZERO = 45
MID_ZERO = 75
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
alphas = alphas[START_ZERO:] # remove the padding at beginning
new_alphas = []
while True:
new_alphas = new_alphas + alphas[:REAL_FRAMES]
alphas = alphas[REAL_FRAMES+MID_ZERO:]
if len(alphas) < REAL_FRAMES: break
return new_alphas
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
# if args.mode == 'cal_aas':
# asc = AverageShiftCalculator()
# asc(args.input, args.input2)
if args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='timestamp tools')
parser.add_argument('--mode',
default=None,
type=str,
choices=SUPPORTED_MODES,
help='timestamp related toolbox')
parser.add_argument('--input', default=None, type=str, help='input file path')
parser.add_argument('--output', default=None, type=str, help='output file name')
parser.add_argument('--input2', default=None, type=str, help='input2 file path')
parser.add_argument('--kaldi-ts-type',
default='v2',
type=str,
choices=['v0', 'v1', 'v2'],
help='kaldi timestamp to write')
args = parser.parse_args()
main(args)