update asr with speaker

This commit is contained in:
shixian.shi 2024-01-11 17:03:00 +08:00
parent 78ffd04ac9
commit 7037971392
7 changed files with 747 additions and 671 deletions

View File

@ -6,12 +6,28 @@
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0", model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0", vad_model_revision="v2.0.0",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0", punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
) )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60) res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res) print(res)
'''try asr with speaker label with
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
spk_mode='punc_segment',
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
'''

View File

@ -1,453 +1,501 @@
import os.path
import torch
import numpy as np
import hydra
import json import json
from omegaconf import DictConfig, OmegaConf, ListConfig
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.load_utils import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
import time import time
import torch
import hydra
import random import random
import string import string
from funasr.register import tables import logging
import os.path
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank from funasr.register import tables
from funasr.utils.vad_utils import slice_padding_audio_samples from funasr.utils.load_utils import load_bytes
from funasr.utils.timestamp_tools import time_stamp_sentence
from funasr.download.file import download_from_url from funasr.download.file import download_from_url
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
""" """
:param input: :param input:
:param input_len: :param input_len:
:param data_type: :param data_type:
:param frontend: :param frontend:
:return: :return:
""" """
data_list = [] data_list = []
key_list = [] key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl"] filelist = [".scp", ".txt", ".json", ".jsonl"]
chars = string.ascii_letters + string.digits chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in) data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt; if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in) _, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower() file_extension = file_extension.lower()
if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt; if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
with open(data_in, encoding='utf-8') as fin: with open(data_in, encoding='utf-8') as fin:
for line in fin: for line in fin:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13)) key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data}) if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
lines = json.loads(line.strip()) lines = json.loads(line.strip())
data = lines["source"] data = lines["source"]
key = data["key"] if "key" in data else key key = data["key"] if "key" in data else key
else: # filelist, wav.scp, text.txt: id \t data or data else: # filelist, wav.scp, text.txt: id \t data or data
lines = line.strip().split(maxsplit=1) lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines)>1 else lines[0] data = lines[1] if len(lines)>1 else lines[0]
key = lines[0] if len(lines)>1 else key key = lines[0] if len(lines)>1 else key
data_list.append(data) data_list.append(data)
key_list.append(key) key_list.append(key)
else: else:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13)) key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in] data_list = [data_in]
key_list = [key] key_list = [key]
elif isinstance(data_in, (list, tuple)): elif isinstance(data_in, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs
data_list_tmp = [] data_list_tmp = []
for data_in_i, data_type_i in zip(data_in, data_type): for data_in_i, data_type_i in zip(data_in, data_type):
key_list, data_list_i = prepare_data_iterator(data_in=data_in_i, data_type=data_type_i) key_list, data_list_i = prepare_data_iterator(data_in=data_in_i, data_type=data_type_i)
data_list_tmp.append(data_list_i) data_list_tmp.append(data_list_i)
data_list = [] data_list = []
for item in zip(*data_list_tmp): for item in zip(*data_list_tmp):
data_list.append(item) data_list.append(item)
else: else:
# [audio sample point, fbank, text] # [audio sample point, fbank, text]
data_list = data_in data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))] key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
else: # raw text; audio sample point, fbank; bytes else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in) data_in = load_bytes(data_in)
if key is None: if key is None:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13)) key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in] data_list = [data_in]
key_list = [key] key_list = [key]
return key_list, data_list return key_list, data_list
@hydra.main(config_name=None, version_base=None) @hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig): def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item): def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig): if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True) return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig): elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()} return {k: to_plain_list(v) for k, v in cfg_item.items()}
else: else:
return cfg_item return cfg_item
kwargs = to_plain_list(cfg) kwargs = to_plain_list(cfg)
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level) logging.basicConfig(level=log_level)
if kwargs.get("debug", False): if kwargs.get("debug", False):
import pdb; pdb.set_trace() import pdb; pdb.set_trace()
model = AutoModel(**kwargs) model = AutoModel(**kwargs)
res = model(input=kwargs["input"]) res = model(input=kwargs["input"])
print(res) print(res)
class AutoModel: class AutoModel:
def __init__(self, **kwargs): def __init__(self, **kwargs):
tables.print() tables.print()
model, kwargs = self.build_model(**kwargs) model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None # if vad_model is not None, build vad model else None
vad_model = kwargs.get("vad_model", None) vad_model = kwargs.get("vad_model", None)
vad_kwargs = kwargs.get("vad_model_revision", None) vad_kwargs = kwargs.get("vad_model_revision", None)
if vad_model is not None: if vad_model is not None:
print("build vad model") print("build vad model")
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs} vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
vad_model, vad_kwargs = self.build_model(**vad_kwargs) vad_model, vad_kwargs = self.build_model(**vad_kwargs)
# if punc_model is not None, build punc model else None # if punc_model is not None, build punc model else None
punc_model = kwargs.get("punc_model", None) punc_model = kwargs.get("punc_model", None)
punc_kwargs = kwargs.get("punc_model_revision", None) punc_kwargs = kwargs.get("punc_model_revision", None)
if punc_model is not None: if punc_model is not None:
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs} punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
punc_model, punc_kwargs = self.build_model(**punc_kwargs) punc_model, punc_kwargs = self.build_model(**punc_kwargs)
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
self.punc_model = punc_model
self.punc_kwargs = punc_kwargs
def build_model(self, **kwargs): # if spk_model is not None, build spk model else None
assert "model" in kwargs spk_model = kwargs.get("spk_model", None)
if "model_conf" not in kwargs: spk_kwargs = kwargs.get("spk_model_revision", None)
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) if spk_model is not None:
kwargs = download_model(**kwargs) spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
set_all_random_seed(kwargs.get("seed", 0)) self.cb_model = ClusterBackend()
spk_mode = kwargs.get("spk_mode", 'punc_segment')
device = kwargs.get("device", "cuda") if spk_mode not in ["default", "vad_segment", "punc_segment"]:
if not torch.cuda.is_available() or kwargs.get("ngpu", 0): logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
device = "cpu" self.spk_mode = spk_mode
# kwargs["batch_size"] = 1 logging.warning("Many to print when using speaker model...")
kwargs["device"] = device
self.kwargs = kwargs
if kwargs.get("ncpu", None): self.model = model
torch.set_num_threads(kwargs.get("ncpu")) self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
# build tokenizer self.punc_model = punc_model
tokenizer = kwargs.get("tokenizer", None) self.punc_kwargs = punc_kwargs
if tokenizer is not None: self.spk_model = spk_model
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower()) self.spk_kwargs = spk_kwargs
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list def build_model(self, **kwargs):
vocab_size = len(tokenizer.token_list) assert "model" in kwargs
else: if "model_conf" not in kwargs:
vocab_size = -1 logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
# build frontend
frontend = kwargs.get("frontend", None) set_all_random_seed(kwargs.get("seed", 0))
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower()) device = kwargs.get("device", "cuda")
frontend = frontend_class(**kwargs["frontend_conf"]) if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
kwargs["frontend"] = frontend device = "cpu"
kwargs["input_size"] = frontend.output_size() # kwargs["batch_size"] = 1
kwargs["device"] = device
# build model
model_class = tables.model_classes.get(kwargs["model"].lower()) if kwargs.get("ncpu", None):
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) torch.set_num_threads(kwargs.get("ncpu"))
model.eval()
model.to(device) # build tokenizer
tokenizer = kwargs.get("tokenizer", None)
# init_param if tokenizer is not None:
init_param = kwargs.get("init_param", None) tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
if init_param is not None: tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
logging.info(f"Loading pretrained params from {init_param}") kwargs["tokenizer"] = tokenizer
load_pretrained_model( kwargs["token_list"] = tokenizer.token_list
model=model, vocab_size = len(tokenizer.token_list)
init_param=init_param, else:
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False), vocab_size = -1
oss_bucket=kwargs.get("oss_bucket", None),
) # build frontend
frontend = kwargs.get("frontend", None)
return model, kwargs if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
def __call__(self, input, input_len=None, **cfg): frontend = frontend_class(**kwargs["frontend_conf"])
if self.vad_model is None: kwargs["frontend"] = frontend
return self.generate(input, input_len=input_len, **cfg) kwargs["input_size"] = frontend.output_size()
else: # build model
return self.generate_with_vad(input, input_len=input_len, **cfg) model_class = tables.model_classes.get(kwargs["model"].lower())
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
def generate(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg): model.eval()
# import pdb; pdb.set_trace() model.to(device)
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg) # init_param
model = self.model if model is None else model init_param = kwargs.get("init_param", None)
if init_param is not None:
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
init_param=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
)
return model, kwargs
def __call__(self, input, input_len=None, **cfg):
if self.vad_model is None:
return self.generate(input, input_len=input_len, **cfg)
else:
return self.generate_with_vad(input, input_len=input_len, **cfg)
def generate(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
model = self.model if model is None else model
batch_size = kwargs.get("batch_size", 1) batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu": # if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1 # batch_size = 1
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key) key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
speed_stats = {} speed_stats = {}
asr_result_list = [] asr_result_list = []
num_samples = len(data_list) num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
time_speech_total = 0.0 time_speech_total = 0.0
time_escape_total = 0.0 time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size): for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size) end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx] data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx] key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch} batch = {"data_in": data_batch, "key": key_batch}
if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
batch["data_in"] = data_batch[0] batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len batch["data_lengths"] = input_len
time1 = time.perf_counter() time1 = time.perf_counter()
with torch.no_grad(): with torch.no_grad():
results, meta_data = model.generate(**batch, **kwargs) results, meta_data = model.generate(**batch, **kwargs)
time2 = time.perf_counter() time2 = time.perf_counter()
asr_result_list.extend(results) asr_result_list.extend(results)
pbar.update(1) pbar.update(1)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1) batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1 time_escape = time2 - time1
speed_stats["load_data"] = meta_data.get("load_data", 0.0) speed_stats["load_data"] = meta_data.get("load_data", 0.0)
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0) speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
speed_stats["forward"] = f"{time_escape:0.3f}" speed_stats["forward"] = f"{time_escape:0.3f}"
speed_stats["batch_size"] = f"{len(results)}" speed_stats["batch_size"] = f"{len(results)}"
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}" speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = ( description = (
f"{speed_stats}, " f"{speed_stats}, "
) )
pbar.set_description(description) pbar.set_description(description)
time_speech_total += batch_data_time time_speech_total += batch_data_time
time_escape_total += time_escape time_escape_total += time_escape
pbar.update(1) pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
return asr_result_list return asr_result_list
def generate_with_vad(self, input, input_len=None, **cfg): def generate_with_vad(self, input, input_len=None, **cfg):
# step.1: compute the vad model # step.1: compute the vad model
model = self.vad_model model = self.vad_model
kwargs = self.vad_kwargs kwargs = self.vad_kwargs
kwargs.update(cfg) kwargs.update(cfg)
beg_vad = time.time() beg_vad = time.time()
res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg) res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
end_vad = time.time() vad_res = res
print(f"time cost vad: {end_vad - beg_vad:0.3f}") end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
# step.2 compute asr model # step.2 compute asr model
model = self.model model = self.model
kwargs = self.kwargs kwargs = self.kwargs
kwargs.update(cfg) kwargs.update(cfg)
batch_size = int(kwargs.get("batch_size_s", 300))*1000 batch_size = int(kwargs.get("batch_size_s", 300))*1000
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000 batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
kwargs["batch_size"] = batch_size kwargs["batch_size"] = batch_size
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None)) key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
results_ret_list = [] results_ret_list = []
time_speech_total_all_samples = 0.0 time_speech_total_all_samples = 0.0
beg_total = time.time() beg_total = time.time()
pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True) pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
for i in range(len(res)): for i in range(len(res)):
key = res[i]["key"] key = res[i]["key"]
vadsegments = res[i]["value"] vadsegments = res[i]["value"]
input_i = data_list[i] input_i = data_list[i]
speech = load_audio_text_image_video(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000)) speech = load_audio_text_image_video(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
speech_lengths = len(speech) speech_lengths = len(speech)
n = len(vadsegments) n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)] data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = [] results_sorted = []
if not len(sorted_data): if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key)) logging.info("decoding, utt: {}, empty speech".format(key))
continue continue
# if kwargs["device"] == "cpu": # if kwargs["device"] == "cpu":
# batch_size = 0 # batch_size = 0
if len(sorted_data) > 0 and len(sorted_data[0]) > 0: if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]) batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
batch_size_ms_cum = 0 batch_size_ms_cum = 0
beg_idx = 0 beg_idx = 0
beg_asr_total = time.time() beg_asr_total = time.time()
time_speech_total_per_sample = speech_lengths/16000 time_speech_total_per_sample = speech_lengths/16000
time_speech_total_all_samples += time_speech_total_per_sample time_speech_total_all_samples += time_speech_total_per_sample
for j, _ in enumerate(range(0, n)): for j, _ in enumerate(range(0, n)):
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0]) batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n - 1 and ( if j < n - 1 and (
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and ( batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms: sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms:
continue continue
batch_size_ms_cum = 0 batch_size_ms_cum = 0
end_idx = j + 1 end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
beg_idx = end_idx results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg) all_segments = []
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
if len(results) < 1: for _b in range(len(speech_j)):
continue vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
results_sorted.extend(results) sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
speech_j[_b]]]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
continue
results_sorted.extend(results)
pbar_total.update(1) pbar_total.update(1)
end_asr_total = time.time() end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total time_escape_total_per_sample = end_asr_total - beg_asr_total
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
restored_data = [0] * n restored_data = [0] * n
for j in range(n): for j in range(n):
index = sorted_data[j][1] index = sorted_data[j][1]
restored_data[index] = results_sorted[j] restored_data[index] = results_sorted[j]
result = {} result = {}
for j in range(n): # results combine for texts, timestamps, speaker embeddings and others
for k, v in restored_data[j].items(): # TODO: rewrite for clean code
if not k.startswith("timestamp"): for j in range(n):
if k not in result: for k, v in restored_data[j].items():
result[k] = restored_data[j][k] if k.startswith("timestamp"):
else: if k not in result:
result[k] += restored_data[j][k] result[k] = []
else: for t in restored_data[j][k]:
result[k] = [] t[0] += vadsegments[j][0]
for t in restored_data[j][k]: t[1] += vadsegments[j][0]
t[0] += vadsegments[j][0] result[k].extend(restored_data[j][k])
t[1] += vadsegments[j][0] elif k == 'spk_embedding':
result[k] += restored_data[j][k] if k not in result:
result[k] = restored_data[j][k]
result["key"] = key else:
results_ret_list.append(result) result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
pbar_total.update(1) elif k == 'text':
if k not in result:
# step.3 compute punc model result[k] = restored_data[j][k]
model = self.punc_model else:
kwargs = self.punc_kwargs result[k] += " " + restored_data[j][k]
kwargs.update(cfg) else:
if k not in result:
for i, result in enumerate(results_ret_list): result[k] = restored_data[j][k]
beg_punc = time.time() else:
res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg) result[k] += restored_data[j][k]
end_punc = time.time()
print(f"time punc: {end_punc - beg_punc:0.3f}") # step.3 compute punc model
if self.punc_model is not None:
# sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"]) self.punc_kwargs.update(cfg)
# results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"] punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
# results_ret_list[i]["sentences"] = sentences result["text_with_punc"] = punc_res[0]["text"]
results_ret_list[i]["text_with_punc"] = res[i]["text"]
# speaker embedding cluster after resorted
pbar_total.update(1) if self.spk_model is not None:
end_total = time.time() all_segments = sorted(all_segments, key=lambda x: x[0])
time_escape_total_all_samples = end_total - beg_total spk_embedding = result['spk_embedding']
pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, " labels = self.cb_model(spk_embedding)
f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, " del result['spk_embedding']
f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}") sv_output = postprocess(all_segments, None, labels, spk_embedding)
return results_ret_list if self.spk_mode == 'vad_segment':
sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments):
sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1],
"sentence": res['text'],
"timestamp": res['timestamp']})
else: # punc_segment
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['text'])
distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list
result["key"] = key
results_ret_list.append(result)
pbar_total.update(1)
pbar_total.update(1)
end_total = time.time()
time_escape_total_all_samples = end_total - beg_total
pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
return results_ret_list
class AutoFrontend: class AutoFrontend:
def __init__(self, **kwargs): def __init__(self, **kwargs):
assert "model" in kwargs assert "model" in kwargs
if "model_conf" not in kwargs: if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs) kwargs = download_model(**kwargs)
# build frontend # build frontend
frontend = kwargs.get("frontend", None) frontend = kwargs.get("frontend", None)
if frontend is not None: if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower()) frontend_class = tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"]) frontend = frontend_class(**kwargs["frontend_conf"])
self.frontend = frontend self.frontend = frontend
if "frontend" in kwargs: if "frontend" in kwargs:
del kwargs["frontend"] del kwargs["frontend"]
self.kwargs = kwargs self.kwargs = kwargs
def __call__(self, input, input_len=None, kwargs=None, **cfg): def __call__(self, input, input_len=None, kwargs=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg) kwargs.update(cfg)
key_list, data_list = prepare_data_iterator(input, input_len=input_len) key_list, data_list = prepare_data_iterator(input, input_len=input_len)
batch_size = kwargs.get("batch_size", 1) batch_size = kwargs.get("batch_size", 1)
device = kwargs.get("device", "cpu") device = kwargs.get("device", "cpu")
if device == "cpu": if device == "cpu":
batch_size = 1 batch_size = 1
meta_data = {} meta_data = {}
result_list = [] result_list = []
num_samples = len(data_list) num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True) pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
time0 = time.perf_counter() time0 = time.perf_counter()
for beg_idx in range(0, num_samples, batch_size): for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size) end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx] data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx] key_batch = key_list[beg_idx:end_idx]
# extract fbank feats # extract fbank feats
time1 = time.perf_counter() time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)) audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter() time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}" meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend, **kwargs) frontend=self.frontend, **kwargs)
time3 = time.perf_counter() time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000 meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
speech.to(device=device), speech_lengths.to(device=device) speech.to(device=device), speech_lengths.to(device=device)
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch} batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
result_list.append(batch) result_list.append(batch)
pbar.update(1) pbar.update(1)
description = ( description = (
f"{meta_data}, " f"{meta_data}, "
) )
pbar.set_description(description) pbar.set_description(description)
time_end = time.perf_counter() time_end = time.perf_counter()
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}") pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
return result_list return result_list
if __name__ == '__main__': if __name__ == '__main__':
main_hydra() main_hydra()

