diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py old mode 100755 new mode 100644 diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh old mode 100755 new mode 100644 index 1bd4f7f5b..8fc66f34f --- a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh +++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh @@ -2,7 +2,7 @@ model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" model_revision="v2.0.4" -python ../../../funasr/bin/inference.py \ +python funasr/bin/inference.py \ +model=${model} \ +model_revision=${model_revision} \ +input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \ diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh deleted file mode 100755 index 282f4f1f2..000000000 --- a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh +++ /dev/null @@ -1,9 +0,0 @@ -python -m funasr.bin.inference \ ---config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \ ---config-name="config.yaml" \ -++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \ -++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \ -++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \ -++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \ -++output_dir="./outputs/debug2" \ -++device="" \ diff --git a/examples/industrial_data_pretraining/contextual_paraformer/path.sh b/examples/industrial_data_pretraining/contextual_paraformer/path.sh deleted file mode 100755 index 1a6d67e08..000000000 --- a/examples/industrial_data_pretraining/contextual_paraformer/path.sh +++ /dev/null @@ -1,6 +0,0 @@ -export FUNASR_DIR=$PWD/../../../ - -# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PATH=$FUNASR_DIR/funasr/bin:$PATH -export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py deleted file mode 100755 index e72d87155..000000000 --- a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py +++ /dev/null @@ -1,702 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - - -from enum import Enum -import re, sys, unicodedata -import codecs -import argparse -from tqdm import tqdm -import os -import pdb -remove_tag = False -spacelist = [" ", "\t", "\r", "\n"] -puncts = [ - "!", - ",", - "?", - "、", - "。", - "!", - ",", - ";", - "?", - ":", - "「", - "」", - "︰", - "『", - "』", - "《", - "》", -] - - -class Code(Enum): - match = 1 - substitution = 2 - insertion = 3 - deletion = 4 - - -class WordError(object): - def __init__(self): - self.errors = { - Code.substitution: 0, - Code.insertion: 0, - Code.deletion: 0, - } - self.ref_words = 0 - - def get_wer(self): - assert self.ref_words != 0 - errors = ( - self.errors[Code.substitution] - + self.errors[Code.insertion] - + self.errors[Code.deletion] - ) - return 100.0 * errors / self.ref_words - - def get_result_string(self): - return ( - f"error_rate={self.get_wer():.4f}, " - f"ref_words={self.ref_words}, " - f"subs={self.errors[Code.substitution]}, " - f"ins={self.errors[Code.insertion]}, " - f"dels={self.errors[Code.deletion]}" - ) - - -def characterize(string): - res = [] - i = 0 - while i < len(string): - char = string[i] - if char in puncts: - i += 1 - continue - cat1 = unicodedata.category(char) - # https://unicodebook.readthedocs.io/unicode.html#unicode-categories - if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned - i += 1 - continue - if cat1 == "Lo": # letter-other - res.append(char) - i += 1 - else: - # some input looks like: , we want to separate it to two words. - sep = " " - if char == "<": - sep = ">" - j = i + 1 - while j < len(string): - c = string[j] - if ord(c) >= 128 or (c in spacelist) or (c == sep): - break - j += 1 - if j < len(string) and string[j] == ">": - j += 1 - res.append(string[i:j]) - i = j - return res - - -def stripoff_tags(x): - if not x: - return "" - chars = [] - i = 0 - T = len(x) - while i < T: - if x[i] == "<": - while i < T and x[i] != ">": - i += 1 - i += 1 - else: - chars.append(x[i]) - i += 1 - return "".join(chars) - - -def normalize(sentence, ignore_words, cs, split=None): - """sentence, ignore_words are both in unicode""" - new_sentence = [] - for token in sentence: - x = token - if not cs: - x = x.upper() - if x in ignore_words: - continue - if remove_tag: - x = stripoff_tags(x) - if not x: - continue - if split and x in split: - new_sentence += split[x] - else: - new_sentence.append(x) - return new_sentence - - -class Calculator: - def __init__(self): - self.data = {} - self.space = [] - self.cost = {} - self.cost["cor"] = 0 - self.cost["sub"] = 1 - self.cost["del"] = 1 - self.cost["ins"] = 1 - - def calculate(self, lab, rec): - # Initialization - lab.insert(0, "") - rec.insert(0, "") - while len(self.space) < len(lab): - self.space.append([]) - for row in self.space: - for element in row: - element["dist"] = 0 - element["error"] = "non" - while len(row) < len(rec): - row.append({"dist": 0, "error": "non"}) - for i in range(len(lab)): - self.space[i][0]["dist"] = i - self.space[i][0]["error"] = "del" - for j in range(len(rec)): - self.space[0][j]["dist"] = j - self.space[0][j]["error"] = "ins" - self.space[0][0]["error"] = "non" - for token in lab: - if token not in self.data and len(token) > 0: - self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} - for token in rec: - if token not in self.data and len(token) > 0: - self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} - # Computing edit distance - for i, lab_token in enumerate(lab): - for j, rec_token in enumerate(rec): - if i == 0 or j == 0: - continue - min_dist = sys.maxsize - min_error = "none" - dist = self.space[i - 1][j]["dist"] + self.cost["del"] - error = "del" - if dist < min_dist: - min_dist = dist - min_error = error - dist = self.space[i][j - 1]["dist"] + self.cost["ins"] - error = "ins" - if dist < min_dist: - min_dist = dist - min_error = error - if lab_token == rec_token.replace("", ""): - dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"] - error = "cor" - else: - dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"] - error = "sub" - if dist < min_dist: - min_dist = dist - min_error = error - self.space[i][j]["dist"] = min_dist - self.space[i][j]["error"] = min_error - # Tracing back - result = { - "lab": [], - "rec": [], - "code": [], - "all": 0, - "cor": 0, - "sub": 0, - "ins": 0, - "del": 0, - } - i = len(lab) - 1 - j = len(rec) - 1 - while True: - if self.space[i][j]["error"] == "cor": # correct - if len(lab[i]) > 0: - self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 - self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1 - result["all"] = result["all"] + 1 - result["cor"] = result["cor"] + 1 - result["lab"].insert(0, lab[i]) - result["rec"].insert(0, rec[j]) - result["code"].insert(0, Code.match) - i = i - 1 - j = j - 1 - elif self.space[i][j]["error"] == "sub": # substitution - if len(lab[i]) > 0: - self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 - self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1 - result["all"] = result["all"] + 1 - result["sub"] = result["sub"] + 1 - result["lab"].insert(0, lab[i]) - result["rec"].insert(0, rec[j]) - result["code"].insert(0, Code.substitution) - i = i - 1 - j = j - 1 - elif self.space[i][j]["error"] == "del": # deletion - if len(lab[i]) > 0: - self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 - self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1 - result["all"] = result["all"] + 1 - result["del"] = result["del"] + 1 - result["lab"].insert(0, lab[i]) - result["rec"].insert(0, "") - result["code"].insert(0, Code.deletion) - i = i - 1 - elif self.space[i][j]["error"] == "ins": # insertion - if len(rec[j]) > 0: - self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1 - result["ins"] = result["ins"] + 1 - result["lab"].insert(0, "") - result["rec"].insert(0, rec[j]) - result["code"].insert(0, Code.insertion) - j = j - 1 - elif self.space[i][j]["error"] == "non": # starting point - break - else: # shouldn't reach here - print( - "this should not happen , i = {i} , j = {j} , error = {error}".format( - i=i, j=j, error=self.space[i][j]["error"] - ) - ) - return result - - def overall(self): - result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} - for token in self.data: - result["all"] = result["all"] + self.data[token]["all"] - result["cor"] = result["cor"] + self.data[token]["cor"] - result["sub"] = result["sub"] + self.data[token]["sub"] - result["ins"] = result["ins"] + self.data[token]["ins"] - result["del"] = result["del"] + self.data[token]["del"] - return result - - def cluster(self, data): - result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} - for token in data: - if token in self.data: - result["all"] = result["all"] + self.data[token]["all"] - result["cor"] = result["cor"] + self.data[token]["cor"] - result["sub"] = result["sub"] + self.data[token]["sub"] - result["ins"] = result["ins"] + self.data[token]["ins"] - result["del"] = result["del"] + self.data[token]["del"] - return result - - def keys(self): - return list(self.data.keys()) - - -def width(string): - return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) - - -def default_cluster(word): - unicode_names = [unicodedata.name(char) for char in word] - for i in reversed(range(len(unicode_names))): - if unicode_names[i].startswith("DIGIT"): # 1 - unicode_names[i] = "Number" # 'DIGIT' - elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[ - i - ].startswith("CJK COMPATIBILITY IDEOGRAPH"): - # 明 / 郎 - unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH' - elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[ - i - ].startswith("LATIN SMALL LETTER"): - # A / a - unicode_names[i] = "English" # 'LATIN LETTER' - elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め - unicode_names[i] = "Japanese" # 'GANA LETTER' - elif ( - unicode_names[i].startswith("AMPERSAND") - or unicode_names[i].startswith("APOSTROPHE") - or unicode_names[i].startswith("COMMERCIAL AT") - or unicode_names[i].startswith("DEGREE CELSIUS") - or unicode_names[i].startswith("EQUALS SIGN") - or unicode_names[i].startswith("FULL STOP") - or unicode_names[i].startswith("HYPHEN-MINUS") - or unicode_names[i].startswith("LOW LINE") - or unicode_names[i].startswith("NUMBER SIGN") - or unicode_names[i].startswith("PLUS SIGN") - or unicode_names[i].startswith("SEMICOLON") - ): - # & / ' / @ / ℃ / = / . / - / _ / # / + / ; - del unicode_names[i] - else: - return "Other" - if len(unicode_names) == 0: - return "Other" - if len(unicode_names) == 1: - return unicode_names[0] - for i in range(len(unicode_names) - 1): - if unicode_names[i] != unicode_names[i + 1]: - return "Other" - return unicode_names[0] - - -def get_args(): - parser = argparse.ArgumentParser(description="wer cal") - parser.add_argument("--ref", type=str, help="Text input path") - parser.add_argument("--ref_ocr", type=str, help="Text input path") - parser.add_argument("--rec_name", type=str, action="append", default=[]) - parser.add_argument("--rec_file", type=str, action="append", default=[]) - parser.add_argument("--verbose", type=int, default=1, help="show") - parser.add_argument("--char", type=bool, default=True, help="show") - args = parser.parse_args() - return args - - -def main(args): - cluster_file = "" - ignore_words = set() - tochar = args.char - verbose = args.verbose - padding_symbol = " " - case_sensitive = False - max_words_per_line = sys.maxsize - split = None - - if not case_sensitive: - ig = set([w.upper() for w in ignore_words]) - ignore_words = ig - - default_clusters = {} - default_words = {} - ref_file = args.ref - ref_ocr = args.ref_ocr - rec_files = args.rec_file - rec_names = args.rec_name - assert len(rec_files) == len(rec_names) - - # load ocr - ref_ocr_dict = {} - with codecs.open(ref_ocr, "r", "utf-8") as fh: - for line in fh: - if "$" in line: - line = line.replace("$", " ") - if tochar: - array = characterize(line) - else: - array = line.strip().split() - if len(array) == 0: - continue - fid = array[0] - ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split) - - if split and not case_sensitive: - newsplit = dict() - for w in split: - words = split[w] - for i in range(len(words)): - words[i] = words[i].upper() - newsplit[w.upper()] = words - split = newsplit - - rec_sets = {} - calculators_dict = dict() - ub_wer_dict = dict() - hotwords_related_dict = dict() # 记录recall相关的内容 - for i, hyp_file in enumerate(rec_files): - rec_sets[rec_names[i]] = dict() - with codecs.open(hyp_file, "r", "utf-8") as fh: - for line in fh: - if tochar: - array = characterize(line) - else: - array = line.strip().split() - if len(array) == 0: - continue - fid = array[0] - rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split) - - calculators_dict[rec_names[i]] = Calculator() - ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()} - hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} - # tp: 热词在label里,同时在rec里 - # tn: 热词不在label里,同时不在rec里 - # fp: 热词不在label里,但是在rec里 - # fn: 热词在label里,但是不在rec里 - - # record wrong label but in ocr - wrong_rec_but_in_ocr_dict = {} - for rec_name in rec_names: - wrong_rec_but_in_ocr_dict[rec_name] = 0 - - _file_total_len = 0 - with os.popen("cat {} | wc -l".format(ref_file)) as pipe: - _file_total_len = int(pipe.read().strip()) - - # compute error rate on the interaction of reference file and hyp file - for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len): - if tochar: - array = characterize(line) - else: - array = line.rstrip('\n').split() - if len(array) == 0: continue - fid = array[0] - lab = normalize(array[1:], ignore_words, case_sensitive, split) - - if verbose: - print('\nutt: %s' % fid) - - ocr_text = ref_ocr_dict[fid] - ocr_set = set(ocr_text) - print('ocr: {}'.format(" ".join(ocr_text))) - list_match = [] # 指label里面在ocr里面的内容 - list_not_mathch = [] - tmp_error = 0 - tmp_match = 0 - for index in range(len(lab)): - # text_list.append(uttlist[index+1]) - if lab[index] not in ocr_set: - tmp_error += 1 - list_not_mathch.append(lab[index]) - else: - tmp_match += 1 - list_match.append(lab[index]) - print('label in ocr: {}'.format(" ".join(list_match))) - - # for each reco file - base_wrong_ocr_wer = None - ocr_wrong_ocr_wer = None - - for rec_name in rec_names: - rec_set = rec_sets[rec_name] - if fid not in rec_set: - continue - rec = rec_set[fid] - - # print(rec) - for word in rec + lab: - if word not in default_words: - default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters: - default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name]: - default_clusters[default_cluster_name][word] = 1 - default_words[word] = default_cluster_name - - result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy()) - if verbose: - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - - - # print(result['rec']) - wrong_rec_but_in_ocr = [] - for idx in range(len(result['lab'])): - if result['lab'][idx] != "": - if result['lab'][idx] != result['rec'][idx].replace("", ""): - if result['lab'][idx] in list_match: - wrong_rec_but_in_ocr.append(result['lab'][idx]) - wrong_rec_but_in_ocr_dict[rec_name] += 1 - print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr))) - - if rec_name == "base": - base_wrong_ocr_wer = len(wrong_rec_but_in_ocr) - if "ocr" in rec_name or "hot" in rec_name: - ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr) - if ocr_wrong_ocr_wer < base_wrong_ocr_wer: - print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) - elif ocr_wrong_ocr_wer > base_wrong_ocr_wer: - print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) - - # recall = 0 - # false_alarm = 0 - # for idx in range(len(result['lab'])): - # if "" in result['rec'][idx]: - # if result['rec'][idx].replace("", "") in list_match: - # recall += 1 - # else: - # false_alarm += 1 - # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format( - # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0 - # )) - # tp: 热词在label里,同时在rec里 - # tn: 热词不在label里,同时不在rec里 - # fp: 热词不在label里,但是在rec里 - # fn: 热词在label里,但是不在rec里 - _rec_list = [word.replace("", "") for word in rec] - _label_list = [word for word in lab] - _tp = _tn = _fp = _fn = 0 - hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list] - hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list] - for badhotword in hot_bad_list: - count = len([word for word in _rec_list if word == badhotword]) - # print(f"bad {badhotword} count: {count}") - # for word in _rec_list: - # if badhotword == word: - # count += 1 - if count == 0: - hotwords_related_dict[rec_name]['tn'] += 1 - _tn += 1 - # fp: 0 - else: - hotwords_related_dict[rec_name]['fp'] += count - _fp += count - # tn: 0 - # if badhotword in _rec_list: - # hotwords_related_dict[rec_name]['fp'] += 1 - # else: - # hotwords_related_dict[rec_name]['tn'] += 1 - for hotword in hot_true_list: - true_count = len([word for word in _label_list if hotword == word]) - rec_count = len([word for word in _rec_list if hotword == word]) - # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}") - if rec_count == true_count: - hotwords_related_dict[rec_name]['tp'] += true_count - _tp += true_count - elif rec_count > true_count: - hotwords_related_dict[rec_name]['tp'] += true_count - # fp: 不在label里,但是在rec里 - hotwords_related_dict[rec_name]['fp'] += rec_count - true_count - _tp += true_count - _fp += rec_count - true_count - else: - hotwords_related_dict[rec_name]['tp'] += rec_count - # fn: 热词在label里,但是不在rec里 - hotwords_related_dict[rec_name]['fn'] += true_count - rec_count - _tp += rec_count - _fn += true_count - rec_count - print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( - _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0 - )) - - # if hotword in _rec_list: - # hotwords_related_dict[rec_name]['tp'] += 1 - # else: - # hotwords_related_dict[rec_name]['fn'] += 1 - # 计算uwer, bwer, wer - for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]): - if code == Code.match: - ub_wer_dict[rec_name]["wer"].ref_words += 1 - if lab_word in hot_true_list: - # tmp_ref.append(ref_tokens[ref_idx]) - ub_wer_dict[rec_name]["b_wer"].ref_words += 1 - else: - ub_wer_dict[rec_name]["u_wer"].ref_words += 1 - elif code == Code.substitution: - ub_wer_dict[rec_name]["wer"].ref_words += 1 - ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1 - if lab_word in hot_true_list: - # tmp_ref.append(ref_tokens[ref_idx]) - ub_wer_dict[rec_name]["b_wer"].ref_words += 1 - ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1 - else: - ub_wer_dict[rec_name]["u_wer"].ref_words += 1 - ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1 - elif code == Code.deletion: - ub_wer_dict[rec_name]["wer"].ref_words += 1 - ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1 - if lab_word in hot_true_list: - # tmp_ref.append(ref_tokens[ref_idx]) - ub_wer_dict[rec_name]["b_wer"].ref_words += 1 - ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1 - else: - ub_wer_dict[rec_name]["u_wer"].ref_words += 1 - ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1 - elif code == Code.insertion: - ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1 - if rec_word in hot_true_list: - ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1 - else: - ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1 - - space = {} - space['lab'] = [] - space['rec'] = [] - for idx in range(len(result['lab'])): - len_lab = width(result['lab'][idx]) - len_rec = width(result['rec'][idx]) - length = max(len_lab, len_rec) - space['lab'].append(length - len_lab) - space['rec'].append(length - len_rec) - upper_lab = len(result['lab']) - upper_rec = len(result['rec']) - lab1, rec1 = 0, 0 - while lab1 < upper_lab or rec1 < upper_rec: - if verbose > 1: - print('lab(%s):' % fid.encode('utf-8'), end=' ') - else: - print('lab:', end=' ') - lab2 = min(upper_lab, lab1 + max_words_per_line) - for idx in range(lab1, lab2): - token = result['lab'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['lab'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print() - if verbose > 1: - print('rec(%s):' % fid.encode('utf-8'), end=' ') - else: - print('rec:', end=' ') - - rec2 = min(upper_rec, rec1 + max_words_per_line) - for idx in range(rec1, rec2): - token = result['rec'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['rec'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print() - # print('\n', end='\n') - lab1 = lab2 - rec1 = rec2 - print('\n', end='\n') - # break - if verbose: - print('===========================================================================') - print() - - print(wrong_rec_but_in_ocr_dict) - for rec_name in rec_names: - result = calculators_dict[rec_name].overall() - - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}") - print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}") - print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}") - - print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format( - hotwords_related_dict[rec_name]['tp'], - hotwords_related_dict[rec_name]['tn'], - hotwords_related_dict[rec_name]['fp'], - hotwords_related_dict[rec_name]['fn'], - sum([v for k, v in hotwords_related_dict[rec_name].items()]), - hotwords_related_dict[rec_name]['tp'] / ( - hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] - ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0 - )) - - # tp: 热词在label里,同时在rec里 - # tn: 热词不在label里,同时不在rec里 - # fp: 热词不在label里,但是在rec里 - # fn: 热词在label里,但是不在rec里 - if not verbose: - print() - print() - - -if __name__ == "__main__": - args = get_args() - - # print("") - print(args) - main(args) - diff --git a/examples/industrial_data_pretraining/lcbnet/demo.py b/examples/industrial_data_pretraining/lcbnet/demo.py deleted file mode 100755 index 4ca52553f..000000000 --- a/examples/industrial_data_pretraining/lcbnet/demo.py +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env python3 -# -*- encoding: utf-8 -*- -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - -from funasr import AutoModel - -model = AutoModel(model="iic/LCB-NET", - model_revision="v1.0.0") - -res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text")) - -print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/lcbnet/demo.sh b/examples/industrial_data_pretraining/lcbnet/demo.sh deleted file mode 100755 index 2f226bc03..000000000 --- a/examples/industrial_data_pretraining/lcbnet/demo.sh +++ /dev/null @@ -1,72 +0,0 @@ -file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/" -CUDA_VISIBLE_DEVICES="0,1" -inference_device="cuda" - -if [ ${inference_device} == "cuda" ]; then - nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -else - inference_batch_size=1 - CUDA_VISIBLE_DEVICES="" - for JOB in $(seq ${nj}); do - CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1," - done -fi - -inference_dir="outputs/slidespeech_dev" -_logdir="${inference_dir}/logdir" -echo "inference_dir: ${inference_dir}" - -mkdir -p "${_logdir}" -key_file1=${file_dir}/dev/wav.scp -key_file2=${file_dir}/dev/ocr.txt -split_scps1= -split_scps2= -for JOB in $(seq "${nj}"); do - split_scps1+=" ${_logdir}/wav.${JOB}.scp" - split_scps2+=" ${_logdir}/ocr.${JOB}.txt" -done -utils/split_scp.pl "${key_file1}" ${split_scps1} -utils/split_scp.pl "${key_file2}" ${split_scps2} - -gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) -for JOB in $(seq ${nj}); do - { - id=$((JOB-1)) - gpuid=${gpuid_list_array[$id]} - - export CUDA_VISIBLE_DEVICES=${gpuid} - - python -m funasr.bin.inference \ - --config-path=${file_dir} \ - --config-name="config.yaml" \ - ++init_param=${file_dir}/model.pt \ - ++tokenizer_conf.token_list=${file_dir}/tokens.txt \ - ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \ - +data_type='["kaldi_ark", "text"]' \ - ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \ - ++normalize_conf.stats_file=${file_dir}/am.mvn \ - ++output_dir="${inference_dir}/${JOB}" \ - ++device="${inference_device}" \ - ++ncpu=1 \ - ++disable_log=true &> ${_logdir}/log.${JOB}.txt - - }& -done -wait - - -mkdir -p ${inference_dir}/1best_recog - -for JOB in $(seq "${nj}"); do - cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token" -done - -echo "Computing WER ..." -sed -e 's/ /\t/' -e 's/ //g' -e 's/▁/ /g' -e 's/\t /\t/' ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc -cp ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref -cp ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list -python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer -tail -n 3 ${inference_dir}/1best_recog/token.cer - -./run_bwer_recall.sh ${inference_dir}/1best_recog/ -tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5 diff --git a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh deleted file mode 100755 index 7d6b6ff8b..000000000 --- a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh +++ /dev/null @@ -1,11 +0,0 @@ -#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave -#hotword_type=ocr_1ngram_top10_hotwords_list -hot_exp_suf=$1 - - -python compute_wer_details.py --v 1 \ - --ref ${hot_exp_suf}/token.ref \ - --ref_ocr ${hot_exp_suf}/ocr.list \ - --rec_name base \ - --rec_file ${hot_exp_suf}/token.proc \ - > ${hot_exp_suf}/BWER-UWER.results diff --git a/examples/industrial_data_pretraining/lcbnet/utils b/examples/industrial_data_pretraining/lcbnet/utils deleted file mode 120000 index be5e5a322..000000000 --- a/examples/industrial_data_pretraining/lcbnet/utils +++ /dev/null @@ -1 +0,0 @@ -../../aishell/paraformer/utils \ No newline at end of file diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 551dd8bf8..a44c649ae 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -7,10 +7,10 @@ from funasr import AutoModel model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4", - # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - # vad_model_revision="v2.0.4", - # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - # punc_model_revision="v2.0.4", + vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + vad_model_revision="v2.0.4", + punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + punc_model_revision="v2.0.4", # spk_model="damo/speech_campplus_sv_zh-cn_16k-common", # spk_model_revision="v2.0.2", ) @@ -43,4 +43,4 @@ import soundfile wav_file = os.path.join(model.model_path, "example/asr_example.wav") speech, sample_rate = soundfile.read(wav_file) res = model.generate(input=[speech], batch_size_s=300, is_final=True) -''' +''' \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index ec3c3f370..921ede809 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -28,7 +28,7 @@ try: from funasr.models.campplus.cluster_backend import ClusterBackend except: print("If you want to use the speaker diarization, please `pip install hdbscan`") -import pdb + def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): """ @@ -46,7 +46,6 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): chars = string.ascii_letters + string.digits if isinstance(data_in, str) and data_in.startswith('http'): # url data_in = download_from_url(data_in) - if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt; _, file_extension = os.path.splitext(data_in) file_extension = file_extension.lower() @@ -147,7 +146,7 @@ class AutoModel: kwargs = download_model(**kwargs) set_all_random_seed(kwargs.get("seed", 0)) - + device = kwargs.get("device", "cuda") if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: device = "cpu" @@ -169,6 +168,7 @@ class AutoModel: vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 else: vocab_size = -1 + # build frontend frontend = kwargs.get("frontend", None) kwargs["input_size"] = None @@ -181,6 +181,7 @@ class AutoModel: # build model model_class = tables.model_classes.get(kwargs["model"]) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) + model.to(device) # init_param @@ -223,9 +224,9 @@ class AutoModel: batch_size = kwargs.get("batch_size", 1) # if kwargs.get("device", "cpu") == "cpu": # batch_size = 1 - + key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key) - + speed_stats = {} asr_result_list = [] num_samples = len(data_list) @@ -238,7 +239,6 @@ class AutoModel: data_batch = data_list[beg_idx:end_idx] key_batch = key_list[beg_idx:end_idx] batch = {"data_in": data_batch, "key": key_batch} - if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank batch["data_in"] = data_batch[0] batch["data_lengths"] = input_len diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py index c4bdbd774..8ac1ca853 100644 --- a/funasr/frontends/default.py +++ b/funasr/frontends/default.py @@ -3,6 +3,7 @@ from typing import Optional from typing import Tuple from typing import Union import logging +import humanfriendly import numpy as np import torch import torch.nn as nn @@ -15,10 +16,8 @@ from funasr.frontends.utils.log_mel import LogMel from funasr.frontends.utils.stft import Stft from funasr.frontends.utils.frontend import Frontend from funasr.models.transformer.utils.nets_utils import make_pad_mask -from funasr.register import tables -@tables.register("frontend_classes", "DefaultFrontend") class DefaultFrontend(nn.Module): """Conventional frontend structure for ASR. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN @@ -26,7 +25,7 @@ class DefaultFrontend(nn.Module): def __init__( self, - fs: int = 16000, + fs: Union[int, str] = 16000, n_fft: int = 512, win_length: int = None, hop_length: int = 128, @@ -41,14 +40,14 @@ class DefaultFrontend(nn.Module): frontend_conf: Optional[dict] = None, apply_stft: bool = True, use_channel: int = None, - **kwargs, ): super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) # Deepcopy (In general, dict shouldn't be used as default arg) frontend_conf = copy.deepcopy(frontend_conf) self.hop_length = hop_length - self.fs = fs if apply_stft: self.stft = Stft( @@ -85,12 +84,8 @@ class DefaultFrontend(nn.Module): return self.n_mels def forward( - self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list] + self, input: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - if isinstance(input_lengths, list): - input_lengths = torch.tensor(input_lengths) - if input.dtype == torch.float64: - input = input.float() # 1. Domain-conversion: e.g. Stft: time -> time-freq if self.stft is not None: input_stft, feats_lens = self._compute_stft(input, input_lengths) @@ -150,7 +145,7 @@ class MultiChannelFrontend(nn.Module): def __init__( self, - fs: int = 16000, + fs: Union[int, str] = 16000, n_fft: int = 512, win_length: int = None, hop_length: int = None, @@ -173,6 +168,9 @@ class MultiChannelFrontend(nn.Module): mc: bool = True ): super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + # Deepcopy (In general, dict shouldn't be used as default arg) frontend_conf = copy.deepcopy(frontend_conf) if win_length is None and hop_length is None: diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py index be973c641..1d252c206 100644 --- a/funasr/models/conformer/encoder.py +++ b/funasr/models/conformer/encoder.py @@ -47,7 +47,7 @@ from funasr.models.transformer.utils.subsampling import check_short_utt from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad from funasr.models.transformer.utils.subsampling import StreamingConvInput from funasr.register import tables -import pdb + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py index 7d6f729a8..49868a8f4 100644 --- a/funasr/models/contextual_paraformer/model.py +++ b/funasr/models/contextual_paraformer/model.py @@ -29,7 +29,7 @@ from funasr.train_utils.device_funcs import force_gatherable from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -import pdb + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -63,6 +63,7 @@ class ContextualParaformer(Paraformer): crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0) bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) + if bias_encoder_type == 'lstm': self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate) self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim) @@ -102,16 +103,17 @@ class ContextualParaformer(Paraformer): text_lengths = text_lengths[:, 0] if len(speech_lengths.size()) > 1: speech_lengths = speech_lengths[:, 0] - + batch_size = speech.shape[0] hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") dha_pad = kwargs.get("dha_pad") - + # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + loss_ctc, cer_ctc = None, None stats = dict() @@ -126,11 +128,12 @@ class ContextualParaformer(Paraformer): stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc + # 2b. Attention decoder branch loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss( encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths ) - + # 3. CTC-Att loss definition if self.ctc_weight == 0.0: loss = loss_att + loss_pre * self.predictor_weight @@ -168,24 +171,22 @@ class ContextualParaformer(Paraformer): ): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - if self.predictor_bias == 1: _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_pad_lens = ys_pad_lens + self.predictor_bias - pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) + # -1. bias encoder if self.use_decoder_embedding: hw_embed = self.decoder.embed(hotword_pad) else: hw_embed = self.bias_embed(hotword_pad) - hw_embed, (_, _) = self.bias_encoder(hw_embed) _ind = np.arange(0, hotword_pad.shape[0]).tolist() selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]] contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) - + # 0. sampler decoder_out_1st = None if self.sampling_ratio > 0.0: @@ -194,7 +195,7 @@ class ContextualParaformer(Paraformer): pre_acoustic_embeds, contextual_info) else: sematic_embeds = pre_acoustic_embeds - + # 1. Forward decoder decoder_outs = self.decoder( encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info @@ -210,7 +211,7 @@ class ContextualParaformer(Paraformer): loss_ideal = None ''' loss_ideal = None - + if decoder_out_1st is None: decoder_out_1st = decoder_out # 2. Compute attention loss @@ -287,11 +288,10 @@ class ContextualParaformer(Paraformer): enforce_sorted=False) _, (h_n, _) = self.bias_encoder(hw_embed) hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) - + decoder_outs = self.decoder( encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale ) - decoder_out = decoder_outs[0] decoder_out = torch.log_softmax(decoder_out, dim=-1) return decoder_out, ys_pad_lens @@ -305,42 +305,38 @@ class ContextualParaformer(Paraformer): **kwargs, ): # init beamsearch - is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None if self.beam_search is None and (is_use_lm or is_use_ctc): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - + meta_data = {} # extract fbank feats time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend) time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data[ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - + speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) # hotword self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) - + # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] - + # predictor predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ @@ -348,7 +344,8 @@ class ContextualParaformer(Paraformer): pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] - + + decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, diff --git a/funasr/models/lcbnet/__init__.py b/funasr/models/lcbnet/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py deleted file mode 100644 index 8e8c5943a..000000000 --- a/funasr/models/lcbnet/attention.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 yufan -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Multi-Head Attention Return Weight layer definition.""" - -import math - -import torch -from torch import nn - -class MultiHeadedAttentionReturnWeight(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, n_feat, dropout_rate): - """Construct an MultiHeadedAttentionReturnWeight object.""" - super(MultiHeadedAttentionReturnWeight, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, query, key, value): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, value, scores, mask): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - min_value = torch.finfo(scores.dtype).min - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x), self.attn # (batch, time1, d_model) - - def forward(self, query, key, value, mask): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q, k, v = self.forward_qkv(query, key, value) - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask) - - diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py deleted file mode 100644 index c65823cb0..000000000 --- a/funasr/models/lcbnet/encoder.py +++ /dev/null @@ -1,392 +0,0 @@ -# Copyright 2019 Shigeki Karita -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Transformer encoder definition.""" - -from typing import List -from typing import Optional -from typing import Tuple - -import torch -from torch import nn -import logging - -from funasr.models.transformer.attention import MultiHeadedAttention -from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight -from funasr.models.transformer.embedding import PositionalEncoding -from funasr.models.transformer.layer_norm import LayerNorm - -from funasr.models.transformer.utils.nets_utils import make_pad_mask -from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward -from funasr.models.transformer.utils.repeat import repeat -from funasr.register import tables - -class EncoderLayer(nn.Module): - """Encoder layer module. - - Args: - size (int): Input dimension. - self_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance - can be used as the argument. - feed_forward (torch.nn.Module): Feed-forward module instance. - `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance - can be used as the argument. - dropout_rate (float): Dropout rate. - normalize_before (bool): Whether to use layer_norm before the first block. - concat_after (bool): Whether to concat attention layer's input and output. - if True, additional linear will be applied. - i.e. x -> x + linear(concat(x, att(x))) - if False, no additional linear will be applied. i.e. x -> x + att(x) - stochastic_depth_rate (float): Proability to skip this layer. - During training, the layer may skip residual computation and return input - as-is with given probability. - """ - - def __init__( - self, - size, - self_attn, - feed_forward, - dropout_rate, - normalize_before=True, - concat_after=False, - stochastic_depth_rate=0.0, - ): - """Construct an EncoderLayer object.""" - super(EncoderLayer, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.norm1 = LayerNorm(size) - self.norm2 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.size = size - self.normalize_before = normalize_before - self.concat_after = concat_after - if self.concat_after: - self.concat_linear = nn.Linear(size + size, size) - self.stochastic_depth_rate = stochastic_depth_rate - - def forward(self, x, mask, cache=None): - """Compute encoded features. - - Args: - x_input (torch.Tensor): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time). - cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time). - - """ - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - stoch_layer_coeff = 1.0 - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - if cache is not None: - x = torch.cat([cache, x], dim=1) - return x, mask - - residual = x - if self.normalize_before: - x = self.norm1(x) - - if cache is None: - x_q = x - else: - assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) - x_q = x[:, -1:, :] - residual = residual[:, -1:, :] - mask = None if mask is None else mask[:, -1:, :] - - if self.concat_after: - x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) - x = residual + stoch_layer_coeff * self.concat_linear(x_concat) - else: - x = residual + stoch_layer_coeff * self.dropout( - self.self_attn(x_q, x, x, mask) - ) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm2(x) - - if cache is not None: - x = torch.cat([cache, x], dim=1) - - return x, mask - -@tables.register("encoder_classes", "TransformerTextEncoder") -class TransformerTextEncoder(nn.Module): - """Transformer text encoder module. - - Args: - input_size: input dim - output_size: dimension of attention - attention_heads: the number of heads of multi head attention - linear_units: the number of units of position-wise feed forward - num_blocks: the number of decoder blocks - dropout_rate: dropout rate - attention_dropout_rate: dropout rate in attention - positional_dropout_rate: dropout rate after adding positional encoding - input_layer: input layer type - pos_enc_class: PositionalEncoding or ScaledPositionalEncoding - normalize_before: whether to use layer_norm before the first block - concat_after: whether to concat attention layer's input and output - if True, additional linear will be applied. - i.e. x -> x + linear(concat(x, att(x))) - if False, no additional linear will be applied. - i.e. x -> x + att(x) - positionwise_layer_type: linear of conv1d - positionwise_conv_kernel_size: kernel size of positionwise conv1d layer - padding_idx: padding_idx for input_layer=embed - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - pos_enc_class=PositionalEncoding, - normalize_before: bool = True, - concat_after: bool = False, - ): - super().__init__() - self._output_size = output_size - - self.embed = torch.nn.Sequential( - torch.nn.Embedding(input_size, output_size), - pos_enc_class(output_size, positional_dropout_rate), - ) - - self.normalize_before = normalize_before - - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - self.encoders = repeat( - num_blocks, - lambda lnum: EncoderLayer( - output_size, - MultiHeadedAttention( - attention_heads, output_size, attention_dropout_rate - ), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - xs_pad = self.embed(xs_pad) - - xs_pad, masks = self.encoders(xs_pad, masks) - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - return xs_pad, olens, None - - - - -@tables.register("encoder_classes", "FusionSANEncoder") -class SelfSrcAttention(nn.Module): - """Single decoder layer module. - - Args: - size (int): Input dimension. - self_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` instance can be used as the argument. - src_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` instance can be used as the argument. - feed_forward (torch.nn.Module): Feed-forward module instance. - `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance - can be used as the argument. - dropout_rate (float): Dropout rate. - normalize_before (bool): Whether to use layer_norm before the first block. - concat_after (bool): Whether to concat attention layer's input and output. - if True, additional linear will be applied. - i.e. x -> x + linear(concat(x, att(x))) - if False, no additional linear will be applied. i.e. x -> x + att(x) - - - """ - def __init__( - self, - size, - attention_heads, - attention_dim, - linear_units, - self_attention_dropout_rate, - src_attention_dropout_rate, - positional_dropout_rate, - dropout_rate, - normalize_before=True, - concat_after=False, - ): - """Construct an SelfSrcAttention object.""" - super(SelfSrcAttention, self).__init__() - self.size = size - self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate) - self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate) - self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate) - self.norm1 = LayerNorm(size) - self.norm2 = LayerNorm(size) - self.norm3 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.normalize_before = normalize_before - self.concat_after = concat_after - if self.concat_after: - self.concat_linear1 = nn.Linear(size + size, size) - self.concat_linear2 = nn.Linear(size + size, size) - - def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): - """Compute decoded features. - - Args: - tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). - tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). - memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). - memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). - cache (List[torch.Tensor]): List of cached tensors. - Each tensor shape should be (#batch, maxlen_out - 1, size). - - Returns: - torch.Tensor: Output tensor(#batch, maxlen_out, size). - torch.Tensor: Mask for output tensor (#batch, maxlen_out). - torch.Tensor: Encoded memory (#batch, maxlen_in, size). - torch.Tensor: Encoded memory mask (#batch, maxlen_in). - - """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - - if cache is None: - tgt_q = tgt - tgt_q_mask = tgt_mask - else: - # compute only the last frame query keeping dim: max_time_out -> 1 - assert cache.shape == ( - tgt.shape[0], - tgt.shape[1] - 1, - self.size, - ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" - tgt_q = tgt[:, -1:, :] - residual = residual[:, -1:, :] - tgt_q_mask = None - if tgt_mask is not None: - tgt_q_mask = tgt_mask[:, -1:, :] - - if self.concat_after: - tgt_concat = torch.cat( - (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 - ) - x = residual + self.concat_linear1(tgt_concat) - else: - x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - if self.concat_after: - x_concat = torch.cat( - (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 - ) - x = residual + self.concat_linear2(x_concat) - else: - x, score = self.src_attn(x, memory, memory, memory_mask) - x = residual + self.dropout(x) - if not self.normalize_before: - x = self.norm2(x) - - residual = x - if self.normalize_before: - x = self.norm3(x) - x = residual + self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm3(x) - - if cache is not None: - x = torch.cat([cache, x], dim=1) - - return x, tgt_mask, memory, memory_mask - - -@tables.register("encoder_classes", "ConvBiasPredictor") -class ConvPredictor(nn.Module): - def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048): - super().__init__() - self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate) - self.norm1 = LayerNorm(size) - self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate) - self.norm2 = LayerNorm(size) - self.pad = nn.ConstantPad1d((l_order, r_order), 0) - self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size) - self.output_linear = nn.Linear(size, 1) - - - def forward(self, text_enc, asr_enc): - # stage1 cross-attention - residual = text_enc - text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None) - - # stage2 FFN - residual = text_enc - text_enc = self.norm1(text_enc) - text_enc = residual + self.feed_forward(text_enc) - - # stage Conv predictor - text_enc = self.norm2(text_enc) - context = text_enc.transpose(1, 2) - queries = self.pad(context) - memory = self.conv1d(queries) - output = memory + context - output = output.transpose(1, 2) - output = torch.relu(output) - output = self.output_linear(output) - if output.dim()==3: - output = output.squeeze(2) - return output diff --git a/funasr/models/lcbnet/model.py b/funasr/models/lcbnet/model.py deleted file mode 100644 index 3ac319c61..000000000 --- a/funasr/models/lcbnet/model.py +++ /dev/null @@ -1,495 +0,0 @@ -#!/usr/bin/env python3 -# -*- encoding: utf-8 -*- -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - -import logging -from typing import Union, Dict, List, Tuple, Optional - -import time -import torch -import torch.nn as nn -from torch.cuda.amp import autocast - -from funasr.losses.label_smoothing_loss import LabelSmoothingLoss -from funasr.models.ctc.ctc import CTC -from funasr.models.transformer.utils.add_sos_eos import add_sos_eos -from funasr.metrics.compute_acc import th_accuracy -# from funasr.models.e2e_asr_common import ErrorCalculator -from funasr.train_utils.device_funcs import force_gatherable -from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -from funasr.utils import postprocess_utils -from funasr.utils.datadir_writer import DatadirWriter -from funasr.register import tables - -import pdb -@tables.register("model_classes", "LCBNet") -class LCBNet(nn.Module): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION - https://arxiv.org/abs/2401.06390 - """ - - def __init__( - self, - specaug: str = None, - specaug_conf: dict = None, - normalize: str = None, - normalize_conf: dict = None, - encoder: str = None, - encoder_conf: dict = None, - decoder: str = None, - decoder_conf: dict = None, - text_encoder: str = None, - text_encoder_conf: dict = None, - bias_predictor: str = None, - bias_predictor_conf: dict = None, - fusion_encoder: str = None, - fusion_encoder_conf: dict = None, - ctc: str = None, - ctc_conf: dict = None, - ctc_weight: float = 0.5, - interctc_weight: float = 0.0, - select_num: int = 2, - select_length: int = 3, - insert_blank: bool = True, - input_size: int = 80, - vocab_size: int = -1, - ignore_id: int = -1, - blank_id: int = 0, - sos: int = 1, - eos: int = 2, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", - # extract_feats_in_collect_stats: bool = True, - share_embedding: bool = False, - # preencoder: Optional[AbsPreEncoder] = None, - # postencoder: Optional[AbsPostEncoder] = None, - **kwargs, - ): - - super().__init__() - - if specaug is not None: - specaug_class = tables.specaug_classes.get(specaug) - specaug = specaug_class(**specaug_conf) - if normalize is not None: - normalize_class = tables.normalize_classes.get(normalize) - normalize = normalize_class(**normalize_conf) - encoder_class = tables.encoder_classes.get(encoder) - encoder = encoder_class(input_size=input_size, **encoder_conf) - encoder_output_size = encoder.output_size() - - # lcbnet modules: text encoder, fusion encoder and bias predictor - text_encoder_class = tables.encoder_classes.get(text_encoder) - text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf) - fusion_encoder_class = tables.encoder_classes.get(fusion_encoder) - fusion_encoder = fusion_encoder_class(**fusion_encoder_conf) - bias_predictor_class = tables.encoder_classes.get(bias_predictor) - bias_predictor = bias_predictor_class(**bias_predictor_conf) - - - if decoder is not None: - decoder_class = tables.decoder_classes.get(decoder) - decoder = decoder_class( - vocab_size=vocab_size, - encoder_output_size=encoder_output_size, - **decoder_conf, - ) - if ctc_weight > 0.0: - - if ctc_conf is None: - ctc_conf = {} - - ctc = CTC( - odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf - ) - - self.blank_id = blank_id - self.sos = vocab_size - 1 - self.eos = vocab_size - 1 - self.vocab_size = vocab_size - self.ignore_id = ignore_id - self.ctc_weight = ctc_weight - self.specaug = specaug - self.normalize = normalize - self.encoder = encoder - # lcbnet - self.text_encoder = text_encoder - self.fusion_encoder = fusion_encoder - self.bias_predictor = bias_predictor - self.select_num = select_num - self.select_length = select_length - self.insert_blank = insert_blank - - if not hasattr(self.encoder, "interctc_use_conditioning"): - self.encoder.interctc_use_conditioning = False - if self.encoder.interctc_use_conditioning: - self.encoder.conditioning_layer = torch.nn.Linear( - vocab_size, self.encoder.output_size() - ) - self.interctc_weight = interctc_weight - - # self.error_calculator = None - if ctc_weight == 1.0: - self.decoder = None - else: - self.decoder = decoder - - self.criterion_att = LabelSmoothingLoss( - size=vocab_size, - padding_idx=ignore_id, - smoothing=lsm_weight, - normalize_length=length_normalized_loss, - ) - # - # if report_cer or report_wer: - # self.error_calculator = ErrorCalculator( - # token_list, sym_space, sym_blank, report_cer, report_wer - # ) - # - self.error_calculator = None - if ctc_weight == 0.0: - self.ctc = None - else: - self.ctc = ctc - - self.share_embedding = share_embedding - if self.share_embedding: - self.decoder.embed = None - - self.length_normalized_loss = length_normalized_loss - self.beam_search = None - - def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Encoder + Decoder + Calc loss - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - text: (Batch, Length) - text_lengths: (Batch,) - """ - - if len(text_lengths.size()) > 1: - text_lengths = text_lengths[:, 0] - if len(speech_lengths.size()) > 1: - speech_lengths = speech_lengths[:, 0] - - batch_size = speech.shape[0] - - # 1. Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] - - loss_att, acc_att, cer_att, wer_att = None, None, None, None - loss_ctc, cer_ctc = None, None - stats = dict() - - # decoder: CTC branch - if self.ctc_weight != 0.0: - loss_ctc, cer_ctc = self._calc_ctc_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # Collect CTC branch stats - stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None - stats["cer_ctc"] = cer_ctc - - # Intermediate CTC (optional) - loss_interctc = 0.0 - if self.interctc_weight != 0.0 and intermediate_outs is not None: - for layer_idx, intermediate_out in intermediate_outs: - # we assume intermediate_out has the same length & padding - # as those of encoder_out - loss_ic, cer_ic = self._calc_ctc_loss( - intermediate_out, encoder_out_lens, text, text_lengths - ) - loss_interctc = loss_interctc + loss_ic - - # Collect Intermedaite CTC stats - stats["loss_interctc_layer{}".format(layer_idx)] = ( - loss_ic.detach() if loss_ic is not None else None - ) - stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic - - loss_interctc = loss_interctc / len(intermediate_outs) - - # calculate whole encoder loss - loss_ctc = ( - 1 - self.interctc_weight - ) * loss_ctc + self.interctc_weight * loss_interctc - - # decoder: Attention decoder branch - loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # 3. CTC-Att loss definition - if self.ctc_weight == 0.0: - loss = loss_att - elif self.ctc_weight == 1.0: - loss = loss_ctc - else: - loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att - - # Collect Attn branch stats - stats["loss_att"] = loss_att.detach() if loss_att is not None else None - stats["acc"] = acc_att - stats["cer"] = cer_att - stats["wer"] = wer_att - - # Collect total loss stats - stats["loss"] = torch.clone(loss.detach()) - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - if self.length_normalized_loss: - batch_size = int((text_lengths + 1).sum()) - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - return loss, stats, weight - - - def encode( - self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Frontend + Encoder. Note that this method is used by asr_inference.py - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - ind: int - """ - with autocast(False): - # Data augmentation - if self.specaug is not None and self.training: - speech, speech_lengths = self.specaug(speech, speech_lengths) - # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - speech, speech_lengths = self.normalize(speech, speech_lengths) - # Forward encoder - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder( - speech, speech_lengths, ctc=self.ctc - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] - - if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens - return encoder_out, encoder_out_lens - - def _calc_att_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) - ys_in_lens = ys_pad_lens + 1 - - # 1. Forward decoder - decoder_out, _ = self.decoder( - encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens - ) - - # 2. Compute attention loss - loss_att = self.criterion_att(decoder_out, ys_out_pad) - acc_att = th_accuracy( - decoder_out.view(-1, self.vocab_size), - ys_out_pad, - ignore_label=self.ignore_id, - ) - - # Compute cer/wer using attention-decoder - if self.training or self.error_calculator is None: - cer_att, wer_att = None, None - else: - ys_hat = decoder_out.argmax(dim=-1) - cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) - - return loss_att, acc_att, cer_att, wer_att - - def _calc_ctc_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - # Calc CTC loss - loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) - - # Calc CER using CTC - cer_ctc = None - if not self.training and self.error_calculator is not None: - ys_hat = self.ctc.argmax(encoder_out).data - cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) - return loss_ctc, cer_ctc - - def init_beam_search(self, - **kwargs, - ): - from funasr.models.transformer.search import BeamSearch - from funasr.models.transformer.scorers.ctc import CTCPrefixScorer - from funasr.models.transformer.scorers.length_bonus import LengthBonus - - # 1. Build ASR model - scorers = {} - - if self.ctc != None: - ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) - scorers.update( - ctc=ctc - ) - token_list = kwargs.get("token_list") - scorers.update( - decoder=self.decoder, - length_bonus=LengthBonus(len(token_list)), - ) - - - # 3. Build ngram model - # ngram is not supported now - ngram = None - scorers["ngram"] = ngram - - weights = dict( - decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3), - ctc=kwargs.get("decoding_ctc_weight", 0.3), - lm=kwargs.get("lm_weight", 0.0), - ngram=kwargs.get("ngram_weight", 0.0), - length_bonus=kwargs.get("penalty", 0.0), - ) - beam_search = BeamSearch( - beam_size=kwargs.get("beam_size", 20), - weights=weights, - scorers=scorers, - sos=self.sos, - eos=self.eos, - vocab_size=len(token_list), - token_list=token_list, - pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", - ) - - self.beam_search = beam_search - - def inference(self, - data_in, - data_lengths=None, - key: list=None, - tokenizer=None, - frontend=None, - **kwargs, - ): - - if kwargs.get("batch_size", 1) > 1: - raise NotImplementedError("batch decoding is not implemented") - - # init beamsearch - if self.beam_search is None: - logging.info("enable beam_search") - self.init_beam_search(**kwargs) - self.nbest = kwargs.get("nbest", 1) - - meta_data = {} - if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank - speech, speech_lengths = data_in, data_lengths - if len(speech.shape) < 3: - speech = speech[None, :, :] - if speech_lengths is None: - speech_lengths = speech.shape[1] - else: - # extract fbank feats - time1 = time.perf_counter() - sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), - data_type=kwargs.get("data_type", "sound"), - tokenizer=tokenizer) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - audio_sample_list = sample_list[0] - if len(sample_list) >1: - ocr_sample_list = sample_list[1] - else: - ocr_sample_list = [[294, 0]] - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), - frontend=frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - frame_shift = 10 - meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000 - - speech = speech.to(device=kwargs["device"]) - speech_lengths = speech_lengths.to(device=kwargs["device"]) - # Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - if isinstance(encoder_out, tuple): - encoder_out = encoder_out[0] - - ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list] - ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"]) - ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"]) - ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths) - fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None) - encoder_out = encoder_out + fusion_out - # c. Passed the encoder result and the beam search - nbest_hyps = self.beam_search( - x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) - ) - - nbest_hyps = nbest_hyps[: self.nbest] - - results = [] - b, n, d = encoder_out.size() - for i in range(b): - - for nbest_idx, hyp in enumerate(nbest_hyps): - ibest_writer = None - if kwargs.get("output_dir") is not None: - if not hasattr(self, "writer"): - self.writer = DatadirWriter(kwargs.get("output_dir")) - ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) - - # Change integer-ids to tokens - token = tokenizer.ids2tokens(token_int) - text = tokenizer.tokens2text(token) - - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - result_i = {"key": key[i], "token": token, "text": text_postprocessed} - results.append(result_i) - - if ibest_writer is not None: - ibest_writer["token"][key[i]] = " ".join(token) - ibest_writer["text"][key[i]] = text_postprocessed - - return results, meta_data - diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index a8b1f1fb1..20b0cc838 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -30,7 +30,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -import pdb + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: @@ -128,7 +128,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") dha_pad = kwargs.get("dha_pad") - + batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] @@ -209,20 +209,17 @@ class SeacoParaformer(BiCifParaformer, Paraformer): nfilter=50, seaco_weight=1.0): # decoder forward - decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) - decoder_pred = torch.log_softmax(decoder_out, dim=-1) if hw_list is not None: hw_lengths = [len(i) for i in hw_list] hw_list_ = [torch.Tensor(i).long() for i in hw_list] hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) - contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) num_hot_word = contextual_info.shape[1] _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) - + # ASF Core if nfilter > 0 and nfilter < num_hot_word: hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) @@ -242,7 +239,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) merged = self._merge(cif_attended, dec_attended) - + dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation dha_pred = torch.log_softmax(dha_output, dim=-1) def _merge_res(dec_output, dha_output): @@ -256,8 +253,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer): # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) return logits - merged_pred = _merge_res(decoder_pred, dha_pred) + # import pdb; pdb.set_trace() return merged_pred else: return decoder_pred @@ -307,6 +304,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) + meta_data = {} # extract fbank feats @@ -332,7 +330,6 @@ class SeacoParaformer(BiCifParaformer, Paraformer): if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] - # predictor predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ @@ -341,14 +338,15 @@ class SeacoParaformer(BiCifParaformer, Paraformer): if torch.max(pre_token_length) < 1: return [] + decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) - # decoder_out, _ = decoder_outs[0], decoder_outs[1] _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, pre_token_length) + results = [] b, n, d = decoder_out.size() for i in range(b): diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 0c46449e4..ea2372537 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -7,7 +7,7 @@ import logging import torch import torch.nn import torch.optim -import pdb + def filter_state_dict( dst_state: Dict[str, Union[float, torch.Tensor]], @@ -63,7 +63,6 @@ def load_pretrained_model( dst_state = obj.state_dict() print(f"ckpt: {path}") - if oss_bucket is None: src_state = torch.load(path, map_location=map_location) else: diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py index 84c38f9b9..7748172f6 100644 --- a/funasr/utils/load_utils.py +++ b/funasr/utils/load_utils.py @@ -13,25 +13,29 @@ try: from funasr.download.file import download_from_url except: print("urllib is not installed, if you infer from url, please install it first.") -import pdb + def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs): if isinstance(data_or_path_or_list, (list, tuple)): if data_type is not None and isinstance(data_type, (list, tuple)): + data_types = [data_type] * len(data_or_path_or_list) data_or_path_or_list_ret = [[] for d in data_type] for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)): + for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)): + data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs) data_or_path_or_list_ret[j].append(data_or_path_or_list_j) return data_or_path_or_list_ret else: return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list] + if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file data_or_path_or_list = download_from_url(data_or_path_or_list) - + if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file if data_type is None or data_type == "sound": data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list) @@ -52,22 +56,10 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: data_or_path_or_list = tokenizer.encode(data_or_path_or_list) elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,] - elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark": - data_mat = kaldiio.load_mat(data_or_path_or_list) - if isinstance(data_mat, tuple): - audio_fs, mat = data_mat - else: - mat = data_mat - if mat.dtype == 'int16' or mat.dtype == 'int32': - mat = mat.astype(np.float64) - mat = mat / 32768 - if mat.ndim ==2: - mat = mat[:,0] - data_or_path_or_list = mat else: pass # print(f"unsupport data type: {data_or_path_or_list}, return raw data") - + if audio_fs != fs and data_type != "text": resampler = torchaudio.transforms.Resample(audio_fs, fs) data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :] @@ -89,6 +81,8 @@ def load_bytes(input): return array def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs): + # import pdb; + # pdb.set_trace() if isinstance(data, np.ndarray): data = torch.from_numpy(data) if len(data.shape) < 2: @@ -106,7 +100,9 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, data_list.append(data_i) data_len.append(data_i.shape[0]) data = pad_sequence(data_list, batch_first=True) # data: [batch, N] - + # import pdb; + # pdb.set_trace() + # if data_type == "sound": data, data_len = frontend(data, data_len, **kwargs) if isinstance(data_len, (list, tuple)):