Revert "Dev yf" (#1418)

This commit is contained in:
zhifu gao 2024-03-04 17:50:29 +08:00 committed by GitHub
parent 920331972a
commit d2c1204d91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 61 additions and 1886 deletions

View File

View File

@ -2,7 +2,7 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_revision="v2.0.4"
python ../../../funasr/bin/inference.py \
python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \

View File

@ -1,9 +0,0 @@
python -m funasr.bin.inference \
--config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
--config-name="config.yaml" \
++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \
++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \
++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \
++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \
++output_dir="./outputs/debug2" \
++device="" \

View File

@ -1,6 +0,0 @@
export FUNASR_DIR=$PWD/../../../
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH

View File

@ -1,702 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from enum import Enum
import re, sys, unicodedata
import codecs
import argparse
from tqdm import tqdm
import os
import pdb
remove_tag = False
spacelist = [" ", "\t", "\r", "\n"]
puncts = [
"!",
",",
"?",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
]
class Code(Enum):
match = 1
substitution = 2
insertion = 3
deletion = 4
class WordError(object):
def __init__(self):
self.errors = {
Code.substitution: 0,
Code.insertion: 0,
Code.deletion: 0,
}
self.ref_words = 0
def get_wer(self):
assert self.ref_words != 0
errors = (
self.errors[Code.substitution]
+ self.errors[Code.insertion]
+ self.errors[Code.deletion]
)
return 100.0 * errors / self.ref_words
def get_result_string(self):
return (
f"error_rate={self.get_wer():.4f}, "
f"ref_words={self.ref_words}, "
f"subs={self.errors[Code.substitution]}, "
f"ins={self.errors[Code.insertion]}, "
f"dels={self.errors[Code.deletion]}"
)
def characterize(string):
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
# https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == "Lo": # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = " "
if char == "<":
sep = ">"
j = i + 1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c == sep):
break
j += 1
if j < len(string) and string[j] == ">":
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x:
return ""
chars = []
i = 0
T = len(x)
while i < T:
if x[i] == "<":
while i < T and x[i] != ">":
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return "".join(chars)
def normalize(sentence, ignore_words, cs, split=None):
"""sentence, ignore_words are both in unicode"""
new_sentence = []
for token in sentence:
x = token
if not cs:
x = x.upper()
if x in ignore_words:
continue
if remove_tag:
x = stripoff_tags(x)
if not x:
continue
if split and x in split:
new_sentence += split[x]
else:
new_sentence.append(x)
return new_sentence
class Calculator:
def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost["cor"] = 0
self.cost["sub"] = 1
self.cost["del"] = 1
self.cost["ins"] = 1
def calculate(self, lab, rec):
# Initialization
lab.insert(0, "")
rec.insert(0, "")
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element["dist"] = 0
element["error"] = "non"
while len(row) < len(rec):
row.append({"dist": 0, "error": "non"})
for i in range(len(lab)):
self.space[i][0]["dist"] = i
self.space[i][0]["error"] = "del"
for j in range(len(rec)):
self.space[0][j]["dist"] = j
self.space[0][j]["error"] = "ins"
self.space[0][0]["error"] = "non"
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = "none"
dist = self.space[i - 1][j]["dist"] + self.cost["del"]
error = "del"
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
error = "ins"
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token.replace("<BIAS>", ""):
dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
error = "cor"
else:
dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
error = "sub"
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]["dist"] = min_dist
self.space[i][j]["error"] = min_error
# Tracing back
result = {
"lab": [],
"rec": [],
"code": [],
"all": 0,
"cor": 0,
"sub": 0,
"ins": 0,
"del": 0,
}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]["error"] == "cor": # correct
if len(lab[i]) > 0:
self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
result["all"] = result["all"] + 1
result["cor"] = result["cor"] + 1
result["lab"].insert(0, lab[i])
result["rec"].insert(0, rec[j])
result["code"].insert(0, Code.match)
i = i - 1
j = j - 1
elif self.space[i][j]["error"] == "sub": # substitution
if len(lab[i]) > 0:
self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
result["all"] = result["all"] + 1
result["sub"] = result["sub"] + 1
result["lab"].insert(0, lab[i])
result["rec"].insert(0, rec[j])
result["code"].insert(0, Code.substitution)
i = i - 1
j = j - 1
elif self.space[i][j]["error"] == "del": # deletion
if len(lab[i]) > 0:
self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
result["all"] = result["all"] + 1
result["del"] = result["del"] + 1
result["lab"].insert(0, lab[i])
result["rec"].insert(0, "")
result["code"].insert(0, Code.deletion)
i = i - 1
elif self.space[i][j]["error"] == "ins": # insertion
if len(rec[j]) > 0:
self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
result["ins"] = result["ins"] + 1
result["lab"].insert(0, "")
result["rec"].insert(0, rec[j])
result["code"].insert(0, Code.insertion)
j = j - 1
elif self.space[i][j]["error"] == "non": # starting point
break
else: # shouldn't reach here
print(
"this should not happen , i = {i} , j = {j} , error = {error}".format(
i=i, j=j, error=self.space[i][j]["error"]
)
)
return result
def overall(self):
result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
for token in self.data:
result["all"] = result["all"] + self.data[token]["all"]
result["cor"] = result["cor"] + self.data[token]["cor"]
result["sub"] = result["sub"] + self.data[token]["sub"]
result["ins"] = result["ins"] + self.data[token]["ins"]
result["del"] = result["del"] + self.data[token]["del"]
return result
def cluster(self, data):
result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
for token in data:
if token in self.data:
result["all"] = result["all"] + self.data[token]["all"]
result["cor"] = result["cor"] + self.data[token]["cor"]
result["sub"] = result["sub"] + self.data[token]["sub"]
result["ins"] = result["ins"] + self.data[token]["ins"]
result["del"] = result["del"] + self.data[token]["del"]
return result
def keys(self):
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word):
unicode_names = [unicodedata.name(char) for char in word]
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith("DIGIT"): # 1
unicode_names[i] = "Number" # 'DIGIT'
elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
i
].startswith("CJK COMPATIBILITY IDEOGRAPH"):
# 明 / 郎
unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
i
].startswith("LATIN SMALL LETTER"):
# A / a
unicode_names[i] = "English" # 'LATIN LETTER'
elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め
unicode_names[i] = "Japanese" # 'GANA LETTER'
elif (
unicode_names[i].startswith("AMPERSAND")
or unicode_names[i].startswith("APOSTROPHE")
or unicode_names[i].startswith("COMMERCIAL AT")
or unicode_names[i].startswith("DEGREE CELSIUS")
or unicode_names[i].startswith("EQUALS SIGN")
or unicode_names[i].startswith("FULL STOP")
or unicode_names[i].startswith("HYPHEN-MINUS")
or unicode_names[i].startswith("LOW LINE")
or unicode_names[i].startswith("NUMBER SIGN")
or unicode_names[i].startswith("PLUS SIGN")
or unicode_names[i].startswith("SEMICOLON")
):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else:
return "Other"
if len(unicode_names) == 0:
return "Other"
if len(unicode_names) == 1:
return unicode_names[0]
for i in range(len(unicode_names) - 1):
if unicode_names[i] != unicode_names[i + 1]:
return "Other"
return unicode_names[0]
def get_args():
parser = argparse.ArgumentParser(description="wer cal")
parser.add_argument("--ref", type=str, help="Text input path")
parser.add_argument("--ref_ocr", type=str, help="Text input path")
parser.add_argument("--rec_name", type=str, action="append", default=[])
parser.add_argument("--rec_file", type=str, action="append", default=[])
parser.add_argument("--verbose", type=int, default=1, help="show")
parser.add_argument("--char", type=bool, default=True, help="show")
args = parser.parse_args()
return args
def main(args):
cluster_file = ""
ignore_words = set()
tochar = args.char
verbose = args.verbose
padding_symbol = " "
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
if not case_sensitive:
ig = set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
ref_file = args.ref
ref_ocr = args.ref_ocr
rec_files = args.rec_file
rec_names = args.rec_name
assert len(rec_files) == len(rec_names)
# load ocr
ref_ocr_dict = {}
with codecs.open(ref_ocr, "r", "utf-8") as fh:
for line in fh:
if "$" in line:
line = line.replace("$", " ")
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0:
continue
fid = array[0]
ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
rec_sets = {}
calculators_dict = dict()
ub_wer_dict = dict()
hotwords_related_dict = dict() # 记录recall相关的内容
for i, hyp_file in enumerate(rec_files):
rec_sets[rec_names[i]] = dict()
with codecs.open(hyp_file, "r", "utf-8") as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0:
continue
fid = array[0]
rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
calculators_dict[rec_names[i]] = Calculator()
ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
# tp: 热词在label里同时在rec里
# tn: 热词不在label里同时不在rec里
# fp: 热词不在label里但是在rec里
# fn: 热词在label里但是不在rec里
# record wrong label but in ocr
wrong_rec_but_in_ocr_dict = {}
for rec_name in rec_names:
wrong_rec_but_in_ocr_dict[rec_name] = 0
_file_total_len = 0
with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
_file_total_len = int(pipe.read().strip())
# compute error rate on the interaction of reference file and hyp file
for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array) == 0: continue
fid = array[0]
lab = normalize(array[1:], ignore_words, case_sensitive, split)
if verbose:
print('\nutt: %s' % fid)
ocr_text = ref_ocr_dict[fid]
ocr_set = set(ocr_text)
print('ocr: {}'.format(" ".join(ocr_text)))
list_match = [] # 指label里面在ocr里面的内容
list_not_mathch = []
tmp_error = 0
tmp_match = 0
for index in range(len(lab)):
# text_list.append(uttlist[index+1])
if lab[index] not in ocr_set:
tmp_error += 1
list_not_mathch.append(lab[index])
else:
tmp_match += 1
list_match.append(lab[index])
print('label in ocr: {}'.format(" ".join(list_match)))
# for each reco file
base_wrong_ocr_wer = None
ocr_wrong_ocr_wer = None
for rec_name in rec_names:
rec_set = rec_sets[rec_name]
if fid not in rec_set:
continue
rec = rec_set[fid]
# print(rec)
for word in rec + lab:
if word not in default_words:
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters:
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name]:
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
if verbose:
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
# print(result['rec'])
wrong_rec_but_in_ocr = []
for idx in range(len(result['lab'])):
if result['lab'][idx] != "":
if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
if result['lab'][idx] in list_match:
wrong_rec_but_in_ocr.append(result['lab'][idx])
wrong_rec_but_in_ocr_dict[rec_name] += 1
print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
if rec_name == "base":
base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
if "ocr" in rec_name or "hot" in rec_name:
ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
# recall = 0
# false_alarm = 0
# for idx in range(len(result['lab'])):
# if "<BIAS>" in result['rec'][idx]:
# if result['rec'][idx].replace("<BIAS>", "") in list_match:
# recall += 1
# else:
# false_alarm += 1
# print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
# recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
# ))
# tp: 热词在label里同时在rec里
# tn: 热词不在label里同时不在rec里
# fp: 热词不在label里但是在rec里
# fn: 热词在label里但是不在rec里
_rec_list = [word.replace("<BIAS>", "") for word in rec]
_label_list = [word for word in lab]
_tp = _tn = _fp = _fn = 0
hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
for badhotword in hot_bad_list:
count = len([word for word in _rec_list if word == badhotword])
# print(f"bad {badhotword} count: {count}")
# for word in _rec_list:
# if badhotword == word:
# count += 1
if count == 0:
hotwords_related_dict[rec_name]['tn'] += 1
_tn += 1
# fp: 0
else:
hotwords_related_dict[rec_name]['fp'] += count
_fp += count
# tn: 0
# if badhotword in _rec_list:
# hotwords_related_dict[rec_name]['fp'] += 1
# else:
# hotwords_related_dict[rec_name]['tn'] += 1
for hotword in hot_true_list:
true_count = len([word for word in _label_list if hotword == word])
rec_count = len([word for word in _rec_list if hotword == word])
# print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
if rec_count == true_count:
hotwords_related_dict[rec_name]['tp'] += true_count
_tp += true_count
elif rec_count > true_count:
hotwords_related_dict[rec_name]['tp'] += true_count
# fp: 不在label里但是在rec里
hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
_tp += true_count
_fp += rec_count - true_count
else:
hotwords_related_dict[rec_name]['tp'] += rec_count
# fn: 热词在label里但是不在rec里
hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
_tp += rec_count
_fn += true_count - rec_count
print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
_tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
))
# if hotword in _rec_list:
# hotwords_related_dict[rec_name]['tp'] += 1
# else:
# hotwords_related_dict[rec_name]['fn'] += 1
# 计算uwer, bwer, wer
for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
if code == Code.match:
ub_wer_dict[rec_name]["wer"].ref_words += 1
if lab_word in hot_true_list:
# tmp_ref.append(ref_tokens[ref_idx])
ub_wer_dict[rec_name]["b_wer"].ref_words += 1
else:
ub_wer_dict[rec_name]["u_wer"].ref_words += 1
elif code == Code.substitution:
ub_wer_dict[rec_name]["wer"].ref_words += 1
ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
if lab_word in hot_true_list:
# tmp_ref.append(ref_tokens[ref_idx])
ub_wer_dict[rec_name]["b_wer"].ref_words += 1
ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
else:
ub_wer_dict[rec_name]["u_wer"].ref_words += 1
ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
elif code == Code.deletion:
ub_wer_dict[rec_name]["wer"].ref_words += 1
ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
if lab_word in hot_true_list:
# tmp_ref.append(ref_tokens[ref_idx])
ub_wer_dict[rec_name]["b_wer"].ref_words += 1
ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
else:
ub_wer_dict[rec_name]["u_wer"].ref_words += 1
ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
elif code == Code.insertion:
ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
if rec_word in hot_true_list:
ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
else:
ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])):
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length - len_lab)
space['rec'].append(length - len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end=' ')
else:
print('lab:', end=' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end=' ')
else:
print('rec:', end=' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
# print('\n', end='\n')
lab1 = lab2
rec1 = rec2
print('\n', end='\n')
# break
if verbose:
print('===========================================================================')
print()
print(wrong_rec_but_in_ocr_dict)
for rec_name in rec_names:
result = calculators_dict[rec_name].overall()
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
hotwords_related_dict[rec_name]['tp'],
hotwords_related_dict[rec_name]['tn'],
hotwords_related_dict[rec_name]['fp'],
hotwords_related_dict[rec_name]['fn'],
sum([v for k, v in hotwords_related_dict[rec_name].items()]),
hotwords_related_dict[rec_name]['tp'] / (
hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
))
# tp: 热词在label里同时在rec里
# tn: 热词不在label里同时不在rec里
# fp: 热词不在label里但是在rec里
# fn: 热词在label里但是不在rec里
if not verbose:
print()
print()
if __name__ == "__main__":
args = get_args()
# print("")
print(args)
main(args)

View File

@ -1,13 +0,0 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="iic/LCB-NET",
model_revision="v1.0.0")
res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
print(res)