View File

@ -0,0 +1,191 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union
import hdbscan
import numpy as np
import scipy
import sklearn
import umap
from sklearn.cluster._kmeans import k_means
from torch import nn
class SpectralCluster:
r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
This implementation is adapted from https://github.com/speechbrain/speechbrain.
"""
def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
self.min_num_spks = min_num_spks
self.max_num_spks = max_num_spks
self.pval = pval
def __call__(self, X, oracle_num=None):
# Similarity matrix computation
sim_mat = self.get_sim_mat(X)
# Refining similarity matrix with pval
prunned_sim_mat = self.p_pruning(sim_mat)
# Symmetrization
sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
# Laplacian calculation
laplacian = self.get_laplacian(sym_prund_sim_mat)
# Get Spectral Embeddings
emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
# Perform clustering
labels = self.cluster_embs(emb, num_of_spk)
return labels
def get_sim_mat(self, X):
# Cosine similarities
M = sklearn.metrics.pairwise.cosine_similarity(X, X)
return M
def p_pruning(self, A):
if A.shape[0] * self.pval < 6:
pval = 6. / A.shape[0]
else:
pval = self.pval
n_elems = int((1 - pval) * A.shape[0])
# For each row in a affinity matrix
for i in range(A.shape[0]):
low_indexes = np.argsort(A[i, :])
low_indexes = low_indexes[0:n_elems]
# Replace smaller similarity values by 0s
A[i, low_indexes] = 0
return A
def get_laplacian(self, M):
M[np.diag_indices(M.shape[0])] = 0
D = np.sum(np.abs(M), axis=1)
D = np.diag(D)
L = D - M
return L
def get_spec_embs(self, L, k_oracle=None):
lambdas, eig_vecs = scipy.linalg.eigh(L)
if k_oracle is not None:
num_of_spk = k_oracle
else:
lambda_gap_list = self.getEigenGaps(
lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
emb = eig_vecs[:, :num_of_spk]
return emb, num_of_spk
def cluster_embs(self, emb, k):
_, labels, _ = k_means(emb, k)
return labels
def getEigenGaps(self, eig_vals):
eig_vals_gap_list = []
for i in range(len(eig_vals) - 1):
gap = float(eig_vals[i + 1]) - float(eig_vals[i])
eig_vals_gap_list.append(gap)
return eig_vals_gap_list
class UmapHdbscan:
r"""
Reference:
- Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
Emphasis On Topological Structure. ICASSP2022
"""
def __init__(self,
n_neighbors=20,
n_components=60,
min_samples=10,
min_cluster_size=10,
metric='cosine'):
self.n_neighbors = n_neighbors
self.n_components = n_components
self.min_samples = min_samples
self.min_cluster_size = min_cluster_size
self.metric = metric
def __call__(self, X):
umap_X = umap.UMAP(
n_neighbors=self.n_neighbors,
min_dist=0.0,
n_components=min(self.n_components, X.shape[0] - 2),
metric=self.metric,
).fit_transform(X)
labels = hdbscan.HDBSCAN(
min_samples=self.min_samples,
min_cluster_size=self.min_cluster_size,
allow_single_cluster=True).fit_predict(umap_X)
return labels
class ClusterBackend(nn.Module):
r"""Perfom clustering for input embeddings and output the labels.
Args:
model_dir: A model dir.
model_config: The model config.
"""
def __init__(self):
super().__init__()
self.model_config = {'merge_thr':0.78}
# self.other_config = kwargs
self.spectral_cluster = SpectralCluster()
self.umap_hdbscan_cluster = UmapHdbscan()
def forward(self, X, **params):
# clustering and return the labels
k = params['oracle_num'] if 'oracle_num' in params else None
assert len(
X.shape
) == 2, 'modelscope error: the shape of input should be [N, C]'
if X.shape[0] < 20:
return np.zeros(X.shape[0], dtype='int')
if X.shape[0] < 2048 or k is not None:
labels = self.spectral_cluster(X, k)
else:
labels = self.umap_hdbscan_cluster(X)
if k is None and 'merge_thr' in self.model_config:
labels = self.merge_by_cos(labels, X,
self.model_config['merge_thr'])
return labels
def merge_by_cos(self, labels, embs, cos_thr):
# merge the similar speakers by cosine similarity
assert cos_thr > 0 and cos_thr <= 1
while True:
spk_num = labels.max() + 1
if spk_num == 1:
break
spk_center = []
for i in range(spk_num):
spk_emb = embs[labels == i].mean(0)
spk_center.append(spk_emb)
assert len(spk_center) > 0
spk_center = np.stack(spk_center, axis=0)
norm_spk_center = spk_center / np.linalg.norm(
spk_center, axis=1, keepdims=True)
affinity = np.matmul(norm_spk_center, norm_spk_center.T)
affinity = np.triu(affinity, 1)
spks = np.unravel_index(np.argmax(affinity), affinity.shape)
if affinity[spks] < cos_thr:
break
for i in range(len(labels)):
if labels[i] == spks[1]:
labels[i] = spks[0]
elif labels[i] > spks[1]:
labels[i] -= 1
return labels

View File

@ -109,13 +109,9 @@ class CAMPPlus(nn.Module):
audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound") audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
time2 = time.perf_counter() time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}" meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_feature(audio_sample_list) speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
time3 = time.perf_counter() time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0 meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
# import pdb; pdb.set_trace() results = [{"spk_embedding": self.forward(speech)}]
results = []
embeddings = self.forward(speech)
for embedding in embeddings:
results.append({"spk_embedding":embedding})
return results, meta_data return results, meta_data

View File

@ -2,23 +2,19 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import io import io
from typing import Union
import librosa as sf
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
import contextlib
import os import os
import torch
import requests
import tempfile import tempfile
from abc import ABCMeta, abstractmethod import contextlib
import numpy as np
import librosa as sf
from typing import Union
from pathlib import Path from pathlib import Path
from typing import Generator, Union from typing import Generator, Union
from abc import ABCMeta, abstractmethod
import requests import torchaudio.compliance.kaldi as Kaldi
from funasr.models.transformer.utils.nets_utils import pad_list
def check_audio_list(audio: list): def check_audio_list(audio: list):
@ -40,31 +36,31 @@ def check_audio_list(audio: list):
def sv_preprocess(inputs: Union[np.ndarray, list]): def sv_preprocess(inputs: Union[np.ndarray, list]):
output = [] output = []
for i in range(len(inputs)): for i in range(len(inputs)):
if isinstance(inputs[i], str): if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i]) file_bytes = File.read(inputs[i])
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32') data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2: if len(data.shape) == 2:
data = data[:, 0] data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0) data = torch.from_numpy(data).unsqueeze(0)
data = data.squeeze(0) data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray): elif isinstance(inputs[i], np.ndarray):
assert len( assert len(
inputs[i].shape inputs[i].shape
) == 1, 'modelscope error: Input array should be [N, T]' ) == 1, 'modelscope error: Input array should be [N, T]'
data = inputs[i] data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']: if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32') data = (data / (1 << 15)).astype('float32')
else: else:
data = data.astype('float32') data = data.astype('float32')
data = torch.from_numpy(data) data = torch.from_numpy(data)
else: else:
raise ValueError( raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.' 'modelscope error: The input type is restricted to audio address and nump array.'
) )
output.append(data) output.append(data)
return output return output
def sv_chunk(vad_segments: list, fs = 16000) -> list: def sv_chunk(vad_segments: list, fs = 16000) -> list:
@ -105,15 +101,19 @@ def sv_chunk(vad_segments: list, fs = 16000) -> list:
def extract_feature(audio): def extract_feature(audio):
features = [] features = []
feature_times = []
feature_lengths = [] feature_lengths = []
for au in audio: for au in audio:
feature = Kaldi.fbank( feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=80) au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True) feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0)) features.append(feature)
feature_lengths.append(au.shape[0]) feature_times.append(au.shape[0])
features = torch.cat(features) feature_lengths.append(feature.shape[0])
return features, feature_lengths # padding for batch inference
features_padded = pad_list(features, pad_value=0)
# features = torch.cat(features)
return features_padded, feature_lengths, feature_times
def postprocess(segments: list, vad_segments: list, def postprocess(segments: list, vad_segments: list,
@ -195,8 +195,8 @@ def smooth(res, mindur=1):
def distribute_spk(sentence_list, sd_time_list): def distribute_spk(sentence_list, sd_time_list):
sd_sentence_list = [] sd_sentence_list = []
for d in sentence_list: for d in sentence_list:
sentence_start = d['ts_list'][0][0] sentence_start = d['start']
sentence_end = d['ts_list'][-1][1] sentence_end = d['end']
sentence_spk = 0 sentence_spk = 0
max_overlap = 0 max_overlap = 0
for sd_time in sd_time_list: for sd_time in sd_time_list:
@ -213,8 +213,6 @@ def distribute_spk(sentence_list, sd_time_list):
return sd_sentence_list return sd_sentence_list
class Storage(metaclass=ABCMeta): class Storage(metaclass=ABCMeta):
"""Abstract class of storage. """Abstract class of storage.

View File

@ -239,6 +239,7 @@ class CTTransformer(nn.Module):
cache_pop_trigger_limit = 200 cache_pop_trigger_limit = 200
results = [] results = []
meta_data = {} meta_data = {}
punc_array = None
for mini_sentence_i in range(len(mini_sentences)): for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i] mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i] mini_sentence_id = mini_sentences_id[mini_sentence_i]
@ -320,8 +321,13 @@ class CTTransformer(nn.Module):
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "." new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
# keep a punctuations array for punc segment
result_i = {"key": key[0], "text": new_mini_sentence_out} if punc_array is None:
punc_array = punctuations
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i) results.append(result_i)
return results, meta_data return results, meta_data

View File

@ -98,14 +98,14 @@ def ts_prediction_lfr6_standard(us_alphas,
return res_txt, res return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed): def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed):
punc_list = ['', '', '', ''] punc_list = ['', '', '', '']
res = [] res = []
if text_postprocessed is None: if text_postprocessed is None:
return res return res
if time_stamp_postprocessed is None: if timestamp_postprocessed is None:
return res return res
if len(time_stamp_postprocessed) == 0: if len(timestamp_postprocessed) == 0:
return res return res
if len(text_postprocessed) == 0: if len(text_postprocessed) == 0:
return res return res
@ -113,23 +113,22 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
if punc_id_list is None or len(punc_id_list) == 0: if punc_id_list is None or len(punc_id_list) == 0:
res.append({ res.append({
'text': text_postprocessed.split(), 'text': text_postprocessed.split(),
"start": time_stamp_postprocessed[0][0], "start": timestamp_postprocessed[0][0],
"end": time_stamp_postprocessed[-1][1], "end": timestamp_postprocessed[-1][1],
'text_seg': text_postprocessed.split(), "timestamp": timestamp_postprocessed,
"ts_list": time_stamp_postprocessed,
}) })
return res return res
if len(punc_id_list) != len(time_stamp_postprocessed): if len(punc_id_list) != len(timestamp_postprocessed):
print(" warning length mistach!!!!!!") logging.warning("length mismatch between punc and timestamp")
sentence_text = "" sentence_text = ""
sentence_text_seg = "" sentence_text_seg = ""
ts_list = [] ts_list = []
sentence_start = time_stamp_postprocessed[0][0] sentence_start = timestamp_postprocessed[0][0]
sentence_end = time_stamp_postprocessed[0][1] sentence_end = timestamp_postprocessed[0][1]
texts = text_postprocessed.split() texts = text_postprocessed.split()
punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None)) punc_stamp_text_list = list(zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None))
for punc_stamp_text in punc_stamp_text_list: for punc_stamp_text in punc_stamp_text_list:
punc_id, time_stamp, text = punc_stamp_text punc_id, timestamp, text = punc_stamp_text
# sentence_text += text if text is not None else '' # sentence_text += text if text is not None else ''
if text is not None: if text is not None:
if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z': if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z':
@ -139,10 +138,10 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
else: else:
sentence_text += text sentence_text += text
sentence_text_seg += text + ' ' sentence_text_seg += text + ' '
ts_list.append(time_stamp) ts_list.append(timestamp)
punc_id = int(punc_id) if punc_id is not None else 1 punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = time_stamp[1] if time_stamp is not None else sentence_end sentence_end = timestamp[1] if timestamp is not None else sentence_end
if punc_id > 1: if punc_id > 1:
sentence_text += punc_list[punc_id - 2] sentence_text += punc_list[punc_id - 2]
@ -150,8 +149,7 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
'text': sentence_text, 'text': sentence_text,
"start": sentence_start, "start": sentence_start,
"end": sentence_end, "end": sentence_end,
"text_seg": sentence_text_seg, "timestamp": ts_list
"ts_list": ts_list
}) })
sentence_text = '' sentence_text = ''
sentence_text_seg = '' sentence_text_seg = ''
@ -160,181 +158,4 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
return res return res
# class AverageShiftCalculator():
# def __init__(self):
# logging.warning("Calculating average shift.")
# def __call__(self, file1, file2):
# uttid_list1, ts_dict1 = self.read_timestamps(file1)
# uttid_list2, ts_dict2 = self.read_timestamps(file2)
# uttid_intersection = self._intersection(uttid_list1, uttid_list2)
# res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
# logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
# logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
#
# def _intersection(self, list1, list2):
# set1 = set(list1)
# set2 = set(list2)
# if set1 == set2:
# logging.warning("Uttid same checked.")
# return set1
# itsc = list(set1 & set2)
# logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
# return itsc
#
# def read_timestamps(self, file):
# # read timestamps file in standard format
# uttid_list = []
# ts_dict = {}
# with codecs.open(file, 'r') as fin:
# for line in fin.readlines():
# text = ''
# ts_list = []
# line = line.rstrip()
# uttid = line.split()[0]
# uttid_list.append(uttid)
# body = " ".join(line.split()[1:])
# for pd in body.split(';'):
# if not len(pd): continue
# # pdb.set_trace()
# char, start, end = pd.lstrip(" ").split(' ')
# text += char + ','
# ts_list.append((float(start), float(end)))
# # ts_lists.append(ts_list)
# ts_dict[uttid] = (text[:-1], ts_list)
# logging.warning("File {} read done.".format(file))
# return uttid_list, ts_dict
#
# def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time = 0
# for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
# num_tokens = len(filtered_timestamp_list1)
# return shift_time, num_tokens
#
# # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# # # calculate average shift between timestamp1 and timestamp2
# # # when characters differ, use edit distance alignment
# # # and calculate the error between the same characters
# # self._accumlated_shift = 0
# # self._accumlated_tokens = 0
# # self.max_shift = 0
# # self.max_shift_uttid = None
# # for uttid in uttid_list:
# # (t1, ts1) = ts_dict1[uttid]
# # (t2, ts2) = ts_dict2[uttid]
# # _align, _align2, _align3 = [], [], []
# # fts1, fts2 = [], []
# # _t1, _t2 = [], []
# # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
# # s = sm.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align.append(0)
# # if s[j][0] == "replace" or s[j][0] == "delete":
# # _align3.append(0)
# # elif s[j][0] == "equal":
# # _align.append(1)
# # _align3.append(1)
# # else:
# # continue
# # # use s to index t2
# # for a, ts , t in zip(_align, ts2, t2.split(',')):
# # if a:
# # fts2.append(ts)
# # _t2.append(t)
# # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
# # s = sm2.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align2.append(0)
# # elif s[j][0] == "equal":
# # _align2.append(1)
# # else:
# # continue
# # # use s2 tp index t1
# # for a, ts, t in zip(_align3, ts1, t1.split(',')):
# # if a:
# # fts1.append(ts)
# # _t1.append(t)
# # if len(fts1) == len(fts2):
# # shift_time, num_tokens = self._shift(fts1, fts2)
# # self._accumlated_shift += shift_time
# # self._accumlated_tokens += num_tokens
# # if shift_time/num_tokens > self.max_shift:
# # self.max_shift = shift_time/num_tokens
# # self.max_shift_uttid = uttid
# # else:
# # logging.warning("length mismatch")
# # return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
from funasr.models.paraformer.cif_predictor import cif_wo_hidden
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
for line1, line2 in zip(f1.readlines(), f2.readlines()):
line1 = line1.rstrip()
line2 = line2.rstrip()
assert line1.split()[0] == line2.split()[0]
uttid = line1.split()[0]
alphas = [float(i) for i in line1.split()[1:]]
new_alphas = np.array(remove_chunk_padding(alphas))
new_alphas[-1] += 1e-4
text = line2.split()[1:]
if len(text) + 1 != int(new_alphas.sum()):
# force resize
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
if " " in text:
text = text.split()
else:
text = [i for i in text]
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
force_time_shift=-7.0,
sil_in_str=False)
f3.write("{} {}\n".format(uttid, res_str))
def remove_chunk_padding(alphas):
# remove the padding part in alphas if using chunk paraformer for GPU
START_ZERO = 45
MID_ZERO = 75
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
alphas = alphas[START_ZERO:] # remove the padding at beginning
new_alphas = []
while True:
new_alphas = new_alphas + alphas[:REAL_FRAMES]
alphas = alphas[REAL_FRAMES+MID_ZERO:]
if len(alphas) < REAL_FRAMES: break
return new_alphas
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
# if args.mode == 'cal_aas':
# asc = AverageShiftCalculator()
# asc(args.input, args.input2)
if args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='timestamp tools')
parser.add_argument('--mode',
default=None,
type=str,
choices=SUPPORTED_MODES,
help='timestamp related toolbox')
parser.add_argument('--input', default=None, type=str, help='input file path')
parser.add_argument('--output', default=None, type=str, help='output file name')
parser.add_argument('--input2', default=None, type=str, help='input2 file path')
parser.add_argument('--kaldi-ts-type',
default='v2',
type=str,
choices=['v0', 'v1', 'v2'],
help='kaldi timestamp to write')
args = parser.parse_args()
main(args)