mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Revert "Dev yf" (#1418)
This commit is contained in:
parent
920331972a
commit
d2c1204d91
0
examples/industrial_data_pretraining/contextual_paraformer/demo.py
Executable file → Normal file
0
examples/industrial_data_pretraining/contextual_paraformer/demo.py
Executable file → Normal file
2
examples/industrial_data_pretraining/contextual_paraformer/demo.sh
Executable file → Normal file
2
examples/industrial_data_pretraining/contextual_paraformer/demo.sh
Executable file → Normal file
@ -2,7 +2,7 @@
|
||||
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python ../../../funasr/bin/inference.py \
|
||||
python funasr/bin/inference.py \
|
||||
+model=${model} \
|
||||
+model_revision=${model_revision} \
|
||||
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
python -m funasr.bin.inference \
|
||||
--config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
|
||||
--config-name="config.yaml" \
|
||||
++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \
|
||||
++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \
|
||||
++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \
|
||||
++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \
|
||||
++output_dir="./outputs/debug2" \
|
||||
++device="" \
|
||||
@ -1,6 +0,0 @@
|
||||
export FUNASR_DIR=$PWD/../../../
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH
|
||||
@ -1,702 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
from enum import Enum
|
||||
import re, sys, unicodedata
|
||||
import codecs
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import pdb
|
||||
remove_tag = False
|
||||
spacelist = [" ", "\t", "\r", "\n"]
|
||||
puncts = [
|
||||
"!",
|
||||
",",
|
||||
"?",
|
||||
"、",
|
||||
"。",
|
||||
"!",
|
||||
",",
|
||||
";",
|
||||
"?",
|
||||
":",
|
||||
"「",
|
||||
"」",
|
||||
"︰",
|
||||
"『",
|
||||
"』",
|
||||
"《",
|
||||
"》",
|
||||
]
|
||||
|
||||
|
||||
class Code(Enum):
|
||||
match = 1
|
||||
substitution = 2
|
||||
insertion = 3
|
||||
deletion = 4
|
||||
|
||||
|
||||
class WordError(object):
|
||||
def __init__(self):
|
||||
self.errors = {
|
||||
Code.substitution: 0,
|
||||
Code.insertion: 0,
|
||||
Code.deletion: 0,
|
||||
}
|
||||
self.ref_words = 0
|
||||
|
||||
def get_wer(self):
|
||||
assert self.ref_words != 0
|
||||
errors = (
|
||||
self.errors[Code.substitution]
|
||||
+ self.errors[Code.insertion]
|
||||
+ self.errors[Code.deletion]
|
||||
)
|
||||
return 100.0 * errors / self.ref_words
|
||||
|
||||
def get_result_string(self):
|
||||
return (
|
||||
f"error_rate={self.get_wer():.4f}, "
|
||||
f"ref_words={self.ref_words}, "
|
||||
f"subs={self.errors[Code.substitution]}, "
|
||||
f"ins={self.errors[Code.insertion]}, "
|
||||
f"dels={self.errors[Code.deletion]}"
|
||||
)
|
||||
|
||||
|
||||
def characterize(string):
|
||||
res = []
|
||||
i = 0
|
||||
while i < len(string):
|
||||
char = string[i]
|
||||
if char in puncts:
|
||||
i += 1
|
||||
continue
|
||||
cat1 = unicodedata.category(char)
|
||||
# https://unicodebook.readthedocs.io/unicode.html#unicode-categories
|
||||
if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned
|
||||
i += 1
|
||||
continue
|
||||
if cat1 == "Lo": # letter-other
|
||||
res.append(char)
|
||||
i += 1
|
||||
else:
|
||||
# some input looks like: <unk><noise>, we want to separate it to two words.
|
||||
sep = " "
|
||||
if char == "<":
|
||||
sep = ">"
|
||||
j = i + 1
|
||||
while j < len(string):
|
||||
c = string[j]
|
||||
if ord(c) >= 128 or (c in spacelist) or (c == sep):
|
||||
break
|
||||
j += 1
|
||||
if j < len(string) and string[j] == ">":
|
||||
j += 1
|
||||
res.append(string[i:j])
|
||||
i = j
|
||||
return res
|
||||
|
||||
|
||||
def stripoff_tags(x):
|
||||
if not x:
|
||||
return ""
|
||||
chars = []
|
||||
i = 0
|
||||
T = len(x)
|
||||
while i < T:
|
||||
if x[i] == "<":
|
||||
while i < T and x[i] != ">":
|
||||
i += 1
|
||||
i += 1
|
||||
else:
|
||||
chars.append(x[i])
|
||||
i += 1
|
||||
return "".join(chars)
|
||||
|
||||
|
||||
def normalize(sentence, ignore_words, cs, split=None):
|
||||
"""sentence, ignore_words are both in unicode"""
|
||||
new_sentence = []
|
||||
for token in sentence:
|
||||
x = token
|
||||
if not cs:
|
||||
x = x.upper()
|
||||
if x in ignore_words:
|
||||
continue
|
||||
if remove_tag:
|
||||
x = stripoff_tags(x)
|
||||
if not x:
|
||||
continue
|
||||
if split and x in split:
|
||||
new_sentence += split[x]
|
||||
else:
|
||||
new_sentence.append(x)
|
||||
return new_sentence
|
||||
|
||||
|
||||
class Calculator:
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.space = []
|
||||
self.cost = {}
|
||||
self.cost["cor"] = 0
|
||||
self.cost["sub"] = 1
|
||||
self.cost["del"] = 1
|
||||
self.cost["ins"] = 1
|
||||
|
||||
def calculate(self, lab, rec):
|
||||
# Initialization
|
||||
lab.insert(0, "")
|
||||
rec.insert(0, "")
|
||||
while len(self.space) < len(lab):
|
||||
self.space.append([])
|
||||
for row in self.space:
|
||||
for element in row:
|
||||
element["dist"] = 0
|
||||
element["error"] = "non"
|
||||
while len(row) < len(rec):
|
||||
row.append({"dist": 0, "error": "non"})
|
||||
for i in range(len(lab)):
|
||||
self.space[i][0]["dist"] = i
|
||||
self.space[i][0]["error"] = "del"
|
||||
for j in range(len(rec)):
|
||||
self.space[0][j]["dist"] = j
|
||||
self.space[0][j]["error"] = "ins"
|
||||
self.space[0][0]["error"] = "non"
|
||||
for token in lab:
|
||||
if token not in self.data and len(token) > 0:
|
||||
self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
|
||||
for token in rec:
|
||||
if token not in self.data and len(token) > 0:
|
||||
self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
|
||||
# Computing edit distance
|
||||
for i, lab_token in enumerate(lab):
|
||||
for j, rec_token in enumerate(rec):
|
||||
if i == 0 or j == 0:
|
||||
continue
|
||||
min_dist = sys.maxsize
|
||||
min_error = "none"
|
||||
dist = self.space[i - 1][j]["dist"] + self.cost["del"]
|
||||
error = "del"
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
|
||||
error = "ins"
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
if lab_token == rec_token.replace("<BIAS>", ""):
|
||||
dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
|
||||
error = "cor"
|
||||
else:
|
||||
dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
|
||||
error = "sub"
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
self.space[i][j]["dist"] = min_dist
|
||||
self.space[i][j]["error"] = min_error
|
||||
# Tracing back
|
||||
result = {
|
||||
"lab": [],
|
||||
"rec": [],
|
||||
"code": [],
|
||||
"all": 0,
|
||||
"cor": 0,
|
||||
"sub": 0,
|
||||
"ins": 0,
|
||||
"del": 0,
|
||||
}
|
||||
i = len(lab) - 1
|
||||
j = len(rec) - 1
|
||||
while True:
|
||||
if self.space[i][j]["error"] == "cor": # correct
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
|
||||
self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
|
||||
result["all"] = result["all"] + 1
|
||||
result["cor"] = result["cor"] + 1
|
||||
result["lab"].insert(0, lab[i])
|
||||
result["rec"].insert(0, rec[j])
|
||||
result["code"].insert(0, Code.match)
|
||||
i = i - 1
|
||||
j = j - 1
|
||||
elif self.space[i][j]["error"] == "sub": # substitution
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
|
||||
self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
|
||||
result["all"] = result["all"] + 1
|
||||
result["sub"] = result["sub"] + 1
|
||||
result["lab"].insert(0, lab[i])
|
||||
result["rec"].insert(0, rec[j])
|
||||
result["code"].insert(0, Code.substitution)
|
||||
i = i - 1
|
||||
j = j - 1
|
||||
elif self.space[i][j]["error"] == "del": # deletion
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
|
||||
self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
|
||||
result["all"] = result["all"] + 1
|
||||
result["del"] = result["del"] + 1
|
||||
result["lab"].insert(0, lab[i])
|
||||
result["rec"].insert(0, "")
|
||||
result["code"].insert(0, Code.deletion)
|
||||
i = i - 1
|
||||
elif self.space[i][j]["error"] == "ins": # insertion
|
||||
if len(rec[j]) > 0:
|
||||
self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
|
||||
result["ins"] = result["ins"] + 1
|
||||
result["lab"].insert(0, "")
|
||||
result["rec"].insert(0, rec[j])
|
||||
result["code"].insert(0, Code.insertion)
|
||||
j = j - 1
|
||||
elif self.space[i][j]["error"] == "non": # starting point
|
||||
break
|
||||
else: # shouldn't reach here
|
||||
print(
|
||||
"this should not happen , i = {i} , j = {j} , error = {error}".format(
|
||||
i=i, j=j, error=self.space[i][j]["error"]
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def overall(self):
|
||||
result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
|
||||
for token in self.data:
|
||||
result["all"] = result["all"] + self.data[token]["all"]
|
||||
result["cor"] = result["cor"] + self.data[token]["cor"]
|
||||
result["sub"] = result["sub"] + self.data[token]["sub"]
|
||||
result["ins"] = result["ins"] + self.data[token]["ins"]
|
||||
result["del"] = result["del"] + self.data[token]["del"]
|
||||
return result
|
||||
|
||||
def cluster(self, data):
|
||||
result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
|
||||
for token in data:
|
||||
if token in self.data:
|
||||
result["all"] = result["all"] + self.data[token]["all"]
|
||||
result["cor"] = result["cor"] + self.data[token]["cor"]
|
||||
result["sub"] = result["sub"] + self.data[token]["sub"]
|
||||
result["ins"] = result["ins"] + self.data[token]["ins"]
|
||||
result["del"] = result["del"] + self.data[token]["del"]
|
||||
return result
|
||||
|
||||
def keys(self):
|
||||
return list(self.data.keys())
|
||||
|
||||
|
||||
def width(string):
|
||||
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
|
||||
|
||||
|
||||
def default_cluster(word):
|
||||
unicode_names = [unicodedata.name(char) for char in word]
|
||||
for i in reversed(range(len(unicode_names))):
|
||||
if unicode_names[i].startswith("DIGIT"): # 1
|
||||
unicode_names[i] = "Number" # 'DIGIT'
|
||||
elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
|
||||
i
|
||||
].startswith("CJK COMPATIBILITY IDEOGRAPH"):
|
||||
# 明 / 郎
|
||||
unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
|
||||
elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
|
||||
i
|
||||
].startswith("LATIN SMALL LETTER"):
|
||||
# A / a
|
||||
unicode_names[i] = "English" # 'LATIN LETTER'
|
||||
elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め
|
||||
unicode_names[i] = "Japanese" # 'GANA LETTER'
|
||||
elif (
|
||||
unicode_names[i].startswith("AMPERSAND")
|
||||
or unicode_names[i].startswith("APOSTROPHE")
|
||||
or unicode_names[i].startswith("COMMERCIAL AT")
|
||||
or unicode_names[i].startswith("DEGREE CELSIUS")
|
||||
or unicode_names[i].startswith("EQUALS SIGN")
|
||||
or unicode_names[i].startswith("FULL STOP")
|
||||
or unicode_names[i].startswith("HYPHEN-MINUS")
|
||||
or unicode_names[i].startswith("LOW LINE")
|
||||
or unicode_names[i].startswith("NUMBER SIGN")
|
||||
or unicode_names[i].startswith("PLUS SIGN")
|
||||
or unicode_names[i].startswith("SEMICOLON")
|
||||
):
|
||||
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
|
||||
del unicode_names[i]
|
||||
else:
|
||||
return "Other"
|
||||
if len(unicode_names) == 0:
|
||||
return "Other"
|
||||
if len(unicode_names) == 1:
|
||||
return unicode_names[0]
|
||||
for i in range(len(unicode_names) - 1):
|
||||
if unicode_names[i] != unicode_names[i + 1]:
|
||||
return "Other"
|
||||
return unicode_names[0]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="wer cal")
|
||||
parser.add_argument("--ref", type=str, help="Text input path")
|
||||
parser.add_argument("--ref_ocr", type=str, help="Text input path")
|
||||
parser.add_argument("--rec_name", type=str, action="append", default=[])
|
||||
parser.add_argument("--rec_file", type=str, action="append", default=[])
|
||||
parser.add_argument("--verbose", type=int, default=1, help="show")
|
||||
parser.add_argument("--char", type=bool, default=True, help="show")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
cluster_file = ""
|
||||
ignore_words = set()
|
||||
tochar = args.char
|
||||
verbose = args.verbose
|
||||
padding_symbol = " "
|
||||
case_sensitive = False
|
||||
max_words_per_line = sys.maxsize
|
||||
split = None
|
||||
|
||||
if not case_sensitive:
|
||||
ig = set([w.upper() for w in ignore_words])
|
||||
ignore_words = ig
|
||||
|
||||
default_clusters = {}
|
||||
default_words = {}
|
||||
ref_file = args.ref
|
||||
ref_ocr = args.ref_ocr
|
||||
rec_files = args.rec_file
|
||||
rec_names = args.rec_name
|
||||
assert len(rec_files) == len(rec_names)
|
||||
|
||||
# load ocr
|
||||
ref_ocr_dict = {}
|
||||
with codecs.open(ref_ocr, "r", "utf-8") as fh:
|
||||
for line in fh:
|
||||
if "$" in line:
|
||||
line = line.replace("$", " ")
|
||||
if tochar:
|
||||
array = characterize(line)
|
||||
else:
|
||||
array = line.strip().split()
|
||||
if len(array) == 0:
|
||||
continue
|
||||
fid = array[0]
|
||||
ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
|
||||
|
||||
if split and not case_sensitive:
|
||||
newsplit = dict()
|
||||
for w in split:
|
||||
words = split[w]
|
||||
for i in range(len(words)):
|
||||
words[i] = words[i].upper()
|
||||
newsplit[w.upper()] = words
|
||||
split = newsplit
|
||||
|
||||
rec_sets = {}
|
||||
calculators_dict = dict()
|
||||
ub_wer_dict = dict()
|
||||
hotwords_related_dict = dict() # 记录recall相关的内容
|
||||
for i, hyp_file in enumerate(rec_files):
|
||||
rec_sets[rec_names[i]] = dict()
|
||||
with codecs.open(hyp_file, "r", "utf-8") as fh:
|
||||
for line in fh:
|
||||
if tochar:
|
||||
array = characterize(line)
|
||||
else:
|
||||
array = line.strip().split()
|
||||
if len(array) == 0:
|
||||
continue
|
||||
fid = array[0]
|
||||
rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
|
||||
|
||||
calculators_dict[rec_names[i]] = Calculator()
|
||||
ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
|
||||
hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
|
||||
# tp: 热词在label里,同时在rec里
|
||||
# tn: 热词不在label里,同时不在rec里
|
||||
# fp: 热词不在label里,但是在rec里
|
||||
# fn: 热词在label里,但是不在rec里
|
||||
|
||||
# record wrong label but in ocr
|
||||
wrong_rec_but_in_ocr_dict = {}
|
||||
for rec_name in rec_names:
|
||||
wrong_rec_but_in_ocr_dict[rec_name] = 0
|
||||
|
||||
_file_total_len = 0
|
||||
with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
|
||||
_file_total_len = int(pipe.read().strip())
|
||||
|
||||
# compute error rate on the interaction of reference file and hyp file
|
||||
for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
|
||||
if tochar:
|
||||
array = characterize(line)
|
||||
else:
|
||||
array = line.rstrip('\n').split()
|
||||
if len(array) == 0: continue
|
||||
fid = array[0]
|
||||
lab = normalize(array[1:], ignore_words, case_sensitive, split)
|
||||
|
||||
if verbose:
|
||||
print('\nutt: %s' % fid)
|
||||
|
||||
ocr_text = ref_ocr_dict[fid]
|
||||
ocr_set = set(ocr_text)
|
||||
print('ocr: {}'.format(" ".join(ocr_text)))
|
||||
list_match = [] # 指label里面在ocr里面的内容
|
||||
list_not_mathch = []
|
||||
tmp_error = 0
|
||||
tmp_match = 0
|
||||
for index in range(len(lab)):
|
||||
# text_list.append(uttlist[index+1])
|
||||
if lab[index] not in ocr_set:
|
||||
tmp_error += 1
|
||||
list_not_mathch.append(lab[index])
|
||||
else:
|
||||
tmp_match += 1
|
||||
list_match.append(lab[index])
|
||||
print('label in ocr: {}'.format(" ".join(list_match)))
|
||||
|
||||
# for each reco file
|
||||
base_wrong_ocr_wer = None
|
||||
ocr_wrong_ocr_wer = None
|
||||
|
||||
for rec_name in rec_names:
|
||||
rec_set = rec_sets[rec_name]
|
||||
if fid not in rec_set:
|
||||
continue
|
||||
rec = rec_set[fid]
|
||||
|
||||
# print(rec)
|
||||
for word in rec + lab:
|
||||
if word not in default_words:
|
||||
default_cluster_name = default_cluster(word)
|
||||
if default_cluster_name not in default_clusters:
|
||||
default_clusters[default_cluster_name] = {}
|
||||
if word not in default_clusters[default_cluster_name]:
|
||||
default_clusters[default_cluster_name][word] = 1
|
||||
default_words[word] = default_cluster_name
|
||||
|
||||
result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
|
||||
if verbose:
|
||||
if result['all'] != 0:
|
||||
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
|
||||
else:
|
||||
wer = 0.0
|
||||
print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
|
||||
print('N=%d C=%d S=%d D=%d I=%d' %
|
||||
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
|
||||
|
||||
|
||||
# print(result['rec'])
|
||||
wrong_rec_but_in_ocr = []
|
||||
for idx in range(len(result['lab'])):
|
||||
if result['lab'][idx] != "":
|
||||
if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
|
||||
if result['lab'][idx] in list_match:
|
||||
wrong_rec_but_in_ocr.append(result['lab'][idx])
|
||||
wrong_rec_but_in_ocr_dict[rec_name] += 1
|
||||
print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
|
||||
|
||||
if rec_name == "base":
|
||||
base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
|
||||
if "ocr" in rec_name or "hot" in rec_name:
|
||||
ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
|
||||
if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
|
||||
print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
|
||||
elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
|
||||
print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
|
||||
|
||||
# recall = 0
|
||||
# false_alarm = 0
|
||||
# for idx in range(len(result['lab'])):
|
||||
# if "<BIAS>" in result['rec'][idx]:
|
||||
# if result['rec'][idx].replace("<BIAS>", "") in list_match:
|
||||
# recall += 1
|
||||
# else:
|
||||
# false_alarm += 1
|
||||
# print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
|
||||
# recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
|
||||
# ))
|
||||
# tp: 热词在label里,同时在rec里
|
||||
# tn: 热词不在label里,同时不在rec里
|
||||
# fp: 热词不在label里,但是在rec里
|
||||
# fn: 热词在label里,但是不在rec里
|
||||
_rec_list = [word.replace("<BIAS>", "") for word in rec]
|
||||
_label_list = [word for word in lab]
|
||||
_tp = _tn = _fp = _fn = 0
|
||||
hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
|
||||
hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
|
||||
for badhotword in hot_bad_list:
|
||||
count = len([word for word in _rec_list if word == badhotword])
|
||||
# print(f"bad {badhotword} count: {count}")
|
||||
# for word in _rec_list:
|
||||
# if badhotword == word:
|
||||
# count += 1
|
||||
if count == 0:
|
||||
hotwords_related_dict[rec_name]['tn'] += 1
|
||||
_tn += 1
|
||||
# fp: 0
|
||||
else:
|
||||
hotwords_related_dict[rec_name]['fp'] += count
|
||||
_fp += count
|
||||
# tn: 0
|
||||
# if badhotword in _rec_list:
|
||||
# hotwords_related_dict[rec_name]['fp'] += 1
|
||||
# else:
|
||||
# hotwords_related_dict[rec_name]['tn'] += 1
|
||||
for hotword in hot_true_list:
|
||||
true_count = len([word for word in _label_list if hotword == word])
|
||||
rec_count = len([word for word in _rec_list if hotword == word])
|
||||
# print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
|
||||
if rec_count == true_count:
|
||||
hotwords_related_dict[rec_name]['tp'] += true_count
|
||||
_tp += true_count
|
||||
elif rec_count > true_count:
|
||||
hotwords_related_dict[rec_name]['tp'] += true_count
|
||||
# fp: 不在label里,但是在rec里
|
||||
hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
|
||||
_tp += true_count
|
||||
_fp += rec_count - true_count
|
||||
else:
|
||||
hotwords_related_dict[rec_name]['tp'] += rec_count
|
||||
# fn: 热词在label里,但是不在rec里
|
||||
hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
|
||||
_tp += rec_count
|
||||
_fn += true_count - rec_count
|
||||
print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
|
||||
_tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
|
||||
))
|
||||
|
||||
# if hotword in _rec_list:
|
||||
# hotwords_related_dict[rec_name]['tp'] += 1
|
||||
# else:
|
||||
# hotwords_related_dict[rec_name]['fn'] += 1
|
||||
# 计算uwer, bwer, wer
|
||||
for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
|
||||
if code == Code.match:
|
||||
ub_wer_dict[rec_name]["wer"].ref_words += 1
|
||||
if lab_word in hot_true_list:
|
||||
# tmp_ref.append(ref_tokens[ref_idx])
|
||||
ub_wer_dict[rec_name]["b_wer"].ref_words += 1
|
||||
else:
|
||||
ub_wer_dict[rec_name]["u_wer"].ref_words += 1
|
||||
elif code == Code.substitution:
|
||||
ub_wer_dict[rec_name]["wer"].ref_words += 1
|
||||
ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
|
||||
if lab_word in hot_true_list:
|
||||
# tmp_ref.append(ref_tokens[ref_idx])
|
||||
ub_wer_dict[rec_name]["b_wer"].ref_words += 1
|
||||
ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
|
||||
else:
|
||||
ub_wer_dict[rec_name]["u_wer"].ref_words += 1
|
||||
ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
|
||||
elif code == Code.deletion:
|
||||
ub_wer_dict[rec_name]["wer"].ref_words += 1
|
||||
ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
|
||||
if lab_word in hot_true_list:
|
||||
# tmp_ref.append(ref_tokens[ref_idx])
|
||||
ub_wer_dict[rec_name]["b_wer"].ref_words += 1
|
||||
ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
|
||||
else:
|
||||
ub_wer_dict[rec_name]["u_wer"].ref_words += 1
|
||||
ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
|
||||
elif code == Code.insertion:
|
||||
ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
|
||||
if rec_word in hot_true_list:
|
||||
ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
|
||||
else:
|
||||
ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
|
||||
|
||||
space = {}
|
||||
space['lab'] = []
|
||||
space['rec'] = []
|
||||
for idx in range(len(result['lab'])):
|
||||
len_lab = width(result['lab'][idx])
|
||||
len_rec = width(result['rec'][idx])
|
||||
length = max(len_lab, len_rec)
|
||||
space['lab'].append(length - len_lab)
|
||||
space['rec'].append(length - len_rec)
|
||||
upper_lab = len(result['lab'])
|
||||
upper_rec = len(result['rec'])
|
||||
lab1, rec1 = 0, 0
|
||||
while lab1 < upper_lab or rec1 < upper_rec:
|
||||
if verbose > 1:
|
||||
print('lab(%s):' % fid.encode('utf-8'), end=' ')
|
||||
else:
|
||||
print('lab:', end=' ')
|
||||
lab2 = min(upper_lab, lab1 + max_words_per_line)
|
||||
for idx in range(lab1, lab2):
|
||||
token = result['lab'][idx]
|
||||
print('{token}'.format(token=token), end='')
|
||||
for n in range(space['lab'][idx]):
|
||||
print(padding_symbol, end='')
|
||||
print(' ', end='')
|
||||
print()
|
||||
if verbose > 1:
|
||||
print('rec(%s):' % fid.encode('utf-8'), end=' ')
|
||||
else:
|
||||
print('rec:', end=' ')
|
||||
|
||||
rec2 = min(upper_rec, rec1 + max_words_per_line)
|
||||
for idx in range(rec1, rec2):
|
||||
token = result['rec'][idx]
|
||||
print('{token}'.format(token=token), end='')
|
||||
for n in range(space['rec'][idx]):
|
||||
print(padding_symbol, end='')
|
||||
print(' ', end='')
|
||||
print()
|
||||
# print('\n', end='\n')
|
||||
lab1 = lab2
|
||||
rec1 = rec2
|
||||
print('\n', end='\n')
|
||||
# break
|
||||
if verbose:
|
||||
print('===========================================================================')
|
||||
print()
|
||||
|
||||
print(wrong_rec_but_in_ocr_dict)
|
||||
for rec_name in rec_names:
|
||||
result = calculators_dict[rec_name].overall()
|
||||
|
||||
if result['all'] != 0:
|
||||
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
|
||||
else:
|
||||
wer = 0.0
|
||||
print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
|
||||
print('N=%d C=%d S=%d D=%d I=%d' %
|
||||
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
|
||||
print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
|
||||
print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
|
||||
print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
|
||||
|
||||
print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
|
||||
hotwords_related_dict[rec_name]['tp'],
|
||||
hotwords_related_dict[rec_name]['tn'],
|
||||
hotwords_related_dict[rec_name]['fp'],
|
||||
hotwords_related_dict[rec_name]['fn'],
|
||||
sum([v for k, v in hotwords_related_dict[rec_name].items()]),
|
||||
hotwords_related_dict[rec_name]['tp'] / (
|
||||
hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
|
||||
) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
|
||||
))
|
||||
|
||||
# tp: 热词在label里,同时在rec里
|
||||
# tn: 热词不在label里,同时不在rec里
|
||||
# fp: 热词不在label里,但是在rec里
|
||||
# fn: 热词在label里,但是不在rec里
|
||||
if not verbose:
|
||||
print()
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
# print("")
|
||||
print(args)
|
||||
main(args)
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/LCB-NET",
|
||||
model_revision="v1.0.0")
|
||||
|
||||
res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
|
||||
|
||||
print(res)
|
||||
@ -1,72 +0,0 @@
|
||||
file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/"
|
||||
CUDA_VISIBLE_DEVICES="0,1"
|
||||
inference_device="cuda"
|
||||
|
||||
if [ ${inference_device} == "cuda" ]; then
|
||||
nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
else
|
||||
inference_batch_size=1
|
||||
CUDA_VISIBLE_DEVICES=""
|
||||
for JOB in $(seq ${nj}); do
|
||||
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
|
||||
done
|
||||
fi
|
||||
|
||||
inference_dir="outputs/slidespeech_dev"
|
||||
_logdir="${inference_dir}/logdir"
|
||||
echo "inference_dir: ${inference_dir}"
|
||||
|
||||
mkdir -p "${_logdir}"
|
||||
key_file1=${file_dir}/dev/wav.scp
|
||||
key_file2=${file_dir}/dev/ocr.txt
|
||||
split_scps1=
|
||||
split_scps2=
|
||||
for JOB in $(seq "${nj}"); do
|
||||
split_scps1+=" ${_logdir}/wav.${JOB}.scp"
|
||||
split_scps2+=" ${_logdir}/ocr.${JOB}.txt"
|
||||
done
|
||||
utils/split_scp.pl "${key_file1}" ${split_scps1}
|
||||
utils/split_scp.pl "${key_file2}" ${split_scps2}
|
||||
|
||||
gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
|
||||
for JOB in $(seq ${nj}); do
|
||||
{
|
||||
id=$((JOB-1))
|
||||
gpuid=${gpuid_list_array[$id]}
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=${gpuid}
|
||||
|
||||
python -m funasr.bin.inference \
|
||||
--config-path=${file_dir} \
|
||||
--config-name="config.yaml" \
|
||||
++init_param=${file_dir}/model.pt \
|
||||
++tokenizer_conf.token_list=${file_dir}/tokens.txt \
|
||||
++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \
|
||||
+data_type='["kaldi_ark", "text"]' \
|
||||
++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \
|
||||
++normalize_conf.stats_file=${file_dir}/am.mvn \
|
||||
++output_dir="${inference_dir}/${JOB}" \
|
||||
++device="${inference_device}" \
|
||||
++ncpu=1 \
|
||||
++disable_log=true &> ${_logdir}/log.${JOB}.txt
|
||||
|
||||
}&
|
||||
done
|
||||
wait
|
||||
|
||||
|
||||
mkdir -p ${inference_dir}/1best_recog
|
||||
|
||||
for JOB in $(seq "${nj}"); do
|
||||
cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token"
|
||||
done
|
||||
|
||||
echo "Computing WER ..."
|
||||
sed -e 's/ /\t/' -e 's/ //g' -e 's/▁/ /g' -e 's/\t /\t/' ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc
|
||||
cp ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref
|
||||
cp ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list
|
||||
python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer
|
||||
tail -n 3 ${inference_dir}/1best_recog/token.cer
|
||||
|
||||
./run_bwer_recall.sh ${inference_dir}/1best_recog/
|
||||
tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5
|
||||
@ -1,11 +0,0 @@
|
||||
#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave
|
||||
#hotword_type=ocr_1ngram_top10_hotwords_list
|
||||
hot_exp_suf=$1
|
||||
|
||||
|
||||
python compute_wer_details.py --v 1 \
|
||||
--ref ${hot_exp_suf}/token.ref \
|
||||
--ref_ocr ${hot_exp_suf}/ocr.list \
|
||||
--rec_name base \
|
||||
--rec_file ${hot_exp_suf}/token.proc \
|
||||
> ${hot_exp_suf}/BWER-UWER.results
|
||||
@ -1 +0,0 @@
|
||||
../../aishell/paraformer/utils
|
||||
@ -7,10 +7,10 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model_revision="v2.0.4",
|
||||
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_model_revision="v2.0.4",
|
||||
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# punc_model_revision="v2.0.4",
|
||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_model_revision="v2.0.4",
|
||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
punc_model_revision="v2.0.4",
|
||||
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
|
||||
# spk_model_revision="v2.0.2",
|
||||
)
|
||||
@ -43,4 +43,4 @@ import soundfile
|
||||
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
|
||||
speech, sample_rate = soundfile.read(wav_file)
|
||||
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
|
||||
'''
|
||||
'''
|
||||
@ -28,7 +28,7 @@ try:
|
||||
from funasr.models.campplus.cluster_backend import ClusterBackend
|
||||
except:
|
||||
print("If you want to use the speaker diarization, please `pip install hdbscan`")
|
||||
import pdb
|
||||
|
||||
|
||||
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
||||
"""
|
||||
@ -46,7 +46,6 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
||||
chars = string.ascii_letters + string.digits
|
||||
if isinstance(data_in, str) and data_in.startswith('http'): # url
|
||||
data_in = download_from_url(data_in)
|
||||
|
||||
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
|
||||
_, file_extension = os.path.splitext(data_in)
|
||||
file_extension = file_extension.lower()
|
||||
@ -147,7 +146,7 @@ class AutoModel:
|
||||
kwargs = download_model(**kwargs)
|
||||
|
||||
set_all_random_seed(kwargs.get("seed", 0))
|
||||
|
||||
|
||||
device = kwargs.get("device", "cuda")
|
||||
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
||||
device = "cpu"
|
||||
@ -169,6 +168,7 @@ class AutoModel:
|
||||
vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
|
||||
else:
|
||||
vocab_size = -1
|
||||
|
||||
# build frontend
|
||||
frontend = kwargs.get("frontend", None)
|
||||
kwargs["input_size"] = None
|
||||
@ -181,6 +181,7 @@ class AutoModel:
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"])
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
||||
|
||||
model.to(device)
|
||||
|
||||
# init_param
|
||||
@ -223,9 +224,9 @@ class AutoModel:
|
||||
batch_size = kwargs.get("batch_size", 1)
|
||||
# if kwargs.get("device", "cpu") == "cpu":
|
||||
# batch_size = 1
|
||||
|
||||
|
||||
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
|
||||
|
||||
|
||||
speed_stats = {}
|
||||
asr_result_list = []
|
||||
num_samples = len(data_list)
|
||||
@ -238,7 +239,6 @@ class AutoModel:
|
||||
data_batch = data_list[beg_idx:end_idx]
|
||||
key_batch = key_list[beg_idx:end_idx]
|
||||
batch = {"data_in": data_batch, "key": key_batch}
|
||||
|
||||
if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
|
||||
batch["data_in"] = data_batch[0]
|
||||
batch["data_lengths"] = input_len
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -15,10 +16,8 @@ from funasr.frontends.utils.log_mel import LogMel
|
||||
from funasr.frontends.utils.stft import Stft
|
||||
from funasr.frontends.utils.frontend import Frontend
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("frontend_classes", "DefaultFrontend")
|
||||
class DefaultFrontend(nn.Module):
|
||||
"""Conventional frontend structure for ASR.
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
@ -26,7 +25,7 @@ class DefaultFrontend(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
@ -41,14 +40,14 @@ class DefaultFrontend(nn.Module):
|
||||
frontend_conf: Optional[dict] = None,
|
||||
apply_stft: bool = True,
|
||||
use_channel: int = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
self.fs = fs
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
@ -85,12 +84,8 @@ class DefaultFrontend(nn.Module):
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(input_lengths, list):
|
||||
input_lengths = torch.tensor(input_lengths)
|
||||
if input.dtype == torch.float64:
|
||||
input = input.float()
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
if self.stft is not None:
|
||||
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
||||
@ -150,7 +145,7 @@ class MultiChannelFrontend(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = None,
|
||||
@ -173,6 +168,9 @@ class MultiChannelFrontend(nn.Module):
|
||||
mc: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
if win_length is None and hop_length is None:
|
||||
|
||||
@ -47,7 +47,7 @@ from funasr.models.transformer.utils.subsampling import check_short_utt
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
|
||||
from funasr.models.transformer.utils.subsampling import StreamingConvInput
|
||||
from funasr.register import tables
|
||||
import pdb
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
@ -29,7 +29,7 @@ from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
import pdb
|
||||
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
@ -63,6 +63,7 @@ class ContextualParaformer(Paraformer):
|
||||
crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
|
||||
bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
|
||||
|
||||
|
||||
if bias_encoder_type == 'lstm':
|
||||
self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
|
||||
self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
|
||||
@ -102,16 +103,17 @@ class ContextualParaformer(Paraformer):
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
hotword_pad = kwargs.get("hotword_pad")
|
||||
hotword_lengths = kwargs.get("hotword_lengths")
|
||||
dha_pad = kwargs.get("dha_pad")
|
||||
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
|
||||
loss_ctc, cer_ctc = None, None
|
||||
|
||||
stats = dict()
|
||||
@ -126,11 +128,12 @@ class ContextualParaformer(Paraformer):
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
|
||||
)
|
||||
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att + loss_pre * self.predictor_weight
|
||||
@ -168,24 +171,22 @@ class ContextualParaformer(Paraformer):
|
||||
):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
|
||||
if self.predictor_bias == 1:
|
||||
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_pad_lens = ys_pad_lens + self.predictor_bias
|
||||
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
|
||||
# -1. bias encoder
|
||||
if self.use_decoder_embedding:
|
||||
hw_embed = self.decoder.embed(hotword_pad)
|
||||
else:
|
||||
hw_embed = self.bias_embed(hotword_pad)
|
||||
|
||||
hw_embed, (_, _) = self.bias_encoder(hw_embed)
|
||||
_ind = np.arange(0, hotword_pad.shape[0]).tolist()
|
||||
selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
|
||||
contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
|
||||
|
||||
|
||||
# 0. sampler
|
||||
decoder_out_1st = None
|
||||
if self.sampling_ratio > 0.0:
|
||||
@ -194,7 +195,7 @@ class ContextualParaformer(Paraformer):
|
||||
pre_acoustic_embeds, contextual_info)
|
||||
else:
|
||||
sematic_embeds = pre_acoustic_embeds
|
||||
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
|
||||
@ -210,7 +211,7 @@ class ContextualParaformer(Paraformer):
|
||||
loss_ideal = None
|
||||
'''
|
||||
loss_ideal = None
|
||||
|
||||
|
||||
if decoder_out_1st is None:
|
||||
decoder_out_1st = decoder_out
|
||||
# 2. Compute attention loss
|
||||
@ -287,11 +288,10 @@ class ContextualParaformer(Paraformer):
|
||||
enforce_sorted=False)
|
||||
_, (h_n, _) = self.bias_encoder(hw_embed)
|
||||
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
|
||||
|
||||
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
|
||||
)
|
||||
|
||||
decoder_out = decoder_outs[0]
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out, ys_pad_lens
|
||||
@ -305,42 +305,38 @@ class ContextualParaformer(Paraformer):
|
||||
**kwargs,
|
||||
):
|
||||
# init beamsearch
|
||||
|
||||
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
||||
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||
if self.beam_search is None and (is_use_lm or is_use_ctc):
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
|
||||
meta_data = {}
|
||||
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
|
||||
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data[
|
||||
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# hotword
|
||||
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
|
||||
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
|
||||
# predictor
|
||||
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
|
||||
@ -348,7 +344,8 @@ class ContextualParaformer(Paraformer):
|
||||
pre_token_length = pre_token_length.round().long()
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
|
||||
|
||||
|
||||
decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
|
||||
pre_acoustic_embeds,
|
||||
pre_token_length,
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 yufan
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Multi-Head Attention Return Weight layer definition."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class MultiHeadedAttentionReturnWeight(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttentionReturnWeight object."""
|
||||
super(MultiHeadedAttentionReturnWeight, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = torch.finfo(scores.dtype).min
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x), self.attn # (batch, time1, d_model)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
@ -1,392 +0,0 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Transformer encoder definition."""
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import logging
|
||||
|
||||
from funasr.models.transformer.attention import MultiHeadedAttention
|
||||
from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
|
||||
from funasr.models.transformer.embedding import PositionalEncoding
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
from funasr.register import tables
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
stochastic_depth_rate (float): Proability to skip this layer.
|
||||
During training, the layer may skip residual computation and return input
|
||||
as-is with given probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
|
||||
def forward(self, x, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x_q, x, x, mask)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, mask
|
||||
|
||||
@tables.register("encoder_classes", "TransformerTextEncoder")
|
||||
class TransformerTextEncoder(nn.Module):
|
||||
"""Transformer text encoder module.
|
||||
|
||||
Args:
|
||||
input_size: input dim
|
||||
output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
attention_dropout_rate: dropout rate in attention
|
||||
positional_dropout_rate: dropout rate after adding positional encoding
|
||||
input_layer: input layer type
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
positionwise_layer_type: linear of conv1d
|
||||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, output_size, attention_dropout_rate
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
|
||||
|
||||
|
||||
|
||||
@tables.register("encoder_classes", "FusionSANEncoder")
|
||||
class SelfSrcAttention(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
src_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
linear_units,
|
||||
self_attention_dropout_rate,
|
||||
src_attention_dropout_rate,
|
||||
positional_dropout_rate,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an SelfSrcAttention object."""
|
||||
super(SelfSrcAttention, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
|
||||
self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
|
||||
self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
||||
"""Compute decoded features.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
cache (List[torch.Tensor]): List of cached tensors.
|
||||
Each tensor shape should be (#batch, maxlen_out - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
"""
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
tgt_concat = torch.cat(
|
||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear1(tgt_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x, score = self.src_attn(x, memory, memory, memory_mask)
|
||||
x = residual + self.dropout(x)
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
|
||||
|
||||
@tables.register("encoder_classes", "ConvBiasPredictor")
|
||||
class ConvPredictor(nn.Module):
|
||||
def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
|
||||
super().__init__()
|
||||
self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
|
||||
self.output_linear = nn.Linear(size, 1)
|
||||
|
||||
|
||||
def forward(self, text_enc, asr_enc):
|
||||
# stage1 cross-attention
|
||||
residual = text_enc
|
||||
text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
|
||||
|
||||
# stage2 FFN
|
||||
residual = text_enc
|
||||
text_enc = self.norm1(text_enc)
|
||||
text_enc = residual + self.feed_forward(text_enc)
|
||||
|
||||
# stage Conv predictor
|
||||
text_enc = self.norm2(text_enc)
|
||||
context = text_enc.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
memory = self.conv1d(queries)
|
||||
output = memory + context
|
||||
output = output.transpose(1, 2)
|
||||
output = torch.relu(output)
|
||||
output = self.output_linear(output)
|
||||
if output.dim()==3:
|
||||
output = output.squeeze(2)
|
||||
return output
|
||||
@ -1,495 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import logging
|
||||
from typing import Union, Dict, List, Tuple, Optional
|
||||
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.metrics.compute_acc import th_accuracy
|
||||
# from funasr.models.e2e_asr_common import ErrorCalculator
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.utils import postprocess_utils
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.register import tables
|
||||
|
||||
import pdb
|
||||
@tables.register("model_classes", "LCBNet")
|
||||
class LCBNet(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
|
||||
https://arxiv.org/abs/2401.06390
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: dict = None,
|
||||
encoder: str = None,
|
||||
encoder_conf: dict = None,
|
||||
decoder: str = None,
|
||||
decoder_conf: dict = None,
|
||||
text_encoder: str = None,
|
||||
text_encoder_conf: dict = None,
|
||||
bias_predictor: str = None,
|
||||
bias_predictor_conf: dict = None,
|
||||
fusion_encoder: str = None,
|
||||
fusion_encoder_conf: dict = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
select_num: int = 2,
|
||||
select_length: int = 3,
|
||||
insert_blank: bool = True,
|
||||
input_size: int = 80,
|
||||
vocab_size: int = -1,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
# extract_feats_in_collect_stats: bool = True,
|
||||
share_embedding: bool = False,
|
||||
# preencoder: Optional[AbsPreEncoder] = None,
|
||||
# postencoder: Optional[AbsPostEncoder] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
if normalize is not None:
|
||||
normalize_class = tables.normalize_classes.get(normalize)
|
||||
normalize = normalize_class(**normalize_conf)
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
# lcbnet modules: text encoder, fusion encoder and bias predictor
|
||||
text_encoder_class = tables.encoder_classes.get(text_encoder)
|
||||
text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
|
||||
fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
|
||||
fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
|
||||
bias_predictor_class = tables.encoder_classes.get(bias_predictor)
|
||||
bias_predictor = bias_predictor_class(**bias_predictor_conf)
|
||||
|
||||
|
||||
if decoder is not None:
|
||||
decoder_class = tables.decoder_classes.get(decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**decoder_conf,
|
||||
)
|
||||
if ctc_weight > 0.0:
|
||||
|
||||
if ctc_conf is None:
|
||||
ctc_conf = {}
|
||||
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.encoder = encoder
|
||||
# lcbnet
|
||||
self.text_encoder = text_encoder
|
||||
self.fusion_encoder = fusion_encoder
|
||||
self.bias_predictor = bias_predictor
|
||||
self.select_num = select_num
|
||||
self.select_length = select_length
|
||||
self.insert_blank = insert_blank
|
||||
|
||||
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
self.encoder.conditioning_layer = torch.nn.Linear(
|
||||
vocab_size, self.encoder.output_size()
|
||||
)
|
||||
self.interctc_weight = interctc_weight
|
||||
|
||||
# self.error_calculator = None
|
||||
if ctc_weight == 1.0:
|
||||
self.decoder = None
|
||||
else:
|
||||
self.decoder = decoder
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
#
|
||||
# if report_cer or report_wer:
|
||||
# self.error_calculator = ErrorCalculator(
|
||||
# token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
# )
|
||||
#
|
||||
self.error_calculator = None
|
||||
if ctc_weight == 0.0:
|
||||
self.ctc = None
|
||||
else:
|
||||
self.ctc = ctc
|
||||
|
||||
self.share_embedding = share_embedding
|
||||
if self.share_embedding:
|
||||
self.decoder.embed = None
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.beam_search = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
|
||||
if len(text_lengths.size()) > 1:
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
stats = dict()
|
||||
|
||||
# decoder: CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# decoder: Attention decoder branch
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
ind: int
|
||||
"""
|
||||
with autocast(False):
|
||||
# Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
||||
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
||||
# Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
||||
speech, speech_lengths, ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
def init_beam_search(self,
|
||||
**kwargs,
|
||||
):
|
||||
from funasr.models.transformer.search import BeamSearch
|
||||
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
|
||||
if self.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = kwargs.get("token_list")
|
||||
scorers.update(
|
||||
decoder=self.decoder,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3),
|
||||
ctc=kwargs.get("decoding_ctc_weight", 0.3),
|
||||
lm=kwargs.get("lm_weight", 0.0),
|
||||
ngram=kwargs.get("ngram_weight", 0.0),
|
||||
length_bonus=kwargs.get("penalty", 0.0),
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=kwargs.get("beam_size", 20),
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=self.sos,
|
||||
eos=self.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
self.beam_search = beam_search
|
||||
|
||||
def inference(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list=None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
# init beamsearch
|
||||
if self.beam_search is None:
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
meta_data = {}
|
||||
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
audio_sample_list = sample_list[0]
|
||||
if len(sample_list) >1:
|
||||
ocr_sample_list = sample_list[1]
|
||||
else:
|
||||
ocr_sample_list = [[294, 0]]
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
frame_shift = 10
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
|
||||
ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
|
||||
ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
|
||||
ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
|
||||
fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
|
||||
encoder_out = encoder_out + fusion_out
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
results = []
|
||||
b, n, d = encoder_out.size()
|
||||
for i in range(b):
|
||||
|
||||
for nbest_idx, hyp in enumerate(nbest_hyps):
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "token": token, "text": text_postprocessed}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
|
||||
return results, meta_data
|
||||
|
||||
@ -30,7 +30,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
import pdb
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
@ -128,7 +128,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
hotword_pad = kwargs.get("hotword_pad")
|
||||
hotword_lengths = kwargs.get("hotword_lengths")
|
||||
dha_pad = kwargs.get("dha_pad")
|
||||
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
@ -209,20 +209,17 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
nfilter=50,
|
||||
seaco_weight=1.0):
|
||||
# decoder forward
|
||||
|
||||
decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
|
||||
|
||||
decoder_pred = torch.log_softmax(decoder_out, dim=-1)
|
||||
if hw_list is not None:
|
||||
hw_lengths = [len(i) for i in hw_list]
|
||||
hw_list_ = [torch.Tensor(i).long() for i in hw_list]
|
||||
hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
|
||||
selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
|
||||
|
||||
contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
|
||||
num_hot_word = contextual_info.shape[1]
|
||||
_contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
|
||||
|
||||
|
||||
# ASF Core
|
||||
if nfilter > 0 and nfilter < num_hot_word:
|
||||
hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
|
||||
@ -242,7 +239,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
|
||||
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
|
||||
merged = self._merge(cif_attended, dec_attended)
|
||||
|
||||
|
||||
dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation
|
||||
dha_pred = torch.log_softmax(dha_output, dim=-1)
|
||||
def _merge_res(dec_output, dha_output):
|
||||
@ -256,8 +253,8 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
# logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
|
||||
logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
|
||||
return logits
|
||||
|
||||
merged_pred = _merge_res(decoder_pred, dha_pred)
|
||||
# import pdb; pdb.set_trace()
|
||||
return merged_pred
|
||||
else:
|
||||
return decoder_pred
|
||||
@ -307,6 +304,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
meta_data = {}
|
||||
|
||||
# extract fbank feats
|
||||
@ -332,7 +330,6 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
|
||||
# predictor
|
||||
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
|
||||
@ -341,14 +338,15 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
|
||||
|
||||
decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
|
||||
pre_acoustic_embeds,
|
||||
pre_token_length,
|
||||
hw_list=self.hotword_list)
|
||||
|
||||
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
|
||||
pre_token_length)
|
||||
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
for i in range(b):
|
||||
|
||||
@ -7,7 +7,7 @@ import logging
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
import pdb
|
||||
|
||||
|
||||
def filter_state_dict(
|
||||
dst_state: Dict[str, Union[float, torch.Tensor]],
|
||||
@ -63,7 +63,6 @@ def load_pretrained_model(
|
||||
dst_state = obj.state_dict()
|
||||
|
||||
print(f"ckpt: {path}")
|
||||
|
||||
if oss_bucket is None:
|
||||
src_state = torch.load(path, map_location=map_location)
|
||||
else:
|
||||
|
||||
@ -13,25 +13,29 @@ try:
|
||||
from funasr.download.file import download_from_url
|
||||
except:
|
||||
print("urllib is not installed, if you infer from url, please install it first.")
|
||||
import pdb
|
||||
|
||||
|
||||
|
||||
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
|
||||
if isinstance(data_or_path_or_list, (list, tuple)):
|
||||
if data_type is not None and isinstance(data_type, (list, tuple)):
|
||||
|
||||
data_types = [data_type] * len(data_or_path_or_list)
|
||||
data_or_path_or_list_ret = [[] for d in data_type]
|
||||
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
|
||||
|
||||
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
|
||||
|
||||
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
|
||||
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
|
||||
|
||||
return data_or_path_or_list_ret
|
||||
else:
|
||||
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
|
||||
|
||||
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
|
||||
data_or_path_or_list = download_from_url(data_or_path_or_list)
|
||||
|
||||
|
||||
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
|
||||
if data_type is None or data_type == "sound":
|
||||
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
|
||||
@ -52,22 +56,10 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
|
||||
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
|
||||
elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
|
||||
data_mat = kaldiio.load_mat(data_or_path_or_list)
|
||||
if isinstance(data_mat, tuple):
|
||||
audio_fs, mat = data_mat
|
||||
else:
|
||||
mat = data_mat
|
||||
if mat.dtype == 'int16' or mat.dtype == 'int32':
|
||||
mat = mat.astype(np.float64)
|
||||
mat = mat / 32768
|
||||
if mat.ndim ==2:
|
||||
mat = mat[:,0]
|
||||
data_or_path_or_list = mat
|
||||
else:
|
||||
pass
|
||||
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
|
||||
|
||||
|
||||
if audio_fs != fs and data_type != "text":
|
||||
resampler = torchaudio.transforms.Resample(audio_fs, fs)
|
||||
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
|
||||
@ -89,6 +81,8 @@ def load_bytes(input):
|
||||
return array
|
||||
|
||||
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data)
|
||||
if len(data.shape) < 2:
|
||||
@ -106,7 +100,9 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None,
|
||||
data_list.append(data_i)
|
||||
data_len.append(data_i.shape[0])
|
||||
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
# if data_type == "sound":
|
||||
data, data_len = frontend(data, data_len, **kwargs)
|
||||
|
||||
if isinstance(data_len, (list, tuple)):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user