View File

@ -1,72 +0,0 @@
file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/"
CUDA_VISIBLE_DEVICES="0,1"
inference_device="cuda"
if [ ${inference_device} == "cuda" ]; then
nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
else
inference_batch_size=1
CUDA_VISIBLE_DEVICES=""
for JOB in $(seq ${nj}); do
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
done
fi
inference_dir="outputs/slidespeech_dev"
_logdir="${inference_dir}/logdir"
echo "inference_dir: ${inference_dir}"
mkdir -p "${_logdir}"
key_file1=${file_dir}/dev/wav.scp
key_file2=${file_dir}/dev/ocr.txt
split_scps1=
split_scps2=
for JOB in $(seq "${nj}"); do
split_scps1+=" ${_logdir}/wav.${JOB}.scp"
split_scps2+=" ${_logdir}/ocr.${JOB}.txt"
done
utils/split_scp.pl "${key_file1}" ${split_scps1}
utils/split_scp.pl "${key_file2}" ${split_scps2}
gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
for JOB in $(seq ${nj}); do
{
id=$((JOB-1))
gpuid=${gpuid_list_array[$id]}
export CUDA_VISIBLE_DEVICES=${gpuid}
python -m funasr.bin.inference \
--config-path=${file_dir} \
--config-name="config.yaml" \
++init_param=${file_dir}/model.pt \
++tokenizer_conf.token_list=${file_dir}/tokens.txt \
++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \
+data_type='["kaldi_ark", "text"]' \
++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \
++normalize_conf.stats_file=${file_dir}/am.mvn \
++output_dir="${inference_dir}/${JOB}" \
++device="${inference_device}" \
++ncpu=1 \
++disable_log=true &> ${_logdir}/log.${JOB}.txt
}&
done
wait
mkdir -p ${inference_dir}/1best_recog
for JOB in $(seq "${nj}"); do
cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token"
done
echo "Computing WER ..."
sed -e 's/ /\t/' -e 's/ //g' -e 's/▁/ /g' -e 's/\t /\t/' ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc
cp ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref
cp ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list
python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer
tail -n 3 ${inference_dir}/1best_recog/token.cer
./run_bwer_recall.sh ${inference_dir}/1best_recog/
tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5

