update asr with speaker

This commit is contained in:
shixian.shi 2024-01-11 17:03:00 +08:00
parent 78ffd04ac9
commit 7037971392
7 changed files with 747 additions and 671 deletions

View File

@ -6,12 +6,28 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
print(res)
'''try asr with speaker label with
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
spk_mode='punc_segment',
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
'''

View File

@ -1,453 +1,501 @@
import os.path
import torch
import numpy as np
import hydra
import json
from omegaconf import DictConfig, OmegaConf, ListConfig
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.load_utils import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
import time
import torch
import hydra
import random
import string
from funasr.register import tables
import logging
import os.path
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.timestamp_tools import time_stamp_sentence
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
:param input:
:param input_len:
:param data_type:
:param frontend:
:return:
"""
data_list = []
key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
with open(data_in, encoding='utf-8') as fin:
for line in fin:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
lines = json.loads(line.strip())
data = lines["source"]
key = data["key"] if "key" in data else key
else: # filelist, wav.scp, text.txt: id \t data or data
lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines)>1 else lines[0]
key = lines[0] if len(lines)>1 else key
data_list.append(data)
key_list.append(key)
else:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
elif isinstance(data_in, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs
data_list_tmp = []
for data_in_i, data_type_i in zip(data_in, data_type):
key_list, data_list_i = prepare_data_iterator(data_in=data_in_i, data_type=data_type_i)
data_list_tmp.append(data_list_i)
data_list = []
for item in zip(*data_list_tmp):
data_list.append(item)
else:
# [audio sample point, fbank, text]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
if key is None:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
return key_list, data_list
"""
:param input:
:param input_len:
:param data_type:
:param frontend:
:return:
"""
data_list = []
key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
with open(data_in, encoding='utf-8') as fin:
for line in fin:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
lines = json.loads(line.strip())
data = lines["source"]
key = data["key"] if "key" in data else key
else: # filelist, wav.scp, text.txt: id \t data or data
lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines)>1 else lines[0]
key = lines[0] if len(lines)>1 else key
data_list.append(data)
key_list.append(key)
else:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
elif isinstance(data_in, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs
data_list_tmp = []
for data_in_i, data_type_i in zip(data_in, data_type):
key_list, data_list_i = prepare_data_iterator(data_in=data_in_i, data_type=data_type_i)
data_list_tmp.append(data_list_i)
data_list = []
for item in zip(*data_list_tmp):
data_list.append(item)
else:
# [audio sample point, fbank, text]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
if key is None:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
return key_list, data_list
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)
logging.basicConfig(level=log_level)
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
model = AutoModel(**kwargs)
res = model(input=kwargs["input"])
print(res)
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
model = AutoModel(**kwargs)
res = model(input=kwargs["input"])
print(res)
class AutoModel:
def __init__(self, **kwargs):
tables.print()
model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None
vad_model = kwargs.get("vad_model", None)
vad_kwargs = kwargs.get("vad_model_revision", None)
if vad_model is not None:
print("build vad model")
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
def __init__(self, **kwargs):
tables.print()
model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None
vad_model = kwargs.get("vad_model", None)
vad_kwargs = kwargs.get("vad_model_revision", None)
if vad_model is not None:
print("build vad model")
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
# if punc_model is not None, build punc model else None
punc_model = kwargs.get("punc_model", None)
punc_kwargs = kwargs.get("punc_model_revision", None)
if punc_model is not None:
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
self.punc_model = punc_model
self.punc_kwargs = punc_kwargs
# if punc_model is not None, build punc model else None
punc_model = kwargs.get("punc_model", None)
punc_kwargs = kwargs.get("punc_model_revision", None)
if punc_model is not None:
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
def build_model(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
device = "cpu"
# kwargs["batch_size"] = 1
kwargs["device"] = device
if kwargs.get("ncpu", None):
torch.set_num_threads(kwargs.get("ncpu"))
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
vocab_size = len(tokenizer.token_list)
else:
vocab_size = -1
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# build model
model_class = tables.model_classes.get(kwargs["model"].lower())
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()
model.to(device)
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
init_param=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
)
return model, kwargs
def __call__(self, input, input_len=None, **cfg):
if self.vad_model is None:
return self.generate(input, input_len=input_len, **cfg)
else:
return self.generate_with_vad(input, input_len=input_len, **cfg)
def generate(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
# import pdb; pdb.set_trace()
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
model = self.model if model is None else model
# if spk_model is not None, build spk model else None
spk_model = kwargs.get("spk_model", None)
spk_kwargs = kwargs.get("spk_model_revision", None)
if spk_model is not None:
spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend()
spk_mode = kwargs.get("spk_mode", 'punc_segment')
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
self.spk_mode = spk_mode
logging.warning("Many to print when using speaker model...")
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
self.punc_model = punc_model
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
def build_model(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
device = "cpu"
# kwargs["batch_size"] = 1
kwargs["device"] = device
if kwargs.get("ncpu", None):
torch.set_num_threads(kwargs.get("ncpu"))
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
vocab_size = len(tokenizer.token_list)
else:
vocab_size = -1
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# build model
model_class = tables.model_classes.get(kwargs["model"].lower())
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()
model.to(device)
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
init_param=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
)
return model, kwargs
def __call__(self, input, input_len=None, **cfg):
if self.vad_model is None:
return self.generate(input, input_len=input_len, **cfg)
else:
return self.generate_with_vad(input, input_len=input_len, **cfg)
def generate(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
model = self.model if model is None else model
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
time_speech_total = 0.0
time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
time1 = time.perf_counter()
with torch.no_grad():
results, meta_data = model.generate(**batch, **kwargs)
time2 = time.perf_counter()
asr_result_list.extend(results)
pbar.update(1)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
speed_stats["forward"] = f"{time_escape:0.3f}"
speed_stats["batch_size"] = f"{len(results)}"
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = (
f"{speed_stats}, "
)
pbar.set_description(description)
time_speech_total += batch_data_time
time_escape_total += time_escape
pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list
def generate_with_vad(self, input, input_len=None, **cfg):
# step.1: compute the vad model
model = self.vad_model
kwargs = self.vad_kwargs
kwargs.update(cfg)
beg_vad = time.time()
res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
time_speech_total = 0.0
time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
time1 = time.perf_counter()
with torch.no_grad():
results, meta_data = model.generate(**batch, **kwargs)
time2 = time.perf_counter()
asr_result_list.extend(results)
pbar.update(1)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
speed_stats["forward"] = f"{time_escape:0.3f}"
speed_stats["batch_size"] = f"{len(results)}"
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = (
f"{speed_stats}, "
)
pbar.set_description(description)
time_speech_total += batch_data_time
time_escape_total += time_escape
pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list
def generate_with_vad(self, input, input_len=None, **cfg):
# step.1: compute the vad model
model = self.vad_model
kwargs = self.vad_kwargs
kwargs.update(cfg)
beg_vad = time.time()
res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
vad_res = res
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
# step.2 compute asr model
model = self.model
kwargs = self.kwargs
kwargs.update(cfg)
batch_size = int(kwargs.get("batch_size_s", 300))*1000
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
kwargs["batch_size"] = batch_size
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
results_ret_list = []
time_speech_total_all_samples = 0.0
# step.2 compute asr model
model = self.model
kwargs = self.kwargs
kwargs.update(cfg)
batch_size = int(kwargs.get("batch_size_s", 300))*1000
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
kwargs["batch_size"] = batch_size
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
results_ret_list = []
time_speech_total_all_samples = 0.0
beg_total = time.time()
pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
for i in range(len(res)):
key = res[i]["key"]
vadsegments = res[i]["value"]
input_i = data_list[i]
speech = load_audio_text_image_video(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
speech_lengths = len(speech)
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
beg_total = time.time()
pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
for i in range(len(res)):
key = res[i]["key"]
vadsegments = res[i]["value"]
input_i = data_list[i]
speech = load_audio_text_image_video(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
speech_lengths = len(speech)
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
# if kwargs["device"] == "cpu":
# batch_size = 0
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
batch_size_ms_cum = 0
beg_idx = 0
beg_asr_total = time.time()
time_speech_total_per_sample = speech_lengths/16000
time_speech_total_all_samples += time_speech_total_per_sample
# if kwargs["device"] == "cpu":
# batch_size = 0
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
batch_size_ms_cum = 0
beg_idx = 0
beg_asr_total = time.time()
time_speech_total_per_sample = speech_lengths/16000
time_speech_total_all_samples += time_speech_total_per_sample
for j, _ in enumerate(range(0, n)):
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n - 1 and (
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms:
continue
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
beg_idx = end_idx
results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if len(results) < 1:
continue
results_sorted.extend(results)
for j, _ in enumerate(range(0, n)):
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n - 1 and (
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms:
continue
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
all_segments = []
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)):
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
speech_j[_b]]]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
continue
results_sorted.extend(results)
pbar_total.update(1)
end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
pbar_total.update(1)
end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
restored_data[index] = results_sorted[j]
result = {}
for j in range(n):
for k, v in restored_data[j].items():
if not k.startswith("timestamp"):
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += restored_data[j][k]
else:
result[k] = []
for t in restored_data[j][k]:
t[0] += vadsegments[j][0]
t[1] += vadsegments[j][0]
result[k] += restored_data[j][k]
result["key"] = key
results_ret_list.append(result)
pbar_total.update(1)
# step.3 compute punc model
model = self.punc_model
kwargs = self.punc_kwargs
kwargs.update(cfg)
for i, result in enumerate(results_ret_list):
beg_punc = time.time()
res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg)
end_punc = time.time()
print(f"time punc: {end_punc - beg_punc:0.3f}")
# sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"])
# results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"]
# results_ret_list[i]["sentences"] = sentences
results_ret_list[i]["text_with_punc"] = res[i]["text"]
pbar_total.update(1)
end_total = time.time()
time_escape_total_all_samples = end_total - beg_total
pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
return results_ret_list
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
restored_data[index] = results_sorted[j]
result = {}
# results combine for texts, timestamps, speaker embeddings and others
# TODO: rewrite for clean code
for j in range(n):
for k, v in restored_data[j].items():
if k.startswith("timestamp"):
if k not in result:
result[k] = []
for t in restored_data[j][k]:
t[0] += vadsegments[j][0]
t[1] += vadsegments[j][0]
result[k].extend(restored_data[j][k])
elif k == 'spk_embedding':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
elif k == 'text':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += " " + restored_data[j][k]
else:
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += restored_data[j][k]
# step.3 compute punc model
if self.punc_model is not None:
self.punc_kwargs.update(cfg)
punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
result["text_with_punc"] = punc_res[0]["text"]
# speaker embedding cluster after resorted
if self.spk_model is not None:
all_segments = sorted(all_segments, key=lambda x: x[0])
spk_embedding = result['spk_embedding']
labels = self.cb_model(spk_embedding)
del result['spk_embedding']
sv_output = postprocess(all_segments, None, labels, spk_embedding)
if self.spk_mode == 'vad_segment':
sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments):
sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1],
"sentence": res['text'],
"timestamp": res['timestamp']})
else: # punc_segment
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['text'])
distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list
result["key"] = key
results_ret_list.append(result)
pbar_total.update(1)
pbar_total.update(1)
end_total = time.time()
time_escape_total_all_samples = end_total - beg_total
pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
return results_ret_list
class AutoFrontend:
def __init__(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
def __init__(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
self.frontend = frontend
if "frontend" in kwargs:
del kwargs["frontend"]
self.kwargs = kwargs
self.frontend = frontend
if "frontend" in kwargs:
del kwargs["frontend"]
self.kwargs = kwargs
def __call__(self, input, input_len=None, kwargs=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
def __call__(self, input, input_len=None, kwargs=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
batch_size = kwargs.get("batch_size", 1)
device = kwargs.get("device", "cpu")
if device == "cpu":
batch_size = 1
meta_data = {}
result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
time0 = time.perf_counter()
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
batch_size = kwargs.get("batch_size", 1)
device = kwargs.get("device", "cpu")
if device == "cpu":
batch_size = 1
meta_data = {}
result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
time0 = time.perf_counter()
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend, **kwargs)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
speech.to(device=device), speech_lengths.to(device=device)
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
result_list.append(batch)
pbar.update(1)
description = (
f"{meta_data}, "
)
pbar.set_description(description)
time_end = time.perf_counter()
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
return result_list
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend, **kwargs)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
speech.to(device=device), speech_lengths.to(device=device)
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
result_list.append(batch)
pbar.update(1)
description = (
f"{meta_data}, "
)
pbar.set_description(description)
time_end = time.perf_counter()
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
return result_list
if __name__ == '__main__':
main_hydra()
main_hydra()

View File

@ -0,0 +1,191 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union
import hdbscan
import numpy as np
import scipy
import sklearn
import umap
from sklearn.cluster._kmeans import k_means
from torch import nn
class SpectralCluster:
r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
This implementation is adapted from https://github.com/speechbrain/speechbrain.
"""
def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
self.min_num_spks = min_num_spks
self.max_num_spks = max_num_spks
self.pval = pval
def __call__(self, X, oracle_num=None):
# Similarity matrix computation
sim_mat = self.get_sim_mat(X)
# Refining similarity matrix with pval
prunned_sim_mat = self.p_pruning(sim_mat)
# Symmetrization
sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
# Laplacian calculation
laplacian = self.get_laplacian(sym_prund_sim_mat)
# Get Spectral Embeddings
emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
# Perform clustering
labels = self.cluster_embs(emb, num_of_spk)
return labels
def get_sim_mat(self, X):
# Cosine similarities
M = sklearn.metrics.pairwise.cosine_similarity(X, X)
return M
def p_pruning(self, A):
if A.shape[0] * self.pval < 6:
pval = 6. / A.shape[0]
else:
pval = self.pval
n_elems = int((1 - pval) * A.shape[0])
# For each row in a affinity matrix
for i in range(A.shape[0]):
low_indexes = np.argsort(A[i, :])
low_indexes = low_indexes[0:n_elems]
# Replace smaller similarity values by 0s
A[i, low_indexes] = 0
return A
def get_laplacian(self, M):
M[np.diag_indices(M.shape[0])] = 0
D = np.sum(np.abs(M), axis=1)
D = np.diag(D)
L = D - M
return L
def get_spec_embs(self, L, k_oracle=None):
lambdas, eig_vecs = scipy.linalg.eigh(L)
if k_oracle is not None:
num_of_spk = k_oracle
else:
lambda_gap_list = self.getEigenGaps(
lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
emb = eig_vecs[:, :num_of_spk]
return emb, num_of_spk
def cluster_embs(self, emb, k):
_, labels, _ = k_means(emb, k)
return labels
def getEigenGaps(self, eig_vals):
eig_vals_gap_list = []
for i in range(len(eig_vals) - 1):
gap = float(eig_vals[i + 1]) - float(eig_vals[i])
eig_vals_gap_list.append(gap)
return eig_vals_gap_list
class UmapHdbscan:
r"""
Reference:
- Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
Emphasis On Topological Structure. ICASSP2022
"""
def __init__(self,
n_neighbors=20,
n_components=60,
min_samples=10,
min_cluster_size=10,
metric='cosine'):
self.n_neighbors = n_neighbors
self.n_components = n_components
self.min_samples = min_samples
self.min_cluster_size = min_cluster_size
self.metric = metric
def __call__(self, X):
umap_X = umap.UMAP(
n_neighbors=self.n_neighbors,
min_dist=0.0,
n_components=min(self.n_components, X.shape[0] - 2),
metric=self.metric,
).fit_transform(X)
labels = hdbscan.HDBSCAN(
min_samples=self.min_samples,
min_cluster_size=self.min_cluster_size,
allow_single_cluster=True).fit_predict(umap_X)
return labels
class ClusterBackend(nn.Module):
r"""Perfom clustering for input embeddings and output the labels.
Args:
model_dir: A model dir.
model_config: The model config.
"""
def __init__(self):
super().__init__()
self.model_config = {'merge_thr':0.78}
# self.other_config = kwargs
self.spectral_cluster = SpectralCluster()
self.umap_hdbscan_cluster = UmapHdbscan()
def forward(self, X, **params):
# clustering and return the labels
k = params['oracle_num'] if 'oracle_num' in params else None
assert len(
X.shape
) == 2, 'modelscope error: the shape of input should be [N, C]'
if X.shape[0] < 20:
return np.zeros(X.shape[0], dtype='int')
if X.shape[0] < 2048 or k is not None:
labels = self.spectral_cluster(X, k)
else:
labels = self.umap_hdbscan_cluster(X)
if k is None and 'merge_thr' in self.model_config:
labels = self.merge_by_cos(labels, X,
self.model_config['merge_thr'])
return labels
def merge_by_cos(self, labels, embs, cos_thr):
# merge the similar speakers by cosine similarity
assert cos_thr > 0 and cos_thr <= 1
while True:
spk_num = labels.max() + 1
if spk_num == 1:
break
spk_center = []
for i in range(spk_num):
spk_emb = embs[labels == i].mean(0)
spk_center.append(spk_emb)
assert len(spk_center) > 0
spk_center = np.stack(spk_center, axis=0)
norm_spk_center = spk_center / np.linalg.norm(
spk_center, axis=1, keepdims=True)
affinity = np.matmul(norm_spk_center, norm_spk_center.T)
affinity = np.triu(affinity, 1)
spks = np.unravel_index(np.argmax(affinity), affinity.shape)
if affinity[spks] < cos_thr:
break
for i in range(len(labels)):
if labels[i] == spks[1]:
labels[i] = spks[0]
elif labels[i] > spks[1]:
labels[i] -= 1
return labels

View File

@ -109,13 +109,9 @@ class CAMPPlus(nn.Module):
audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_feature(audio_sample_list)
speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
# import pdb; pdb.set_trace()
results = []
embeddings = self.forward(speech)
for embedding in embeddings:
results.append({"spk_embedding":embedding})
meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
results = [{"spk_embedding": self.forward(speech)}]
return results, meta_data

View File

@ -2,23 +2,19 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import io
from typing import Union
import librosa as sf
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
import contextlib
import os
import torch
import requests
import tempfile
from abc import ABCMeta, abstractmethod
import contextlib
import numpy as np
import librosa as sf
from typing import Union
from pathlib import Path
from typing import Generator, Union
import requests
from abc import ABCMeta, abstractmethod
import torchaudio.compliance.kaldi as Kaldi
from funasr.models.transformer.utils.nets_utils import pad_list
def check_audio_list(audio: list):
@ -40,31 +36,31 @@ def check_audio_list(audio: list):
def sv_preprocess(inputs: Union[np.ndarray, list]):
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(
inputs[i].shape
) == 1, 'modelscope error: Input array should be [N, T]'
data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32')
else:
data = data.astype('float32')
data = torch.from_numpy(data)
else:
raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.'
)
output.append(data)
return output
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(
inputs[i].shape
) == 1, 'modelscope error: Input array should be [N, T]'
data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32')
else:
data = data.astype('float32')
data = torch.from_numpy(data)
else:
raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.'
)
output.append(data)
return output
def sv_chunk(vad_segments: list, fs = 16000) -> list:
@ -105,15 +101,19 @@ def sv_chunk(vad_segments: list, fs = 16000) -> list:
def extract_feature(audio):
features = []
feature_times = []
feature_lengths = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
feature_lengths.append(au.shape[0])
features = torch.cat(features)
return features, feature_lengths
features.append(feature)
feature_times.append(au.shape[0])
feature_lengths.append(feature.shape[0])
# padding for batch inference
features_padded = pad_list(features, pad_value=0)
# features = torch.cat(features)
return features_padded, feature_lengths, feature_times
def postprocess(segments: list, vad_segments: list,
@ -195,8 +195,8 @@ def smooth(res, mindur=1):
def distribute_spk(sentence_list, sd_time_list):
sd_sentence_list = []
for d in sentence_list:
sentence_start = d['ts_list'][0][0]
sentence_end = d['ts_list'][-1][1]
sentence_start = d['start']
sentence_end = d['end']
sentence_spk = 0
max_overlap = 0
for sd_time in sd_time_list:
@ -213,8 +213,6 @@ def distribute_spk(sentence_list, sd_time_list):
return sd_sentence_list
class Storage(metaclass=ABCMeta):
"""Abstract class of storage.

View File

@ -239,6 +239,7 @@ class CTTransformer(nn.Module):
cache_pop_trigger_limit = 200
results = []
meta_data = {}
punc_array = None
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
@ -320,8 +321,13 @@ class CTTransformer(nn.Module):
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
result_i = {"key": key[0], "text": new_mini_sentence_out}
# keep a punctuations array for punc segment
if punc_array is None:
punc_array = punctuations
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i)
return results, meta_data

View File

@ -98,14 +98,14 @@ def ts_prediction_lfr6_standard(us_alphas,
return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed):
punc_list = ['', '', '', '']
res = []
if text_postprocessed is None:
return res
if time_stamp_postprocessed is None:
if timestamp_postprocessed is None:
return res
if len(time_stamp_postprocessed) == 0:
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
@ -113,23 +113,22 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
if punc_id_list is None or len(punc_id_list) == 0:
res.append({
'text': text_postprocessed.split(),
"start": time_stamp_postprocessed[0][0],
"end": time_stamp_postprocessed[-1][1],
'text_seg': text_postprocessed.split(),
"ts_list": time_stamp_postprocessed,
"start": timestamp_postprocessed[0][0],
"end": timestamp_postprocessed[-1][1],
"timestamp": timestamp_postprocessed,
})
return res
if len(punc_id_list) != len(time_stamp_postprocessed):
print(" warning length mistach!!!!!!")
if len(punc_id_list) != len(timestamp_postprocessed):
logging.warning("length mismatch between punc and timestamp")
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = time_stamp_postprocessed[0][0]
sentence_end = time_stamp_postprocessed[0][1]
sentence_start = timestamp_postprocessed[0][0]
sentence_end = timestamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
punc_stamp_text_list = list(zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None))
for punc_stamp_text in punc_stamp_text_list:
punc_id, time_stamp, text = punc_stamp_text
punc_id, timestamp, text = punc_stamp_text
# sentence_text += text if text is not None else ''
if text is not None:
if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z':
@ -139,10 +138,10 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
else:
sentence_text += text
sentence_text_seg += text + ' '
ts_list.append(time_stamp)
ts_list.append(timestamp)
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
sentence_end = timestamp[1] if timestamp is not None else sentence_end
if punc_id > 1:
sentence_text += punc_list[punc_id - 2]
@ -150,8 +149,7 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
'text': sentence_text,
"start": sentence_start,
"end": sentence_end,
"text_seg": sentence_text_seg,
"ts_list": ts_list
"timestamp": ts_list
})
sentence_text = ''
sentence_text_seg = ''
@ -160,181 +158,4 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
return res
# class AverageShiftCalculator():
# def __init__(self):
# logging.warning("Calculating average shift.")
# def __call__(self, file1, file2):
# uttid_list1, ts_dict1 = self.read_timestamps(file1)
# uttid_list2, ts_dict2 = self.read_timestamps(file2)
# uttid_intersection = self._intersection(uttid_list1, uttid_list2)
# res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
# logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
# logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
#
# def _intersection(self, list1, list2):
# set1 = set(list1)
# set2 = set(list2)
# if set1 == set2:
# logging.warning("Uttid same checked.")
# return set1
# itsc = list(set1 & set2)
# logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
# return itsc
#
# def read_timestamps(self, file):
# # read timestamps file in standard format
# uttid_list = []
# ts_dict = {}
# with codecs.open(file, 'r') as fin:
# for line in fin.readlines():
# text = ''
# ts_list = []
# line = line.rstrip()
# uttid = line.split()[0]
# uttid_list.append(uttid)
# body = " ".join(line.split()[1:])
# for pd in body.split(';'):
# if not len(pd): continue
# # pdb.set_trace()
# char, start, end = pd.lstrip(" ").split(' ')
# text += char + ','
# ts_list.append((float(start), float(end)))
# # ts_lists.append(ts_list)
# ts_dict[uttid] = (text[:-1], ts_list)
# logging.warning("File {} read done.".format(file))
# return uttid_list, ts_dict
#
# def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time = 0
# for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
# num_tokens = len(filtered_timestamp_list1)
# return shift_time, num_tokens
#
# # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# # # calculate average shift between timestamp1 and timestamp2
# # # when characters differ, use edit distance alignment
# # # and calculate the error between the same characters
# # self._accumlated_shift = 0
# # self._accumlated_tokens = 0
# # self.max_shift = 0
# # self.max_shift_uttid = None
# # for uttid in uttid_list:
# # (t1, ts1) = ts_dict1[uttid]
# # (t2, ts2) = ts_dict2[uttid]
# # _align, _align2, _align3 = [], [], []
# # fts1, fts2 = [], []
# # _t1, _t2 = [], []
# # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
# # s = sm.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align.append(0)
# # if s[j][0] == "replace" or s[j][0] == "delete":
# # _align3.append(0)
# # elif s[j][0] == "equal":
# # _align.append(1)
# # _align3.append(1)
# # else:
# # continue
# # # use s to index t2
# # for a, ts , t in zip(_align, ts2, t2.split(',')):
# # if a:
# # fts2.append(ts)
# # _t2.append(t)
# # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
# # s = sm2.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align2.append(0)
# # elif s[j][0] == "equal":
# # _align2.append(1)
# # else:
# # continue
# # # use s2 tp index t1
# # for a, ts, t in zip(_align3, ts1, t1.split(',')):
# # if a:
# # fts1.append(ts)
# # _t1.append(t)
# # if len(fts1) == len(fts2):
# # shift_time, num_tokens = self._shift(fts1, fts2)
# # self._accumlated_shift += shift_time
# # self._accumlated_tokens += num_tokens
# # if shift_time/num_tokens > self.max_shift:
# # self.max_shift = shift_time/num_tokens
# # self.max_shift_uttid = uttid
# # else:
# # logging.warning("length mismatch")
# # return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
from funasr.models.paraformer.cif_predictor import cif_wo_hidden
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
for line1, line2 in zip(f1.readlines(), f2.readlines()):
line1 = line1.rstrip()
line2 = line2.rstrip()
assert line1.split()[0] == line2.split()[0]
uttid = line1.split()[0]
alphas = [float(i) for i in line1.split()[1:]]
new_alphas = np.array(remove_chunk_padding(alphas))
new_alphas[-1] += 1e-4
text = line2.split()[1:]
if len(text) + 1 != int(new_alphas.sum()):
# force resize
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
if " " in text:
text = text.split()
else:
text = [i for i in text]
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
force_time_shift=-7.0,
sil_in_str=False)
f3.write("{} {}\n".format(uttid, res_str))
def remove_chunk_padding(alphas):
# remove the padding part in alphas if using chunk paraformer for GPU
START_ZERO = 45
MID_ZERO = 75
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
alphas = alphas[START_ZERO:] # remove the padding at beginning
new_alphas = []
while True:
new_alphas = new_alphas + alphas[:REAL_FRAMES]
alphas = alphas[REAL_FRAMES+MID_ZERO:]
if len(alphas) < REAL_FRAMES: break
return new_alphas
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
# if args.mode == 'cal_aas':
# asc = AverageShiftCalculator()
# asc(args.input, args.input2)
if args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='timestamp tools')
parser.add_argument('--mode',
default=None,
type=str,
choices=SUPPORTED_MODES,
help='timestamp related toolbox')
parser.add_argument('--input', default=None, type=str, help='input file path')
parser.add_argument('--output', default=None, type=str, help='output file name')
parser.add_argument('--input2', default=None, type=str, help='input2 file path')
parser.add_argument('--kaldi-ts-type',
default='v2',
type=str,
choices=['v0', 'v1', 'v2'],
help='kaldi timestamp to write')
args = parser.parse_args()
main(args)