diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py old mode 100644 new mode 100755 diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh old mode 100644 new mode 100755 index 8fc66f34f..1bd4f7f5b --- a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh +++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh @@ -2,7 +2,7 @@ model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" model_revision="v2.0.4" -python funasr/bin/inference.py \ +python ../../../funasr/bin/inference.py \ +model=${model} \ +model_revision=${model_revision} \ +input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \ diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh new file mode 100755 index 000000000..282f4f1f2 --- /dev/null +++ b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh @@ -0,0 +1,9 @@ +python -m funasr.bin.inference \ +--config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \ +--config-name="config.yaml" \ +++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \ +++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \ +++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \ +++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \ +++output_dir="./outputs/debug2" \ +++device="" \ diff --git a/examples/industrial_data_pretraining/contextual_paraformer/path.sh b/examples/industrial_data_pretraining/contextual_paraformer/path.sh new file mode 100755 index 000000000..1a6d67e08 --- /dev/null +++ b/examples/industrial_data_pretraining/contextual_paraformer/path.sh @@ -0,0 +1,6 @@ +export FUNASR_DIR=$PWD/../../../ + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH +export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py new file mode 100755 index 000000000..e72d87155 --- /dev/null +++ b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py @@ -0,0 +1,702 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +from enum import Enum +import re, sys, unicodedata +import codecs +import argparse +from tqdm import tqdm +import os +import pdb +remove_tag = False +spacelist = [" ", "\t", "\r", "\n"] +puncts = [ + "!", + ",", + "?", + "、", + "。", + "!", + ",", + ";", + "?", + ":", + "「", + "」", + "︰", + "『", + "』", + "《", + "》", +] + + +class Code(Enum): + match = 1 + substitution = 2 + insertion = 3 + deletion = 4 + + +class WordError(object): + def __init__(self): + self.errors = { + Code.substitution: 0, + Code.insertion: 0, + Code.deletion: 0, + } + self.ref_words = 0 + + def get_wer(self): + assert self.ref_words != 0 + errors = ( + self.errors[Code.substitution] + + self.errors[Code.insertion] + + self.errors[Code.deletion] + ) + return 100.0 * errors / self.ref_words + + def get_result_string(self): + return ( + f"error_rate={self.get_wer():.4f}, " + f"ref_words={self.ref_words}, " + f"subs={self.errors[Code.substitution]}, " + f"ins={self.errors[Code.insertion]}, " + f"dels={self.errors[Code.deletion]}" + ) + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + # https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == "Lo": # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = " " + if char == "<": + sep = ">" + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == ">": + j += 1 + res.append(string[i:j]) + i = j + return res + + +def stripoff_tags(x): + if not x: + return "" + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == "<": + while i < T and x[i] != ">": + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return "".join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """sentence, ignore_words are both in unicode""" + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + + +class Calculator: + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost["cor"] = 0 + self.cost["sub"] = 1 + self.cost["del"] = 1 + self.cost["ins"] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, "") + rec.insert(0, "") + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element["dist"] = 0 + element["error"] = "non" + while len(row) < len(rec): + row.append({"dist": 0, "error": "non"}) + for i in range(len(lab)): + self.space[i][0]["dist"] = i + self.space[i][0]["error"] = "del" + for j in range(len(rec)): + self.space[0][j]["dist"] = j + self.space[0][j]["error"] = "ins" + self.space[0][0]["error"] = "non" + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = "none" + dist = self.space[i - 1][j]["dist"] + self.cost["del"] + error = "del" + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]["dist"] + self.cost["ins"] + error = "ins" + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token.replace("", ""): + dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"] + error = "cor" + else: + dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"] + error = "sub" + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]["dist"] = min_dist + self.space[i][j]["error"] = min_error + # Tracing back + result = { + "lab": [], + "rec": [], + "code": [], + "all": 0, + "cor": 0, + "sub": 0, + "ins": 0, + "del": 0, + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]["error"] == "cor": # correct + if len(lab[i]) > 0: + self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 + self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1 + result["all"] = result["all"] + 1 + result["cor"] = result["cor"] + 1 + result["lab"].insert(0, lab[i]) + result["rec"].insert(0, rec[j]) + result["code"].insert(0, Code.match) + i = i - 1 + j = j - 1 + elif self.space[i][j]["error"] == "sub": # substitution + if len(lab[i]) > 0: + self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 + self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1 + result["all"] = result["all"] + 1 + result["sub"] = result["sub"] + 1 + result["lab"].insert(0, lab[i]) + result["rec"].insert(0, rec[j]) + result["code"].insert(0, Code.substitution) + i = i - 1 + j = j - 1 + elif self.space[i][j]["error"] == "del": # deletion + if len(lab[i]) > 0: + self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 + self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1 + result["all"] = result["all"] + 1 + result["del"] = result["del"] + 1 + result["lab"].insert(0, lab[i]) + result["rec"].insert(0, "") + result["code"].insert(0, Code.deletion) + i = i - 1 + elif self.space[i][j]["error"] == "ins": # insertion + if len(rec[j]) > 0: + self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1 + result["ins"] = result["ins"] + 1 + result["lab"].insert(0, "") + result["rec"].insert(0, rec[j]) + result["code"].insert(0, Code.insertion) + j = j - 1 + elif self.space[i][j]["error"] == "non": # starting point + break + else: # shouldn't reach here + print( + "this should not happen , i = {i} , j = {j} , error = {error}".format( + i=i, j=j, error=self.space[i][j]["error"] + ) + ) + return result + + def overall(self): + result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + for token in self.data: + result["all"] = result["all"] + self.data[token]["all"] + result["cor"] = result["cor"] + self.data[token]["cor"] + result["sub"] = result["sub"] + self.data[token]["sub"] + result["ins"] = result["ins"] + self.data[token]["ins"] + result["del"] = result["del"] + self.data[token]["del"] + return result + + def cluster(self, data): + result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + for token in data: + if token in self.data: + result["all"] = result["all"] + self.data[token]["all"] + result["cor"] = result["cor"] + self.data[token]["cor"] + result["sub"] = result["sub"] + self.data[token]["sub"] + result["ins"] = result["ins"] + self.data[token]["ins"] + result["del"] = result["del"] + self.data[token]["del"] + return result + + def keys(self): + return list(self.data.keys()) + + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + + +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith("DIGIT"): # 1 + unicode_names[i] = "Number" # 'DIGIT' + elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[ + i + ].startswith("CJK COMPATIBILITY IDEOGRAPH"): + # 明 / 郎 + unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH' + elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[ + i + ].startswith("LATIN SMALL LETTER"): + # A / a + unicode_names[i] = "English" # 'LATIN LETTER' + elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め + unicode_names[i] = "Japanese" # 'GANA LETTER' + elif ( + unicode_names[i].startswith("AMPERSAND") + or unicode_names[i].startswith("APOSTROPHE") + or unicode_names[i].startswith("COMMERCIAL AT") + or unicode_names[i].startswith("DEGREE CELSIUS") + or unicode_names[i].startswith("EQUALS SIGN") + or unicode_names[i].startswith("FULL STOP") + or unicode_names[i].startswith("HYPHEN-MINUS") + or unicode_names[i].startswith("LOW LINE") + or unicode_names[i].startswith("NUMBER SIGN") + or unicode_names[i].startswith("PLUS SIGN") + or unicode_names[i].startswith("SEMICOLON") + ): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return "Other" + if len(unicode_names) == 0: + return "Other" + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return "Other" + return unicode_names[0] + + +def get_args(): + parser = argparse.ArgumentParser(description="wer cal") + parser.add_argument("--ref", type=str, help="Text input path") + parser.add_argument("--ref_ocr", type=str, help="Text input path") + parser.add_argument("--rec_name", type=str, action="append", default=[]) + parser.add_argument("--rec_file", type=str, action="append", default=[]) + parser.add_argument("--verbose", type=int, default=1, help="show") + parser.add_argument("--char", type=bool, default=True, help="show") + args = parser.parse_args() + return args + + +def main(args): + cluster_file = "" + ignore_words = set() + tochar = args.char + verbose = args.verbose + padding_symbol = " " + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + ref_file = args.ref + ref_ocr = args.ref_ocr + rec_files = args.rec_file + rec_names = args.rec_name + assert len(rec_files) == len(rec_names) + + # load ocr + ref_ocr_dict = {} + with codecs.open(ref_ocr, "r", "utf-8") as fh: + for line in fh: + if "$" in line: + line = line.replace("$", " ") + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + rec_sets = {} + calculators_dict = dict() + ub_wer_dict = dict() + hotwords_related_dict = dict() # 记录recall相关的内容 + for i, hyp_file in enumerate(rec_files): + rec_sets[rec_names[i]] = dict() + with codecs.open(hyp_file, "r", "utf-8") as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + calculators_dict[rec_names[i]] = Calculator() + ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()} + hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} + # tp: 热词在label里,同时在rec里 + # tn: 热词不在label里,同时不在rec里 + # fp: 热词不在label里,但是在rec里 + # fn: 热词在label里,但是不在rec里 + + # record wrong label but in ocr + wrong_rec_but_in_ocr_dict = {} + for rec_name in rec_names: + wrong_rec_but_in_ocr_dict[rec_name] = 0 + + _file_total_len = 0 + with os.popen("cat {} | wc -l".format(ref_file)) as pipe: + _file_total_len = int(pipe.read().strip()) + + # compute error rate on the interaction of reference file and hyp file + for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len): + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: continue + fid = array[0] + lab = normalize(array[1:], ignore_words, case_sensitive, split) + + if verbose: + print('\nutt: %s' % fid) + + ocr_text = ref_ocr_dict[fid] + ocr_set = set(ocr_text) + print('ocr: {}'.format(" ".join(ocr_text))) + list_match = [] # 指label里面在ocr里面的内容 + list_not_mathch = [] + tmp_error = 0 + tmp_match = 0 + for index in range(len(lab)): + # text_list.append(uttlist[index+1]) + if lab[index] not in ocr_set: + tmp_error += 1 + list_not_mathch.append(lab[index]) + else: + tmp_match += 1 + list_match.append(lab[index]) + print('label in ocr: {}'.format(" ".join(list_match))) + + # for each reco file + base_wrong_ocr_wer = None + ocr_wrong_ocr_wer = None + + for rec_name in rec_names: + rec_set = rec_sets[rec_name] + if fid not in rec_set: + continue + rec = rec_set[fid] + + # print(rec) + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy()) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + + + # print(result['rec']) + wrong_rec_but_in_ocr = [] + for idx in range(len(result['lab'])): + if result['lab'][idx] != "": + if result['lab'][idx] != result['rec'][idx].replace("", ""): + if result['lab'][idx] in list_match: + wrong_rec_but_in_ocr.append(result['lab'][idx]) + wrong_rec_but_in_ocr_dict[rec_name] += 1 + print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr))) + + if rec_name == "base": + base_wrong_ocr_wer = len(wrong_rec_but_in_ocr) + if "ocr" in rec_name or "hot" in rec_name: + ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr) + if ocr_wrong_ocr_wer < base_wrong_ocr_wer: + print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) + elif ocr_wrong_ocr_wer > base_wrong_ocr_wer: + print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) + + # recall = 0 + # false_alarm = 0 + # for idx in range(len(result['lab'])): + # if "" in result['rec'][idx]: + # if result['rec'][idx].replace("", "") in list_match: + # recall += 1 + # else: + # false_alarm += 1 + # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format( + # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0 + # )) + # tp: 热词在label里,同时在rec里 + # tn: 热词不在label里,同时不在rec里 + # fp: 热词不在label里,但是在rec里 + # fn: 热词在label里,但是不在rec里 + _rec_list = [word.replace("", "") for word in rec] + _label_list = [word for word in lab] + _tp = _tn = _fp = _fn = 0 + hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list] + hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list] + for badhotword in hot_bad_list: + count = len([word for word in _rec_list if word == badhotword]) + # print(f"bad {badhotword} count: {count}") + # for word in _rec_list: + # if badhotword == word: + # count += 1 + if count == 0: + hotwords_related_dict[rec_name]['tn'] += 1 + _tn += 1 + # fp: 0 + else: + hotwords_related_dict[rec_name]['fp'] += count + _fp += count + # tn: 0 + # if badhotword in _rec_list: + # hotwords_related_dict[rec_name]['fp'] += 1 + # else: + # hotwords_related_dict[rec_name]['tn'] += 1 + for hotword in hot_true_list: + true_count = len([word for word in _label_list if hotword == word]) + rec_count = len([word for word in _rec_list if hotword == word]) + # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}") + if rec_count == true_count: + hotwords_related_dict[rec_name]['tp'] += true_count + _tp += true_count + elif rec_count > true_count: + hotwords_related_dict[rec_name]['tp'] += true_count + # fp: 不在label里,但是在rec里 + hotwords_related_dict[rec_name]['fp'] += rec_count - true_count + _tp += true_count + _fp += rec_count - true_count + else: + hotwords_related_dict[rec_name]['tp'] += rec_count + # fn: 热词在label里,但是不在rec里 + hotwords_related_dict[rec_name]['fn'] += true_count - rec_count + _tp += rec_count + _fn += true_count - rec_count + print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( + _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0 + )) + + # if hotword in _rec_list: + # hotwords_related_dict[rec_name]['tp'] += 1 + # else: + # hotwords_related_dict[rec_name]['fn'] += 1 + # 计算uwer, bwer, wer + for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]): + if code == Code.match: + ub_wer_dict[rec_name]["wer"].ref_words += 1 + if lab_word in hot_true_list: + # tmp_ref.append(ref_tokens[ref_idx]) + ub_wer_dict[rec_name]["b_wer"].ref_words += 1 + else: + ub_wer_dict[rec_name]["u_wer"].ref_words += 1 + elif code == Code.substitution: + ub_wer_dict[rec_name]["wer"].ref_words += 1 + ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1 + if lab_word in hot_true_list: + # tmp_ref.append(ref_tokens[ref_idx]) + ub_wer_dict[rec_name]["b_wer"].ref_words += 1 + ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1 + else: + ub_wer_dict[rec_name]["u_wer"].ref_words += 1 + ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1 + elif code == Code.deletion: + ub_wer_dict[rec_name]["wer"].ref_words += 1 + ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1 + if lab_word in hot_true_list: + # tmp_ref.append(ref_tokens[ref_idx]) + ub_wer_dict[rec_name]["b_wer"].ref_words += 1 + ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1 + else: + ub_wer_dict[rec_name]["u_wer"].ref_words += 1 + ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1 + elif code == Code.insertion: + ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1 + if rec_word in hot_true_list: + ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1 + else: + ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1 + + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + # print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + print('\n', end='\n') + # break + if verbose: + print('===========================================================================') + print() + + print(wrong_rec_but_in_ocr_dict) + for rec_name in rec_names: + result = calculators_dict[rec_name].overall() + + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}") + print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}") + print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}") + + print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format( + hotwords_related_dict[rec_name]['tp'], + hotwords_related_dict[rec_name]['tn'], + hotwords_related_dict[rec_name]['fp'], + hotwords_related_dict[rec_name]['fn'], + sum([v for k, v in hotwords_related_dict[rec_name].items()]), + hotwords_related_dict[rec_name]['tp'] / ( + hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] + ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0 + )) + + # tp: 热词在label里,同时在rec里 + # tn: 热词不在label里,同时不在rec里 + # fp: 热词不在label里,但是在rec里 + # fn: 热词在label里,但是不在rec里 + if not verbose: + print() + print() + + +if __name__ == "__main__": + args = get_args() + + # print("") + print(args) + main(args) + diff --git a/examples/industrial_data_pretraining/lcbnet/demo.py b/examples/industrial_data_pretraining/lcbnet/demo.py new file mode 100755 index 000000000..4ca52553f --- /dev/null +++ b/examples/industrial_data_pretraining/lcbnet/demo.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel(model="iic/LCB-NET", + model_revision="v1.0.0") + +res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text")) + +print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/lcbnet/demo.sh b/examples/industrial_data_pretraining/lcbnet/demo.sh new file mode 100755 index 000000000..2f226bc03 --- /dev/null +++ b/examples/industrial_data_pretraining/lcbnet/demo.sh @@ -0,0 +1,72 @@ +file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/" +CUDA_VISIBLE_DEVICES="0,1" +inference_device="cuda" + +if [ ${inference_device} == "cuda" ]; then + nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +else + inference_batch_size=1 + CUDA_VISIBLE_DEVICES="" + for JOB in $(seq ${nj}); do + CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1," + done +fi + +inference_dir="outputs/slidespeech_dev" +_logdir="${inference_dir}/logdir" +echo "inference_dir: ${inference_dir}" + +mkdir -p "${_logdir}" +key_file1=${file_dir}/dev/wav.scp +key_file2=${file_dir}/dev/ocr.txt +split_scps1= +split_scps2= +for JOB in $(seq "${nj}"); do + split_scps1+=" ${_logdir}/wav.${JOB}.scp" + split_scps2+=" ${_logdir}/ocr.${JOB}.txt" +done +utils/split_scp.pl "${key_file1}" ${split_scps1} +utils/split_scp.pl "${key_file2}" ${split_scps2} + +gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) +for JOB in $(seq ${nj}); do + { + id=$((JOB-1)) + gpuid=${gpuid_list_array[$id]} + + export CUDA_VISIBLE_DEVICES=${gpuid} + + python -m funasr.bin.inference \ + --config-path=${file_dir} \ + --config-name="config.yaml" \ + ++init_param=${file_dir}/model.pt \ + ++tokenizer_conf.token_list=${file_dir}/tokens.txt \ + ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \ + +data_type='["kaldi_ark", "text"]' \ + ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \ + ++normalize_conf.stats_file=${file_dir}/am.mvn \ + ++output_dir="${inference_dir}/${JOB}" \ + ++device="${inference_device}" \ + ++ncpu=1 \ + ++disable_log=true &> ${_logdir}/log.${JOB}.txt + + }& +done +wait + + +mkdir -p ${inference_dir}/1best_recog + +for JOB in $(seq "${nj}"); do + cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token" +done + +echo "Computing WER ..." +sed -e 's/ /\t/' -e 's/ //g' -e 's/▁/ /g' -e 's/\t /\t/' ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc +cp ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref +cp ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list +python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer +tail -n 3 ${inference_dir}/1best_recog/token.cer + +./run_bwer_recall.sh ${inference_dir}/1best_recog/ +tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5 diff --git a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh new file mode 100755 index 000000000..7d6b6ff8b --- /dev/null +++ b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh @@ -0,0 +1,11 @@ +#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave +#hotword_type=ocr_1ngram_top10_hotwords_list +hot_exp_suf=$1 + + +python compute_wer_details.py --v 1 \ + --ref ${hot_exp_suf}/token.ref \ + --ref_ocr ${hot_exp_suf}/ocr.list \ + --rec_name base \ + --rec_file ${hot_exp_suf}/token.proc \ + > ${hot_exp_suf}/BWER-UWER.results diff --git a/examples/industrial_data_pretraining/lcbnet/utils b/examples/industrial_data_pretraining/lcbnet/utils new file mode 120000 index 000000000..be5e5a322 --- /dev/null +++ b/examples/industrial_data_pretraining/lcbnet/utils @@ -0,0 +1 @@ +../../aishell/paraformer/utils \ No newline at end of file diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index a44c649ae..551dd8bf8 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -7,10 +7,10 @@ from funasr import AutoModel model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4", - vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_model_revision="v2.0.4", - punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - punc_model_revision="v2.0.4", + # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + # vad_model_revision="v2.0.4", + # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + # punc_model_revision="v2.0.4", # spk_model="damo/speech_campplus_sv_zh-cn_16k-common", # spk_model_revision="v2.0.2", ) @@ -43,4 +43,4 @@ import soundfile wav_file = os.path.join(model.model_path, "example/asr_example.wav") speech, sample_rate = soundfile.read(wav_file) res = model.generate(input=[speech], batch_size_s=300, is_final=True) -''' \ No newline at end of file +''' diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 921ede809..ec3c3f370 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -28,7 +28,7 @@ try: from funasr.models.campplus.cluster_backend import ClusterBackend except: print("If you want to use the speaker diarization, please `pip install hdbscan`") - +import pdb def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): """ @@ -46,6 +46,7 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): chars = string.ascii_letters + string.digits if isinstance(data_in, str) and data_in.startswith('http'): # url data_in = download_from_url(data_in) + if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt; _, file_extension = os.path.splitext(data_in) file_extension = file_extension.lower() @@ -146,7 +147,7 @@ class AutoModel: kwargs = download_model(**kwargs) set_all_random_seed(kwargs.get("seed", 0)) - + device = kwargs.get("device", "cuda") if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: device = "cpu" @@ -168,7 +169,6 @@ class AutoModel: vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 else: vocab_size = -1 - # build frontend frontend = kwargs.get("frontend", None) kwargs["input_size"] = None @@ -181,7 +181,6 @@ class AutoModel: # build model model_class = tables.model_classes.get(kwargs["model"]) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) - model.to(device) # init_param @@ -224,9 +223,9 @@ class AutoModel: batch_size = kwargs.get("batch_size", 1) # if kwargs.get("device", "cpu") == "cpu": # batch_size = 1 - + key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key) - + speed_stats = {} asr_result_list = [] num_samples = len(data_list) @@ -239,6 +238,7 @@ class AutoModel: data_batch = data_list[beg_idx:end_idx] key_batch = key_list[beg_idx:end_idx] batch = {"data_in": data_batch, "key": key_batch} + if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank batch["data_in"] = data_batch[0] batch["data_lengths"] = input_len diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py index 8ac1ca853..c4bdbd774 100644 --- a/funasr/frontends/default.py +++ b/funasr/frontends/default.py @@ -3,7 +3,6 @@ from typing import Optional from typing import Tuple from typing import Union import logging -import humanfriendly import numpy as np import torch import torch.nn as nn @@ -16,8 +15,10 @@ from funasr.frontends.utils.log_mel import LogMel from funasr.frontends.utils.stft import Stft from funasr.frontends.utils.frontend import Frontend from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.register import tables +@tables.register("frontend_classes", "DefaultFrontend") class DefaultFrontend(nn.Module): """Conventional frontend structure for ASR. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN @@ -25,7 +26,7 @@ class DefaultFrontend(nn.Module): def __init__( self, - fs: Union[int, str] = 16000, + fs: int = 16000, n_fft: int = 512, win_length: int = None, hop_length: int = 128, @@ -40,14 +41,14 @@ class DefaultFrontend(nn.Module): frontend_conf: Optional[dict] = None, apply_stft: bool = True, use_channel: int = None, + **kwargs, ): super().__init__() - if isinstance(fs, str): - fs = humanfriendly.parse_size(fs) # Deepcopy (In general, dict shouldn't be used as default arg) frontend_conf = copy.deepcopy(frontend_conf) self.hop_length = hop_length + self.fs = fs if apply_stft: self.stft = Stft( @@ -84,8 +85,12 @@ class DefaultFrontend(nn.Module): return self.n_mels def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor + self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list] ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(input_lengths, list): + input_lengths = torch.tensor(input_lengths) + if input.dtype == torch.float64: + input = input.float() # 1. Domain-conversion: e.g. Stft: time -> time-freq if self.stft is not None: input_stft, feats_lens = self._compute_stft(input, input_lengths) @@ -145,7 +150,7 @@ class MultiChannelFrontend(nn.Module): def __init__( self, - fs: Union[int, str] = 16000, + fs: int = 16000, n_fft: int = 512, win_length: int = None, hop_length: int = None, @@ -168,9 +173,6 @@ class MultiChannelFrontend(nn.Module): mc: bool = True ): super().__init__() - if isinstance(fs, str): - fs = humanfriendly.parse_size(fs) - # Deepcopy (In general, dict shouldn't be used as default arg) frontend_conf = copy.deepcopy(frontend_conf) if win_length is None and hop_length is None: diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py index 1d252c206..be973c641 100644 --- a/funasr/models/conformer/encoder.py +++ b/funasr/models/conformer/encoder.py @@ -47,7 +47,7 @@ from funasr.models.transformer.utils.subsampling import check_short_utt from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad from funasr.models.transformer.utils.subsampling import StreamingConvInput from funasr.register import tables - +import pdb class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py index 49868a8f4..7d6f729a8 100644 --- a/funasr/models/contextual_paraformer/model.py +++ b/funasr/models/contextual_paraformer/model.py @@ -29,7 +29,7 @@ from funasr.train_utils.device_funcs import force_gatherable from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank - +import pdb if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -63,7 +63,6 @@ class ContextualParaformer(Paraformer): crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0) bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) - if bias_encoder_type == 'lstm': self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate) self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim) @@ -103,17 +102,16 @@ class ContextualParaformer(Paraformer): text_lengths = text_lengths[:, 0] if len(speech_lengths.size()) > 1: speech_lengths = speech_lengths[:, 0] - + batch_size = speech.shape[0] hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") dha_pad = kwargs.get("dha_pad") - + # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - loss_ctc, cer_ctc = None, None stats = dict() @@ -128,12 +126,11 @@ class ContextualParaformer(Paraformer): stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc - # 2b. Attention decoder branch loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss( encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths ) - + # 3. CTC-Att loss definition if self.ctc_weight == 0.0: loss = loss_att + loss_pre * self.predictor_weight @@ -171,22 +168,24 @@ class ContextualParaformer(Paraformer): ): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) + if self.predictor_bias == 1: _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_pad_lens = ys_pad_lens + self.predictor_bias + pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) - # -1. bias encoder if self.use_decoder_embedding: hw_embed = self.decoder.embed(hotword_pad) else: hw_embed = self.bias_embed(hotword_pad) + hw_embed, (_, _) = self.bias_encoder(hw_embed) _ind = np.arange(0, hotword_pad.shape[0]).tolist() selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]] contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) - + # 0. sampler decoder_out_1st = None if self.sampling_ratio > 0.0: @@ -195,7 +194,7 @@ class ContextualParaformer(Paraformer): pre_acoustic_embeds, contextual_info) else: sematic_embeds = pre_acoustic_embeds - + # 1. Forward decoder decoder_outs = self.decoder( encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info @@ -211,7 +210,7 @@ class ContextualParaformer(Paraformer): loss_ideal = None ''' loss_ideal = None - + if decoder_out_1st is None: decoder_out_1st = decoder_out # 2. Compute attention loss @@ -288,10 +287,11 @@ class ContextualParaformer(Paraformer): enforce_sorted=False) _, (h_n, _) = self.bias_encoder(hw_embed) hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) - + decoder_outs = self.decoder( encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale ) + decoder_out = decoder_outs[0] decoder_out = torch.log_softmax(decoder_out, dim=-1) return decoder_out, ys_pad_lens @@ -305,38 +305,42 @@ class ContextualParaformer(Paraformer): **kwargs, ): # init beamsearch + is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None if self.beam_search is None and (is_use_lm or is_use_ctc): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - + meta_data = {} # extract fbank feats time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend) time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data[ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - + speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) # hotword self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) - + # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] - + # predictor predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ @@ -344,8 +348,7 @@ class ContextualParaformer(Paraformer): pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] - - + decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, diff --git a/funasr/models/lcbnet/__init__.py b/funasr/models/lcbnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py new file mode 100644 index 000000000..8e8c5943a --- /dev/null +++ b/funasr/models/lcbnet/attention.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 yufan +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention Return Weight layer definition.""" + +import math + +import torch +from torch import nn + +class MultiHeadedAttentionReturnWeight(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttentionReturnWeight object.""" + super(MultiHeadedAttentionReturnWeight, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = torch.finfo(scores.dtype).min + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x), self.attn # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py new file mode 100644 index 000000000..c65823cb0 --- /dev/null +++ b/funasr/models/lcbnet/encoder.py @@ -0,0 +1,392 @@ +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Transformer encoder definition.""" + +from typing import List +from typing import Optional +from typing import Tuple + +import torch +from torch import nn +import logging + +from funasr.models.transformer.attention import MultiHeadedAttention +from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight +from funasr.models.transformer.embedding import PositionalEncoding +from funasr.models.transformer.layer_norm import LayerNorm + +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward +from funasr.models.transformer.utils.repeat import repeat +from funasr.register import tables + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + stochastic_depth_rate (float): Proability to skip this layer. + During training, the layer may skip residual computation and return input + as-is with given probability. + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + + def forward(self, x, mask, cache=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn(x_q, x, x, mask) + ) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, mask + +@tables.register("encoder_classes", "TransformerTextEncoder") +class TransformerTextEncoder(nn.Module): + """Transformer text encoder module. + + Args: + input_size: input dim + output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + attention_dropout_rate: dropout rate in attention + positional_dropout_rate: dropout rate after adding positional encoding + input_layer: input layer type + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + positionwise_layer_type: linear of conv1d + positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + super().__init__() + self._output_size = output_size + + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size), + pos_enc_class(output_size, positional_dropout_rate), + ) + + self.normalize_before = normalize_before + + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + MultiHeadedAttention( + attention_heads, output_size, attention_dropout_rate + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + xs_pad = self.embed(xs_pad) + + xs_pad, masks = self.encoders(xs_pad, masks) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + return xs_pad, olens, None + + + + +@tables.register("encoder_classes", "FusionSANEncoder") +class SelfSrcAttention(nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + + """ + def __init__( + self, + size, + attention_heads, + attention_dim, + linear_units, + self_attention_dropout_rate, + src_attention_dropout_rate, + positional_dropout_rate, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an SelfSrcAttention object.""" + super(SelfSrcAttention, self).__init__() + self.size = size + self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate) + self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate) + self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate) + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = torch.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 + ) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x, score = self.src_attn(x, memory, memory, memory_mask) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask + + +@tables.register("encoder_classes", "ConvBiasPredictor") +class ConvPredictor(nn.Module): + def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048): + super().__init__() + self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate) + self.norm1 = LayerNorm(size) + self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate) + self.norm2 = LayerNorm(size) + self.pad = nn.ConstantPad1d((l_order, r_order), 0) + self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size) + self.output_linear = nn.Linear(size, 1) + + + def forward(self, text_enc, asr_enc): + # stage1 cross-attention + residual = text_enc + text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None) + + # stage2 FFN + residual = text_enc + text_enc = self.norm1(text_enc) + text_enc = residual + self.feed_forward(text_enc) + + # stage Conv predictor + text_enc = self.norm2(text_enc) + context = text_enc.transpose(1, 2) + queries = self.pad(context) + memory = self.conv1d(queries) + output = memory + context + output = output.transpose(1, 2) + output = torch.relu(output) + output = self.output_linear(output) + if output.dim()==3: + output = output.squeeze(2) + return output diff --git a/funasr/models/lcbnet/model.py b/funasr/models/lcbnet/model.py new file mode 100644 index 000000000..3ac319c61 --- /dev/null +++ b/funasr/models/lcbnet/model.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import logging +from typing import Union, Dict, List, Tuple, Optional + +import time +import torch +import torch.nn as nn +from torch.cuda.amp import autocast + +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.ctc.ctc import CTC +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.metrics.compute_acc import th_accuracy +# from funasr.models.e2e_asr_common import ErrorCalculator +from funasr.train_utils.device_funcs import force_gatherable +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.utils import postprocess_utils +from funasr.utils.datadir_writer import DatadirWriter +from funasr.register import tables + +import pdb +@tables.register("model_classes", "LCBNet") +class LCBNet(nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION + https://arxiv.org/abs/2401.06390 + """ + + def __init__( + self, + specaug: str = None, + specaug_conf: dict = None, + normalize: str = None, + normalize_conf: dict = None, + encoder: str = None, + encoder_conf: dict = None, + decoder: str = None, + decoder_conf: dict = None, + text_encoder: str = None, + text_encoder_conf: dict = None, + bias_predictor: str = None, + bias_predictor_conf: dict = None, + fusion_encoder: str = None, + fusion_encoder_conf: dict = None, + ctc: str = None, + ctc_conf: dict = None, + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + select_num: int = 2, + select_length: int = 3, + insert_blank: bool = True, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, + **kwargs, + ): + + super().__init__() + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + + # lcbnet modules: text encoder, fusion encoder and bias predictor + text_encoder_class = tables.encoder_classes.get(text_encoder) + text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf) + fusion_encoder_class = tables.encoder_classes.get(fusion_encoder) + fusion_encoder = fusion_encoder_class(**fusion_encoder_conf) + bias_predictor_class = tables.encoder_classes.get(bias_predictor) + bias_predictor = bias_predictor_class(**bias_predictor_conf) + + + if decoder is not None: + decoder_class = tables.decoder_classes.get(decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **decoder_conf, + ) + if ctc_weight > 0.0: + + if ctc_conf is None: + ctc_conf = {} + + ctc = CTC( + odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf + ) + + self.blank_id = blank_id + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.specaug = specaug + self.normalize = normalize + self.encoder = encoder + # lcbnet + self.text_encoder = text_encoder + self.fusion_encoder = fusion_encoder + self.bias_predictor = bias_predictor + self.select_num = select_num + self.select_length = select_length + self.insert_blank = insert_blank + + if not hasattr(self.encoder, "interctc_use_conditioning"): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size() + ) + self.interctc_weight = interctc_weight + + # self.error_calculator = None + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + # + # if report_cer or report_wer: + # self.error_calculator = ErrorCalculator( + # token_list, sym_space, sym_blank, report_cer, report_wer + # ) + # + self.error_calculator = None + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.share_embedding = share_embedding + if self.share_embedding: + self.decoder.embed = None + + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + stats = dict() + + # decoder: CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + # decoder: Attention decoder branch + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + + # Collect total loss stats + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + # Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + speech, speech_lengths, ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + return encoder_out, encoder_out_lens + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def init_beam_search(self, + **kwargs, + ): + from funasr.models.transformer.search import BeamSearch + from funasr.models.transformer.scorers.ctc import CTCPrefixScorer + from funasr.models.transformer.scorers.length_bonus import LengthBonus + + # 1. Build ASR model + scorers = {} + + if self.ctc != None: + ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) + scorers.update( + ctc=ctc + ) + token_list = kwargs.get("token_list") + scorers.update( + decoder=self.decoder, + length_bonus=LengthBonus(len(token_list)), + ) + + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + weights = dict( + decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3), + ctc=kwargs.get("decoding_ctc_weight", 0.3), + lm=kwargs.get("lm_weight", 0.0), + ngram=kwargs.get("ngram_weight", 0.0), + length_bonus=kwargs.get("penalty", 0.0), + ) + beam_search = BeamSearch( + beam_size=kwargs.get("beam_size", 20), + weights=weights, + scorers=scorers, + sos=self.sos, + eos=self.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", + ) + + self.beam_search = beam_search + + def inference(self, + data_in, + data_lengths=None, + key: list=None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + # init beamsearch + if self.beam_search is None: + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + audio_sample_list = sample_list[0] + if len(sample_list) >1: + ocr_sample_list = sample_list[1] + else: + ocr_sample_list = [[294, 0]] + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + frame_shift = 10 + meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000 + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list] + ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"]) + ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"]) + ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths) + fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None) + encoder_out = encoder_out + fusion_out + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + b, n, d = encoder_out.size() + for i in range(b): + + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text = tokenizer.tokens2text(token) + + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text_postprocessed} + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["text"][key[i]] = text_postprocessed + + return results, meta_data + diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index 20b0cc838..a8b1f1fb1 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -30,7 +30,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank - +import pdb if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: @@ -128,7 +128,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") dha_pad = kwargs.get("dha_pad") - + batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] @@ -209,17 +209,20 @@ class SeacoParaformer(BiCifParaformer, Paraformer): nfilter=50, seaco_weight=1.0): # decoder forward + decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) + decoder_pred = torch.log_softmax(decoder_out, dim=-1) if hw_list is not None: hw_lengths = [len(i) for i in hw_list] hw_list_ = [torch.Tensor(i).long() for i in hw_list] hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) + contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) num_hot_word = contextual_info.shape[1] _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) - + # ASF Core if nfilter > 0 and nfilter < num_hot_word: hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) @@ -239,7 +242,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) merged = self._merge(cif_attended, dec_attended) - + dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation dha_pred = torch.log_softmax(dha_output, dim=-1) def _merge_res(dec_output, dha_output): @@ -253,8 +256,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer): # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) return logits + merged_pred = _merge_res(decoder_pred, dha_pred) - # import pdb; pdb.set_trace() return merged_pred else: return decoder_pred @@ -304,7 +307,6 @@ class SeacoParaformer(BiCifParaformer, Paraformer): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - meta_data = {} # extract fbank feats @@ -330,6 +332,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] + # predictor predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ @@ -338,15 +341,14 @@ class SeacoParaformer(BiCifParaformer, Paraformer): if torch.max(pre_token_length) < 1: return [] - decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) + # decoder_out, _ = decoder_outs[0], decoder_outs[1] _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, pre_token_length) - results = [] b, n, d = decoder_out.size() for i in range(b): diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index ea2372537..0c46449e4 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -7,7 +7,7 @@ import logging import torch import torch.nn import torch.optim - +import pdb def filter_state_dict( dst_state: Dict[str, Union[float, torch.Tensor]], @@ -63,6 +63,7 @@ def load_pretrained_model( dst_state = obj.state_dict() print(f"ckpt: {path}") + if oss_bucket is None: src_state = torch.load(path, map_location=map_location) else: diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py index 7748172f6..84c38f9b9 100644 --- a/funasr/utils/load_utils.py +++ b/funasr/utils/load_utils.py @@ -13,29 +13,25 @@ try: from funasr.download.file import download_from_url except: print("urllib is not installed, if you infer from url, please install it first.") - +import pdb def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs): if isinstance(data_or_path_or_list, (list, tuple)): if data_type is not None and isinstance(data_type, (list, tuple)): - data_types = [data_type] * len(data_or_path_or_list) data_or_path_or_list_ret = [[] for d in data_type] for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)): - for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)): - data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs) data_or_path_or_list_ret[j].append(data_or_path_or_list_j) return data_or_path_or_list_ret else: return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list] - if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file data_or_path_or_list = download_from_url(data_or_path_or_list) - + if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file if data_type is None or data_type == "sound": data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list) @@ -56,10 +52,22 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: data_or_path_or_list = tokenizer.encode(data_or_path_or_list) elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,] + elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark": + data_mat = kaldiio.load_mat(data_or_path_or_list) + if isinstance(data_mat, tuple): + audio_fs, mat = data_mat + else: + mat = data_mat + if mat.dtype == 'int16' or mat.dtype == 'int32': + mat = mat.astype(np.float64) + mat = mat / 32768 + if mat.ndim ==2: + mat = mat[:,0] + data_or_path_or_list = mat else: pass # print(f"unsupport data type: {data_or_path_or_list}, return raw data") - + if audio_fs != fs and data_type != "text": resampler = torchaudio.transforms.Resample(audio_fs, fs) data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :] @@ -81,8 +89,6 @@ def load_bytes(input): return array def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs): - # import pdb; - # pdb.set_trace() if isinstance(data, np.ndarray): data = torch.from_numpy(data) if len(data.shape) < 2: @@ -100,9 +106,7 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, data_list.append(data_i) data_len.append(data_i.shape[0]) data = pad_sequence(data_list, batch_first=True) # data: [batch, N] - # import pdb; - # pdb.set_trace() - # if data_type == "sound": + data, data_len = frontend(data, data_len, **kwargs) if isinstance(data_len, (list, tuple)):