View File

@ -1,11 +0,0 @@
#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave
#hotword_type=ocr_1ngram_top10_hotwords_list
hot_exp_suf=$1
python compute_wer_details.py --v 1 \
--ref ${hot_exp_suf}/token.ref \
--ref_ocr ${hot_exp_suf}/ocr.list \
--rec_name base \
--rec_file ${hot_exp_suf}/token.proc \
> ${hot_exp_suf}/BWER-UWER.results

View File

@ -1 +0,0 @@
../../aishell/paraformer/utils

View File

@ -7,10 +7,10 @@ from funasr import AutoModel
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4",
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model_revision="v2.0.4",
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# punc_model_revision="v2.0.4",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.4",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.4",
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
# spk_model_revision="v2.0.2",
)
@ -43,4 +43,4 @@ import soundfile
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
'''
'''

View File

@ -28,7 +28,7 @@ try:
from funasr.models.campplus.cluster_backend import ClusterBackend
except:
print("If you want to use the speaker diarization, please `pip install hdbscan`")
import pdb
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
@ -46,7 +46,6 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
@ -147,7 +146,7 @@ class AutoModel:
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
device = "cpu"
@ -169,6 +168,7 @@ class AutoModel:
vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
else:
vocab_size = -1
# build frontend
frontend = kwargs.get("frontend", None)
kwargs["input_size"] = None
@ -181,6 +181,7 @@ class AutoModel:
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.to(device)
# init_param
@ -223,9 +224,9 @@ class AutoModel:
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
@ -238,7 +239,6 @@ class AutoModel:
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len

View File

@ -3,6 +3,7 @@ from typing import Optional
from typing import Tuple
from typing import Union
import logging
import humanfriendly
import numpy as np
import torch
import torch.nn as nn
@ -15,10 +16,8 @@ from funasr.frontends.utils.log_mel import LogMel
from funasr.frontends.utils.stft import Stft
from funasr.frontends.utils.frontend import Frontend
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("frontend_classes", "DefaultFrontend")
class DefaultFrontend(nn.Module):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@ -26,7 +25,7 @@ class DefaultFrontend(nn.Module):
def __init__(
self,
fs: int = 16000,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
@ -41,14 +40,14 @@ class DefaultFrontend(nn.Module):
frontend_conf: Optional[dict] = None,
apply_stft: bool = True,
use_channel: int = None,
**kwargs,
):
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
self.fs = fs
if apply_stft:
self.stft = Stft(
@ -85,12 +84,8 @@ class DefaultFrontend(nn.Module):
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(input_lengths, list):
input_lengths = torch.tensor(input_lengths)
if input.dtype == torch.float64:
input = input.float()
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
@ -150,7 +145,7 @@ class MultiChannelFrontend(nn.Module):
def __init__(
self,
fs: int = 16000,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = None,
@ -173,6 +168,9 @@ class MultiChannelFrontend(nn.Module):
mc: bool = True
):
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
if win_length is None and hop_length is None:

View File

@ -47,7 +47,7 @@ from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
import pdb
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.

View File

@ -29,7 +29,7 @@ from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@ -63,6 +63,7 @@ class ContextualParaformer(Paraformer):
crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
if bias_encoder_type == 'lstm':
self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
@ -102,16 +103,17 @@ class ContextualParaformer(Paraformer):
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
dha_pad = kwargs.get("dha_pad")
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
stats = dict()
@ -126,11 +128,12 @@ class ContextualParaformer(Paraformer):
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# 2b. Attention decoder branch
loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att + loss_pre * self.predictor_weight
@ -168,24 +171,22 @@ class ContextualParaformer(Paraformer):
):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
ignore_id=self.ignore_id)
# -1. bias encoder
if self.use_decoder_embedding:
hw_embed = self.decoder.embed(hotword_pad)
else:
hw_embed = self.bias_embed(hotword_pad)
hw_embed, (_, _) = self.bias_encoder(hw_embed)
_ind = np.arange(0, hotword_pad.shape[0]).tolist()
selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
# 0. sampler
decoder_out_1st = None
if self.sampling_ratio > 0.0:
@ -194,7 +195,7 @@ class ContextualParaformer(Paraformer):
pre_acoustic_embeds, contextual_info)
else:
sematic_embeds = pre_acoustic_embeds
# 1. Forward decoder
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@ -210,7 +211,7 @@ class ContextualParaformer(Paraformer):
loss_ideal = None
'''
loss_ideal = None
if decoder_out_1st is None:
decoder_out_1st = decoder_out
# 2. Compute attention loss
@ -287,11 +288,10 @@ class ContextualParaformer(Paraformer):
enforce_sorted=False)
_, (h_n, _) = self.bias_encoder(hw_embed)
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
@ -305,42 +305,38 @@ class ContextualParaformer(Paraformer):
**kwargs,
):
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data[
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# hotword
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
@ -348,7 +344,8 @@ class ContextualParaformer(Paraformer):
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
pre_acoustic_embeds,
pre_token_length,

View File

@ -1,112 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2024 yufan
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention Return Weight layer definition."""
import math
import torch
from torch import nn
class MultiHeadedAttentionReturnWeight(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttentionReturnWeight object."""
super(MultiHeadedAttentionReturnWeight, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(self, query, key, value):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x), self.attn # (batch, time1, d_model)
def forward(self, query, key, value, mask):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)

View File

@ -1,392 +0,0 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Transformer encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import torch
from torch import nn
import logging
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x, mask, cache=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x_q, x, x, mask)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
@tables.register("encoder_classes", "TransformerTextEncoder")
class TransformerTextEncoder(nn.Module):
"""Transformer text encoder module.
Args:
input_size: input dim
output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
attention_dropout_rate: dropout rate in attention
positional_dropout_rate: dropout rate after adding positional encoding
input_layer: input layer type
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
positionwise_layer_type: linear of conv1d
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
super().__init__()
self._output_size = output_size
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size),
pos_enc_class(output_size, positional_dropout_rate),
)
self.normalize_before = normalize_before
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
MultiHeadedAttention(
attention_heads, output_size, attention_dropout_rate
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad = self.embed(xs_pad)
xs_pad, masks = self.encoders(xs_pad, masks)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
@tables.register("encoder_classes", "FusionSANEncoder")
class SelfSrcAttention(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
attention_heads,
attention_dim,
linear_units,
self_attention_dropout_rate,
src_attention_dropout_rate,
positional_dropout_rate,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an SelfSrcAttention object."""
super(SelfSrcAttention, self).__init__()
self.size = size
self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x, score = self.src_attn(x, memory, memory, memory_mask)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
@tables.register("encoder_classes", "ConvBiasPredictor")
class ConvPredictor(nn.Module):
def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
super().__init__()
self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
self.norm1 = LayerNorm(size)
self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
self.norm2 = LayerNorm(size)
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
self.output_linear = nn.Linear(size, 1)
def forward(self, text_enc, asr_enc):
# stage1 cross-attention
residual = text_enc
text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
# stage2 FFN
residual = text_enc
text_enc = self.norm1(text_enc)
text_enc = residual + self.feed_forward(text_enc)
# stage Conv predictor
text_enc = self.norm2(text_enc)
context = text_enc.transpose(1, 2)
queries = self.pad(context)
memory = self.conv1d(queries)
output = memory + context
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.output_linear(output)
if output.dim()==3:
output = output.squeeze(2)
return output

View File

@ -1,495 +0,0 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import logging
from typing import Union, Dict, List, Tuple, Optional
import time
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy
# from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
import pdb
@tables.register("model_classes", "LCBNet")
class LCBNet(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
https://arxiv.org/abs/2401.06390
"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
text_encoder: str = None,
text_encoder_conf: dict = None,
bias_predictor: str = None,
bias_predictor_conf: dict = None,
fusion_encoder: str = None,
fusion_encoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
select_num: int = 2,
select_length: int = 3,
insert_blank: bool = True,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
# lcbnet modules: text encoder, fusion encoder and bias predictor
text_encoder_class = tables.encoder_classes.get(text_encoder)
text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
bias_predictor_class = tables.encoder_classes.get(bias_predictor)
bias_predictor = bias_predictor_class(**bias_predictor_conf)
if decoder is not None:
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**decoder_conf,
)
if ctc_weight > 0.0:
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
)
self.blank_id = blank_id
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
# lcbnet
self.text_encoder = text_encoder
self.fusion_encoder = fusion_encoder
self.bias_predictor = bias_predictor
self.select_num = select_num
self.select_length = select_length
self.insert_blank = insert_blank
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
if self.encoder.interctc_use_conditioning:
self.encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.encoder.output_size()
)
self.interctc_weight = interctc_weight
# self.error_calculator = None
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
#
# if report_cer or report_wer:
# self.error_calculator = ErrorCalculator(
# token_list, sym_space, sym_blank, report_cer, report_wer
# )
#
self.error_calculator = None
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.share_embedding = share_embedding
if self.share_embedding:
self.decoder.embed = None
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# decoder: CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# decoder: Attention decoder branch
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
speech, speech_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def init_beam_search(self,
**kwargs,
):
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(
ctc=ctc
)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3),
ctc=kwargs.get("decoding_ctc_weight", 0.3),
lm=kwargs.get("lm_weight", 0.0),
ngram=kwargs.get("ngram_weight", 0.0),
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 20),
weights=weights,
scorers=scorers,
sos=self.sos,
eos=self.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
)
self.beam_search = beam_search
def inference(self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
audio_sample_list = sample_list[0]
if len(sample_list) >1:
ocr_sample_list = sample_list[1]
else:
ocr_sample_list = [[294, 0]]
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = 10
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
encoder_out = encoder_out + fusion_out
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "token": token, "text": text_postprocessed}
results.append(result_i)
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text_postprocessed
return results, meta_data

View File

@ -30,7 +30,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@ -128,7 +128,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
dha_pad = kwargs.get("dha_pad")
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
@ -209,20 +209,17 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
nfilter=50,
seaco_weight=1.0):
# decoder forward
decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
decoder_pred = torch.log_softmax(decoder_out, dim=-1)
if hw_list is not None:
hw_lengths = [len(i) for i in hw_list]
hw_list_ = [torch.Tensor(i).long() for i in hw_list]
hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
num_hot_word = contextual_info.shape[1]
_contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
# ASF Core
if nfilter > 0 and nfilter < num_hot_word:
hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
@ -242,7 +239,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
merged = self._merge(cif_attended, dec_attended)
dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation
dha_pred = torch.log_softmax(dha_output, dim=-1)
def _merge_res(dec_output, dha_output):
@ -256,8 +253,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
# logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
return logits
merged_pred = _merge_res(decoder_pred, dha_pred)
# import pdb; pdb.set_trace()
return merged_pred
else:
return decoder_pred
@ -307,6 +304,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
# extract fbank feats
@ -332,7 +330,6 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@ -341,14 +338,15 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
if torch.max(pre_token_length) < 1:
return []
decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
pre_acoustic_embeds,
pre_token_length,
hw_list=self.hotword_list)
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
pre_token_length)
results = []
b, n, d = decoder_out.size()
for i in range(b):

View File

@ -7,7 +7,7 @@ import logging
import torch
import torch.nn
import torch.optim
import pdb
def filter_state_dict(
dst_state: Dict[str, Union[float, torch.Tensor]],
@ -63,7 +63,6 @@ def load_pretrained_model(
dst_state = obj.state_dict()
print(f"ckpt: {path}")
if oss_bucket is None:
src_state = torch.load(path, map_location=map_location)
else:

View File

@ -13,25 +13,29 @@ try:
from funasr.download.file import download_from_url
except:
print("urllib is not installed, if you infer from url, please install it first.")
import pdb
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
return data_or_path_or_list_ret
else:
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
if data_type is None or data_type == "sound":
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
@ -52,22 +56,10 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
data_mat = kaldiio.load_mat(data_or_path_or_list)
if isinstance(data_mat, tuple):
audio_fs, mat = data_mat
else:
mat = data_mat
if mat.dtype == 'int16' or mat.dtype == 'int32':
mat = mat.astype(np.float64)
mat = mat / 32768
if mat.ndim ==2:
mat = mat[:,0]
data_or_path_or_list = mat
else:
pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
@ -89,6 +81,8 @@ def load_bytes(input):
return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
# import pdb;
# pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
@ -106,7 +100,9 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None,
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
# if data_type == "sound":
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):