mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf exp (#1654)
* sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * bugfix * update with main (#1631) * update seaco finetune * v1.0.24 --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * update with main (#1638) * update seaco finetune * v1.0.24 * update rwkv template --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * whisper * whisper * update style * update style --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
This commit is contained in:
parent
7c3ba91f67
commit
861147c730
6
.pre-commit-config.yaml
Normal file
6
.pre-commit-config.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.4.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: ['--line-length=100'] # 示例参数,black默认使用4个空格缩进
|
||||
18
docs/conf.py
18
docs/conf.py
@ -17,9 +17,9 @@
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'FunASR'
|
||||
copyright = '2022, Speech Lab, Alibaba Group'
|
||||
author = 'Speech Lab, Alibaba Grou'
|
||||
project = "FunASR"
|
||||
copyright = "2022, Speech Lab, Alibaba Group"
|
||||
author = "Speech Lab, Alibaba Group"
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
@ -30,18 +30,18 @@ author = 'Speech Lab, Alibaba Grou'
|
||||
extensions = [
|
||||
"nbsphinx",
|
||||
"sphinx.ext.autodoc",
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.todo",
|
||||
# "sphinxarg.ext",
|
||||
"sphinx_markdown_tables",
|
||||
'recommonmark',
|
||||
'sphinx_rtd_theme',
|
||||
"recommonmark",
|
||||
"sphinx_rtd_theme",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
@ -64,4 +64,4 @@ html_theme = "sphinx_rtd_theme"
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
# html_static_path = ['_static']
|
||||
# html_static_path = ['_static']
|
||||
|
||||
@ -6,23 +6,23 @@
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = 'MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0'
|
||||
copyright = '2023, Speech Lab, Alibaba Group; ASLP Group, Northwestern Polytechnical University'
|
||||
author = 'Speech Lab, Alibaba Group; Audio, Speech and Language Processing Group, Northwestern Polytechnical University'
|
||||
project = "MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0"
|
||||
copyright = "2023, Speech Lab, Alibaba Group; ASLP Group, Northwestern Polytechnical University"
|
||||
author = "Speech Lab, Alibaba Group; Audio, Speech and Language Processing Group, Northwestern Polytechnical University"
|
||||
|
||||
|
||||
extensions = [
|
||||
"nbsphinx",
|
||||
"sphinx.ext.autodoc",
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.todo",
|
||||
# "sphinxarg.ext",
|
||||
"sphinx_markdown_tables",
|
||||
# 'recommonmark',
|
||||
'sphinx_rtd_theme',
|
||||
'myst_parser',
|
||||
"sphinx_rtd_theme",
|
||||
"myst_parser",
|
||||
]
|
||||
|
||||
myst_enable_extensions = [
|
||||
@ -33,13 +33,12 @@ myst_enable_extensions = [
|
||||
]
|
||||
|
||||
myst_heading_anchors = 2
|
||||
myst_highlight_code_blocks=True
|
||||
myst_update_mathjax=False
|
||||
myst_highlight_code_blocks = True
|
||||
myst_update_mathjax = False
|
||||
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
pygments_style = "sphinx"
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
|
||||
@ -2,66 +2,98 @@ import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
def compute_wer(ref_file,
|
||||
hyp_file,
|
||||
cer_detail_file):
|
||||
|
||||
def compute_wer(ref_file, hyp_file, cer_detail_file):
|
||||
rst = {
|
||||
'Wrd': 0,
|
||||
'Corr': 0,
|
||||
'Ins': 0,
|
||||
'Del': 0,
|
||||
'Sub': 0,
|
||||
'Snt': 0,
|
||||
'Err': 0.0,
|
||||
'S.Err': 0.0,
|
||||
'wrong_words': 0,
|
||||
'wrong_sentences': 0
|
||||
"Wrd": 0,
|
||||
"Corr": 0,
|
||||
"Ins": 0,
|
||||
"Del": 0,
|
||||
"Sub": 0,
|
||||
"Snt": 0,
|
||||
"Err": 0.0,
|
||||
"S.Err": 0.0,
|
||||
"wrong_words": 0,
|
||||
"wrong_sentences": 0,
|
||||
}
|
||||
|
||||
hyp_dict = {}
|
||||
ref_dict = {}
|
||||
with open(hyp_file, 'r') as hyp_reader:
|
||||
with open(hyp_file, "r") as hyp_reader:
|
||||
for line in hyp_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
hyp_dict[key] = value
|
||||
with open(ref_file, 'r') as ref_reader:
|
||||
with open(ref_file, "r") as ref_reader:
|
||||
for line in ref_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
ref_dict[key] = value
|
||||
|
||||
cer_detail_writer = open(cer_detail_file, 'w')
|
||||
cer_detail_writer = open(cer_detail_file, "w")
|
||||
for hyp_key in hyp_dict:
|
||||
if hyp_key in ref_dict:
|
||||
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
|
||||
rst['Wrd'] += out_item['nwords']
|
||||
rst['Corr'] += out_item['cor']
|
||||
rst['wrong_words'] += out_item['wrong']
|
||||
rst['Ins'] += out_item['ins']
|
||||
rst['Del'] += out_item['del']
|
||||
rst['Sub'] += out_item['sub']
|
||||
rst['Snt'] += 1
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
|
||||
cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
|
||||
cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
|
||||
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
|
||||
rst["Wrd"] += out_item["nwords"]
|
||||
rst["Corr"] += out_item["cor"]
|
||||
rst["wrong_words"] += out_item["wrong"]
|
||||
rst["Ins"] += out_item["ins"]
|
||||
rst["Del"] += out_item["del"]
|
||||
rst["Sub"] += out_item["sub"]
|
||||
rst["Snt"] += 1
|
||||
if out_item["wrong"] > 0:
|
||||
rst["wrong_sentences"] += 1
|
||||
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + "\n")
|
||||
cer_detail_writer.write(
|
||||
"ref:" + "\t" + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + "\n"
|
||||
)
|
||||
cer_detail_writer.write(
|
||||
"hyp:" + "\t" + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + "\n"
|
||||
)
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
if rst['Snt'] > 0:
|
||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
|
||||
if rst["Wrd"] > 0:
|
||||
rst["Err"] = round(rst["wrong_words"] * 100 / rst["Wrd"], 2)
|
||||
if rst["Snt"] > 0:
|
||||
rst["S.Err"] = round(rst["wrong_sentences"] * 100 / rst["Snt"], 2)
|
||||
|
||||
cer_detail_writer.write('\n')
|
||||
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
|
||||
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
|
||||
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
|
||||
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
|
||||
cer_detail_writer.write("\n")
|
||||
cer_detail_writer.write(
|
||||
"%WER "
|
||||
+ str(rst["Err"])
|
||||
+ " [ "
|
||||
+ str(rst["wrong_words"])
|
||||
+ " / "
|
||||
+ str(rst["Wrd"])
|
||||
+ ", "
|
||||
+ str(rst["Ins"])
|
||||
+ " ins, "
|
||||
+ str(rst["Del"])
|
||||
+ " del, "
|
||||
+ str(rst["Sub"])
|
||||
+ " sub ]"
|
||||
+ "\n"
|
||||
)
|
||||
cer_detail_writer.write(
|
||||
"%SER "
|
||||
+ str(rst["S.Err"])
|
||||
+ " [ "
|
||||
+ str(rst["wrong_sentences"])
|
||||
+ " / "
|
||||
+ str(rst["Snt"])
|
||||
+ " ]"
|
||||
+ "\n"
|
||||
)
|
||||
cer_detail_writer.write(
|
||||
"Scored "
|
||||
+ str(len(hyp_dict))
|
||||
+ " sentences, "
|
||||
+ str(len(hyp_dict) - rst["Snt"])
|
||||
+ " not present in hyp."
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
|
||||
def compute_wer_by_line(hyp,
|
||||
ref):
|
||||
|
||||
def compute_wer_by_line(hyp, ref):
|
||||
hyp = list(map(lambda x: x.lower(), hyp))
|
||||
ref = list(map(lambda x: x.lower(), ref))
|
||||
|
||||
@ -96,14 +128,7 @@ def compute_wer_by_line(hyp,
|
||||
match_idx = []
|
||||
i = len_hyp
|
||||
j = len_ref
|
||||
rst = {
|
||||
'nwords': len_ref,
|
||||
'cor': 0,
|
||||
'wrong': 0,
|
||||
'ins': 0,
|
||||
'del': 0,
|
||||
'sub': 0
|
||||
}
|
||||
rst = {"nwords": len_ref, "cor": 0, "wrong": 0, "ins": 0, "del": 0, "sub": 0}
|
||||
while i >= 0 or j >= 0:
|
||||
i_idx = max(0, i)
|
||||
j_idx = max(0, j)
|
||||
@ -111,42 +136,57 @@ def compute_wer_by_line(hyp,
|
||||
if ops_matrix[i_idx][j_idx] == 0: # correct
|
||||
if i - 1 >= 0 and j - 1 >= 0:
|
||||
match_idx.append((j - 1, i - 1))
|
||||
rst['cor'] += 1
|
||||
rst["cor"] += 1
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 2: # insert
|
||||
i -= 1
|
||||
rst['ins'] += 1
|
||||
rst["ins"] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 3: # delete
|
||||
j -= 1
|
||||
rst['del'] += 1
|
||||
rst["del"] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute
|
||||
i -= 1
|
||||
j -= 1
|
||||
rst['sub'] += 1
|
||||
rst["sub"] += 1
|
||||
|
||||
if i < 0 and j >= 0:
|
||||
rst['del'] += 1
|
||||
rst["del"] += 1
|
||||
elif j < 0 and i >= 0:
|
||||
rst['ins'] += 1
|
||||
rst["ins"] += 1
|
||||
|
||||
match_idx.reverse()
|
||||
wrong_cnt = cost_matrix[len_hyp][len_ref]
|
||||
rst['wrong'] = wrong_cnt
|
||||
rst["wrong"] = wrong_cnt
|
||||
|
||||
return rst
|
||||
|
||||
def print_cer_detail(rst):
|
||||
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
|
||||
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
|
||||
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
|
||||
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
|
||||
|
||||
if __name__ == '__main__':
|
||||
def print_cer_detail(rst):
|
||||
return (
|
||||
"("
|
||||
+ "nwords="
|
||||
+ str(rst["nwords"])
|
||||
+ ",cor="
|
||||
+ str(rst["cor"])
|
||||
+ ",ins="
|
||||
+ str(rst["ins"])
|
||||
+ ",del="
|
||||
+ str(rst["del"])
|
||||
+ ",sub="
|
||||
+ str(rst["sub"])
|
||||
+ ") corr:"
|
||||
+ "{:.2%}".format(rst["cor"] / rst["nwords"])
|
||||
+ ",cer:"
|
||||
+ "{:.2%}".format(rst["wrong"] / rst["nwords"])
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 4:
|
||||
print("usage : python compute-wer.py test.ref test.hyp test.wer")
|
||||
sys.exit(0)
|
||||
|
||||
@ -5,6 +5,7 @@ import os
|
||||
import torch
|
||||
from kaldiio import WriteHelper
|
||||
import re
|
||||
|
||||
text_file_json = sys.argv[1]
|
||||
out_ark = sys.argv[2]
|
||||
out_scp = sys.argv[3]
|
||||
@ -16,17 +17,17 @@ model = AutoModel.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
|
||||
|
||||
with open(text_file_json, 'r') as f:
|
||||
with open(text_file_json, "r") as f:
|
||||
js = f.readlines()
|
||||
|
||||
|
||||
f_shape = open(out_shape, "w")
|
||||
with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
|
||||
with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer:
|
||||
with torch.no_grad():
|
||||
for idx, line in enumerate(js):
|
||||
id, tokens = line.strip().split(" ", 1)
|
||||
tokens = re.sub(" ", "", tokens.strip())
|
||||
tokens = ' '.join([j for j in tokens])
|
||||
tokens = " ".join([j for j in tokens])
|
||||
token_num = len(tokens.split(" "))
|
||||
outputs = extractor(tokens)
|
||||
outputs = np.array(outputs)
|
||||
@ -38,10 +39,11 @@ with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
|
||||
shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
|
||||
f_shape.write(shape_line)
|
||||
else:
|
||||
print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
|
||||
|
||||
print(
|
||||
"{}, size has changed, {}, {}, {}".format(
|
||||
id, token_num, token_num_embeds, tokens
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
f_shape.close()
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
@ -7,25 +6,25 @@ out_f = sys.argv[2]
|
||||
|
||||
|
||||
with open(in_f, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
lines = f.readlines()
|
||||
|
||||
with open(out_f, "w", encoding="utf-8") as f:
|
||||
for line in lines:
|
||||
outs = line.strip().split(" ", 1)
|
||||
if len(outs) == 2:
|
||||
idx, text = outs
|
||||
text = re.sub("</s>", "", text)
|
||||
text = re.sub("<s>", "", text)
|
||||
text = re.sub("@@", "", text)
|
||||
text = re.sub("@", "", text)
|
||||
text = re.sub("<unk>", "", text)
|
||||
text = re.sub(" ", "", text)
|
||||
text = text.lower()
|
||||
else:
|
||||
idx = outs[0]
|
||||
text = " "
|
||||
for line in lines:
|
||||
outs = line.strip().split(" ", 1)
|
||||
if len(outs) == 2:
|
||||
idx, text = outs
|
||||
text = re.sub("</s>", "", text)
|
||||
text = re.sub("<s>", "", text)
|
||||
text = re.sub("@@", "", text)
|
||||
text = re.sub("@", "", text)
|
||||
text = re.sub("<unk>", "", text)
|
||||
text = re.sub(" ", "", text)
|
||||
text = text.lower()
|
||||
else:
|
||||
idx = outs[0]
|
||||
text = " "
|
||||
|
||||
text = [x for x in text]
|
||||
text = " ".join(text)
|
||||
out = "{} {}\n".format(idx, text)
|
||||
f.write(out)
|
||||
text = [x for x in text]
|
||||
text = " ".join(text)
|
||||
out = "{} {}\n".format(idx, text)
|
||||
f.write(out)
|
||||
|
||||
@ -37,9 +37,7 @@ def get_parser():
|
||||
help="number of characters to split, i.e., \
|
||||
aabb -> a a b b with -n 1 and aa bb with -n 2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||
)
|
||||
parser.add_argument("--skip-ncols", "-s", default=0, type=int, help="skip first n columns")
|
||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
||||
parser.add_argument(
|
||||
"--non-lang-syms",
|
||||
@ -80,9 +78,7 @@ def main():
|
||||
else:
|
||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
||||
|
||||
sys.stdout = codecs.getwriter("utf-8")(
|
||||
sys.stdout if is_python2 else sys.stdout.buffer
|
||||
)
|
||||
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer)
|
||||
line = f.readline()
|
||||
n = args.nchar
|
||||
while line:
|
||||
|
||||
@ -4,7 +4,7 @@ import argparse
|
||||
|
||||
def load_dict(seg_file):
|
||||
seg_dict = {}
|
||||
with open(seg_file, 'r') as infile:
|
||||
with open(seg_file, "r") as infile:
|
||||
for line in infile:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
@ -28,8 +28,7 @@ def forward_segment(text, dic):
|
||||
return word_list
|
||||
|
||||
|
||||
def tokenize(txt,
|
||||
seg_dict):
|
||||
def tokenize(txt, seg_dict):
|
||||
out_txt = ""
|
||||
pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
|
||||
for word in txt:
|
||||
@ -87,20 +86,19 @@ def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
|
||||
shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
|
||||
txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), "w")
|
||||
shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), "w")
|
||||
seg_dict = load_dict(args.seg_file)
|
||||
with open(args.text_file, 'r') as infile:
|
||||
with open(args.text_file, "r") as infile:
|
||||
for line in infile:
|
||||
s = line.strip().split()
|
||||
text_id = s[0]
|
||||
text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
|
||||
text = tokenize(text_list, seg_dict)
|
||||
lens = len(text.strip().split())
|
||||
txt_writer.write(text_id + " " + text + '\n')
|
||||
shape_writer.write(text_id + " " + str(lens) + '\n')
|
||||
txt_writer.write(text_id + " " + text + "\n")
|
||||
shape_writer.write(text_id + " " + str(lens) + "\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@ -14,50 +14,59 @@ import sys, os, argparse, codecs, string, re
|
||||
# ================================================================================ #
|
||||
# basic constant
|
||||
# ================================================================================ #
|
||||
CHINESE_DIGIS = u'零一二三四五六七八九'
|
||||
BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
|
||||
BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
|
||||
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
|
||||
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
|
||||
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
|
||||
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
|
||||
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
|
||||
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
|
||||
CHINESE_DIGIS = "零一二三四五六七八九"
|
||||
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
|
||||
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
|
||||
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
|
||||
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
|
||||
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
|
||||
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
|
||||
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
|
||||
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
|
||||
|
||||
ZERO_ALT = u'〇'
|
||||
ONE_ALT = u'幺'
|
||||
TWO_ALTS = [u'两', u'兩']
|
||||
ZERO_ALT = "〇"
|
||||
ONE_ALT = "幺"
|
||||
TWO_ALTS = ["两", "兩"]
|
||||
|
||||
POSITIVE = [u'正', u'正']
|
||||
NEGATIVE = [u'负', u'負']
|
||||
POINT = [u'点', u'點']
|
||||
POSITIVE = ["正", "正"]
|
||||
NEGATIVE = ["负", "負"]
|
||||
POINT = ["点", "點"]
|
||||
# PLUS = [u'加', u'加']
|
||||
# SIL = [u'杠', u'槓']
|
||||
|
||||
FILLER_CHARS = ['呃', '啊']
|
||||
ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
|
||||
'胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
|
||||
'儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
|
||||
'佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
|
||||
FILLER_CHARS = ["呃", "啊"]
|
||||
ER_WHITELIST = (
|
||||
"(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
|
||||
"胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
|
||||
"儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
|
||||
"佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
|
||||
)
|
||||
|
||||
# 中文数字系统类型
|
||||
NUMBERING_TYPES = ['low', 'mid', 'high']
|
||||
NUMBERING_TYPES = ["low", "mid", "high"]
|
||||
|
||||
CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
|
||||
'里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
|
||||
CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
||||
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
|
||||
'砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
|
||||
'针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
|
||||
'毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
|
||||
'盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
|
||||
'纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
|
||||
CURRENCY_NAMES = (
|
||||
"(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
|
||||
"里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
|
||||
)
|
||||
CURRENCY_UNITS = (
|
||||
"((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
||||
)
|
||||
COM_QUANTIFIERS = (
|
||||
"(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
|
||||
"砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
|
||||
"针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
|
||||
"毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
|
||||
"盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
|
||||
"纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
|
||||
)
|
||||
|
||||
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
|
||||
CHINESE_PUNC_STOP = '!?。。'
|
||||
CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
|
||||
CHINESE_PUNC_STOP = "!?。。"
|
||||
CHINESE_PUNC_NON_STOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏"
|
||||
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
|
||||
|
||||
|
||||
# ================================================================================ #
|
||||
# basic class
|
||||
# ================================================================================ #
|
||||
@ -72,7 +81,7 @@ class ChineseChar(object):
|
||||
def __init__(self, simplified, traditional):
|
||||
self.simplified = simplified
|
||||
self.traditional = traditional
|
||||
#self.__repr__ = self.__str__
|
||||
# self.__repr__ = self.__str__
|
||||
|
||||
def __str__(self):
|
||||
return self.simplified or self.traditional or None
|
||||
@ -95,26 +104,49 @@ class ChineseNumberUnit(ChineseChar):
|
||||
self.big_t = big_t
|
||||
|
||||
def __str__(self):
|
||||
return '10^{}'.format(self.power)
|
||||
return "10^{}".format(self.power)
|
||||
|
||||
@classmethod
|
||||
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
|
||||
|
||||
if small_unit:
|
||||
return ChineseNumberUnit(power=index + 1,
|
||||
simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
|
||||
return ChineseNumberUnit(
|
||||
power=index + 1,
|
||||
simplified=value[0],
|
||||
traditional=value[1],
|
||||
big_s=value[1],
|
||||
big_t=value[1],
|
||||
)
|
||||
elif numbering_type == NUMBERING_TYPES[0]:
|
||||
return ChineseNumberUnit(power=index + 8,
|
||||
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
|
||||
return ChineseNumberUnit(
|
||||
power=index + 8,
|
||||
simplified=value[0],
|
||||
traditional=value[1],
|
||||
big_s=value[0],
|
||||
big_t=value[1],
|
||||
)
|
||||
elif numbering_type == NUMBERING_TYPES[1]:
|
||||
return ChineseNumberUnit(power=(index + 2) * 4,
|
||||
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
|
||||
return ChineseNumberUnit(
|
||||
power=(index + 2) * 4,
|
||||
simplified=value[0],
|
||||
traditional=value[1],
|
||||
big_s=value[0],
|
||||
big_t=value[1],
|
||||
)
|
||||
elif numbering_type == NUMBERING_TYPES[2]:
|
||||
return ChineseNumberUnit(power=pow(2, index + 3),
|
||||
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
|
||||
return ChineseNumberUnit(
|
||||
power=pow(2, index + 3),
|
||||
simplified=value[0],
|
||||
traditional=value[1],
|
||||
big_s=value[0],
|
||||
big_t=value[1],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
|
||||
"Counting type should be in {0} ({1} provided).".format(
|
||||
NUMBERING_TYPES, numbering_type
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ChineseNumberDigit(ChineseChar):
|
||||
@ -158,6 +190,7 @@ class NumberSystem(object):
|
||||
"""
|
||||
中文数字系统
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -207,27 +240,27 @@ def create_system(numbering_type=NUMBERING_TYPES[1]):
|
||||
|
||||
# chinese number units of '亿' and larger
|
||||
all_larger_units = zip(
|
||||
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
|
||||
larger_units = [CNU.create(i, v, numbering_type, False)
|
||||
for i, v in enumerate(all_larger_units)]
|
||||
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
|
||||
)
|
||||
larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
|
||||
# chinese number units of '十, 百, 千, 万'
|
||||
all_smaller_units = zip(
|
||||
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
|
||||
smaller_units = [CNU.create(i, v, small_unit=True)
|
||||
for i, v in enumerate(all_smaller_units)]
|
||||
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
|
||||
)
|
||||
smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
|
||||
# digis
|
||||
chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
|
||||
BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
|
||||
chinese_digis = zip(
|
||||
CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL
|
||||
)
|
||||
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
|
||||
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
|
||||
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
|
||||
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
|
||||
|
||||
# symbols
|
||||
positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
|
||||
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
|
||||
point_cn = CM(POINT[0], POINT[1], '.', lambda x,
|
||||
y: float(str(x) + '.' + str(y)))
|
||||
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
|
||||
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
|
||||
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
|
||||
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
|
||||
system = NumberSystem()
|
||||
system.units = smaller_units + larger_units
|
||||
@ -251,13 +284,14 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
||||
return m
|
||||
|
||||
def string2symbols(chinese_string, system):
|
||||
int_string, dec_string = chinese_string, ''
|
||||
int_string, dec_string = chinese_string, ""
|
||||
for p in [system.math.point.simplified, system.math.point.traditional]:
|
||||
if p in chinese_string:
|
||||
int_string, dec_string = chinese_string.split(p)
|
||||
break
|
||||
return [get_symbol(c, system) for c in int_string], \
|
||||
[get_symbol(c, system) for c in dec_string]
|
||||
return [get_symbol(c, system) for c in int_string], [
|
||||
get_symbol(c, system) for c in dec_string
|
||||
]
|
||||
|
||||
def correct_symbols(integer_symbols, system):
|
||||
"""
|
||||
@ -271,8 +305,7 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
||||
|
||||
if len(integer_symbols) > 1:
|
||||
if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
|
||||
integer_symbols.append(
|
||||
CNU(integer_symbols[-2].power - 1, None, None, None, None))
|
||||
integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
|
||||
|
||||
result = []
|
||||
unit_count = 0
|
||||
@ -288,9 +321,13 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
||||
result.append(current_unit)
|
||||
elif unit_count > 1:
|
||||
for i in range(len(result)):
|
||||
if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
|
||||
result[-i - 1] = CNU(result[-i - 1].power +
|
||||
current_unit.power, None, None, None, None)
|
||||
if (
|
||||
isinstance(result[-i - 1], CNU)
|
||||
and result[-i - 1].power < current_unit.power
|
||||
):
|
||||
result[-i - 1] = CNU(
|
||||
result[-i - 1].power + current_unit.power, None, None, None, None
|
||||
)
|
||||
return result
|
||||
|
||||
def compute_value(integer_symbols):
|
||||
@ -307,8 +344,7 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
||||
elif isinstance(s, CNU):
|
||||
value[-1] *= pow(10, s.power)
|
||||
if s.power > last_power:
|
||||
value[:-1] = list(map(lambda v: v *
|
||||
pow(10, s.power), value[:-1]))
|
||||
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
|
||||
last_power = s.power
|
||||
value.append(0)
|
||||
return sum(value)
|
||||
@ -317,20 +353,28 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
||||
int_part, dec_part = string2symbols(chinese_string, system)
|
||||
int_part = correct_symbols(int_part, system)
|
||||
int_str = str(compute_value(int_part))
|
||||
dec_str = ''.join([str(d.value) for d in dec_part])
|
||||
dec_str = "".join([str(d.value) for d in dec_part])
|
||||
if dec_part:
|
||||
return '{0}.{1}'.format(int_str, dec_str)
|
||||
return "{0}.{1}".format(int_str, dec_str)
|
||||
else:
|
||||
return int_str
|
||||
|
||||
|
||||
def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
|
||||
traditional=False, alt_zero=False, alt_one=False, alt_two=True,
|
||||
use_zeros=True, use_units=True):
|
||||
def num2chn(
|
||||
number_string,
|
||||
numbering_type=NUMBERING_TYPES[1],
|
||||
big=False,
|
||||
traditional=False,
|
||||
alt_zero=False,
|
||||
alt_one=False,
|
||||
alt_two=True,
|
||||
use_zeros=True,
|
||||
use_units=True,
|
||||
):
|
||||
|
||||
def get_value(value_string, use_zeros=True):
|
||||
|
||||
striped_string = value_string.lstrip('0')
|
||||
striped_string = value_string.lstrip("0")
|
||||
|
||||
# record nothing if all zeros
|
||||
if not striped_string:
|
||||
@ -345,14 +389,17 @@ def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
|
||||
|
||||
# recursively record multiple digits
|
||||
else:
|
||||
result_unit = next(u for u in reversed(
|
||||
system.units) if u.power < len(striped_string))
|
||||
result_string = value_string[:-result_unit.power]
|
||||
return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
|
||||
result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
|
||||
result_string = value_string[: -result_unit.power]
|
||||
return (
|
||||
get_value(result_string)
|
||||
+ [result_unit]
|
||||
+ get_value(striped_string[-result_unit.power :])
|
||||
)
|
||||
|
||||
system = create_system(numbering_type)
|
||||
|
||||
int_dec = number_string.split('.')
|
||||
int_dec = number_string.split(".")
|
||||
if len(int_dec) == 1:
|
||||
int_string = int_dec[0]
|
||||
dec_string = ""
|
||||
@ -361,7 +408,8 @@ def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
|
||||
dec_string = int_dec[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"invalid input num string with more than one dot: {}".format(number_string))
|
||||
"invalid input num string with more than one dot: {}".format(number_string)
|
||||
)
|
||||
|
||||
if use_units and len(int_string) > 1:
|
||||
result_symbols = get_value(int_string)
|
||||
@ -372,51 +420,62 @@ def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
|
||||
result_symbols += [system.math.point] + dec_symbols
|
||||
|
||||
if alt_two:
|
||||
liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
|
||||
system.digits[2].big_s, system.digits[2].big_t)
|
||||
liang = CND(
|
||||
2,
|
||||
system.digits[2].alt_s,
|
||||
system.digits[2].alt_t,
|
||||
system.digits[2].big_s,
|
||||
system.digits[2].big_t,
|
||||
)
|
||||
for i, v in enumerate(result_symbols):
|
||||
if isinstance(v, CND) and v.value == 2:
|
||||
next_symbol = result_symbols[i +
|
||||
1] if i < len(result_symbols) - 1 else None
|
||||
next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
|
||||
previous_symbol = result_symbols[i - 1] if i > 0 else None
|
||||
if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
|
||||
if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
|
||||
if next_symbol.power != 1 and (
|
||||
(previous_symbol is None) or (previous_symbol.power != 1)
|
||||
):
|
||||
result_symbols[i] = liang
|
||||
|
||||
# if big is True, '两' will not be used and `alt_two` has no impact on output
|
||||
if big:
|
||||
attr_name = 'big_'
|
||||
attr_name = "big_"
|
||||
if traditional:
|
||||
attr_name += 't'
|
||||
attr_name += "t"
|
||||
else:
|
||||
attr_name += 's'
|
||||
attr_name += "s"
|
||||
else:
|
||||
if traditional:
|
||||
attr_name = 'traditional'
|
||||
attr_name = "traditional"
|
||||
else:
|
||||
attr_name = 'simplified'
|
||||
attr_name = "simplified"
|
||||
|
||||
result = ''.join([getattr(s, attr_name) for s in result_symbols])
|
||||
result = "".join([getattr(s, attr_name) for s in result_symbols])
|
||||
|
||||
# if not use_zeros:
|
||||
# result = result.strip(getattr(system.digits[0], attr_name))
|
||||
|
||||
if alt_zero:
|
||||
result = result.replace(
|
||||
getattr(system.digits[0], attr_name), system.digits[0].alt_s)
|
||||
result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
|
||||
|
||||
if alt_one:
|
||||
result = result.replace(
|
||||
getattr(system.digits[1], attr_name), system.digits[1].alt_s)
|
||||
result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
|
||||
|
||||
for i, p in enumerate(POINT):
|
||||
if result.startswith(p):
|
||||
return CHINESE_DIGIS[0] + result
|
||||
|
||||
# ^10, 11, .., 19
|
||||
if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
|
||||
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
|
||||
result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
|
||||
if (
|
||||
len(result) >= 2
|
||||
and result[1]
|
||||
in [
|
||||
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
|
||||
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
|
||||
]
|
||||
and result[0]
|
||||
in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
|
||||
):
|
||||
result = result[1:]
|
||||
|
||||
return result
|
||||
@ -440,6 +499,7 @@ class Cardinal:
|
||||
def cardinal2chntext(self):
|
||||
return num2chn(self.cardinal)
|
||||
|
||||
|
||||
class Digit:
|
||||
"""
|
||||
DIGIT类
|
||||
@ -476,17 +536,17 @@ class TelePhone:
|
||||
def telephone2chntext(self, fixed=False):
|
||||
|
||||
if fixed:
|
||||
sil_parts = self.telephone.split('-')
|
||||
self.raw_chntext = '<SIL>'.join([
|
||||
num2chn(part, alt_two=False, use_units=False) for part in sil_parts
|
||||
])
|
||||
self.chntext = self.raw_chntext.replace('<SIL>', '')
|
||||
sil_parts = self.telephone.split("-")
|
||||
self.raw_chntext = "<SIL>".join(
|
||||
[num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
|
||||
)
|
||||
self.chntext = self.raw_chntext.replace("<SIL>", "")
|
||||
else:
|
||||
sp_parts = self.telephone.strip('+').split()
|
||||
self.raw_chntext = '<SP>'.join([
|
||||
num2chn(part, alt_two=False, use_units=False) for part in sp_parts
|
||||
])
|
||||
self.chntext = self.raw_chntext.replace('<SP>', '')
|
||||
sp_parts = self.telephone.strip("+").split()
|
||||
self.raw_chntext = "<SP>".join(
|
||||
[num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
|
||||
)
|
||||
self.chntext = self.raw_chntext.replace("<SP>", "")
|
||||
return self.chntext
|
||||
|
||||
|
||||
@ -500,12 +560,12 @@ class Fraction:
|
||||
self.chntext = chntext
|
||||
|
||||
def chntext2fraction(self):
|
||||
denominator, numerator = self.chntext.split('分之')
|
||||
return chn2num(numerator) + '/' + chn2num(denominator)
|
||||
denominator, numerator = self.chntext.split("分之")
|
||||
return chn2num(numerator) + "/" + chn2num(denominator)
|
||||
|
||||
def fraction2chntext(self):
|
||||
numerator, denominator = self.fraction.split('/')
|
||||
return num2chn(denominator) + '分之' + num2chn(numerator)
|
||||
numerator, denominator = self.fraction.split("/")
|
||||
return num2chn(denominator) + "分之" + num2chn(numerator)
|
||||
|
||||
|
||||
class Date:
|
||||
@ -544,23 +604,23 @@ class Date:
|
||||
def date2chntext(self):
|
||||
date = self.date
|
||||
try:
|
||||
year, other = date.strip().split('年', 1)
|
||||
year = Digit(digit=year).digit2chntext() + '年'
|
||||
year, other = date.strip().split("年", 1)
|
||||
year = Digit(digit=year).digit2chntext() + "年"
|
||||
except ValueError:
|
||||
other = date
|
||||
year = ''
|
||||
year = ""
|
||||
if other:
|
||||
try:
|
||||
month, day = other.strip().split('月', 1)
|
||||
month = Cardinal(cardinal=month).cardinal2chntext() + '月'
|
||||
month, day = other.strip().split("月", 1)
|
||||
month = Cardinal(cardinal=month).cardinal2chntext() + "月"
|
||||
except ValueError:
|
||||
day = date
|
||||
month = ''
|
||||
month = ""
|
||||
if day:
|
||||
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
|
||||
else:
|
||||
month = ''
|
||||
day = ''
|
||||
month = ""
|
||||
day = ""
|
||||
chntext = year + month + day
|
||||
self.chntext = chntext
|
||||
return self.chntext
|
||||
@ -580,7 +640,7 @@ class Money:
|
||||
|
||||
def money2chntext(self):
|
||||
money = self.money
|
||||
pattern = re.compile(r'(\d+(\.\d+)?)')
|
||||
pattern = re.compile(r"(\d+(\.\d+)?)")
|
||||
matchers = pattern.findall(money)
|
||||
if matchers:
|
||||
for matcher in matchers:
|
||||
@ -599,10 +659,10 @@ class Percentage:
|
||||
self.chntext = chntext
|
||||
|
||||
def chntext2percentage(self):
|
||||
return chn2num(self.chntext.strip().strip('百分之')) + '%'
|
||||
return chn2num(self.chntext.strip().strip("百分之")) + "%"
|
||||
|
||||
def percentage2chntext(self):
|
||||
return '百分之' + num2chn(self.percentage.strip().strip('%'))
|
||||
return "百分之" + num2chn(self.percentage.strip().strip("%"))
|
||||
|
||||
|
||||
def remove_erhua(text, er_whitelist):
|
||||
@ -612,9 +672,9 @@ def remove_erhua(text, er_whitelist):
|
||||
"""
|
||||
|
||||
er_pattern = re.compile(er_whitelist)
|
||||
new_str=''
|
||||
while re.search('儿',text):
|
||||
a = re.search('儿',text).span()
|
||||
new_str = ""
|
||||
while re.search("儿", text):
|
||||
a = re.search("儿", text).span()
|
||||
remove_er_flag = 0
|
||||
|
||||
if er_pattern.search(text):
|
||||
@ -622,23 +682,24 @@ def remove_erhua(text, er_whitelist):
|
||||
if b[0] <= a[0]:
|
||||
remove_er_flag = 1
|
||||
|
||||
if remove_er_flag == 0 :
|
||||
new_str = new_str + text[0:a[0]]
|
||||
text = text[a[1]:]
|
||||
if remove_er_flag == 0:
|
||||
new_str = new_str + text[0 : a[0]]
|
||||
text = text[a[1] :]
|
||||
else:
|
||||
new_str = new_str + text[0:b[1]]
|
||||
text = text[b[1]:]
|
||||
new_str = new_str + text[0 : b[1]]
|
||||
text = text[b[1] :]
|
||||
|
||||
text = new_str + text
|
||||
return text
|
||||
|
||||
|
||||
# ================================================================================ #
|
||||
# NSW Normalizer
|
||||
# ================================================================================ #
|
||||
class NSWNormalizer:
|
||||
def __init__(self, raw_text):
|
||||
self.raw_text = '^' + raw_text + '$'
|
||||
self.norm_text = ''
|
||||
self.raw_text = "^" + raw_text + "$"
|
||||
self.norm_text = ""
|
||||
|
||||
def _particular(self):
|
||||
text = self.norm_text
|
||||
@ -647,7 +708,7 @@ class NSWNormalizer:
|
||||
if matchers:
|
||||
# print('particular')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
|
||||
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
|
||||
self.norm_text = text
|
||||
return self.norm_text
|
||||
|
||||
@ -658,15 +719,17 @@ class NSWNormalizer:
|
||||
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('date')
|
||||
# print('date')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
|
||||
|
||||
# 规范化金钱
|
||||
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
|
||||
pattern = re.compile(
|
||||
r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)"
|
||||
)
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('money')
|
||||
# print('money')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
|
||||
|
||||
@ -679,39 +742,45 @@ class NSWNormalizer:
|
||||
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('telephone')
|
||||
# print('telephone')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
|
||||
text = text.replace(
|
||||
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
|
||||
)
|
||||
# 固话
|
||||
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('fixed telephone')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
|
||||
text = text.replace(
|
||||
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1
|
||||
)
|
||||
|
||||
# 规范化分数
|
||||
pattern = re.compile(r"(\d+/\d+)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('fraction')
|
||||
# print('fraction')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
|
||||
|
||||
# 规范化百分数
|
||||
text = text.replace('%', '%')
|
||||
text = text.replace("%", "%")
|
||||
pattern = re.compile(r"(\d+(\.\d+)?%)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('percentage')
|
||||
# print('percentage')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
|
||||
text = text.replace(
|
||||
matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1
|
||||
)
|
||||
|
||||
# 规范化纯数+量词
|
||||
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('cardinal+quantifier')
|
||||
# print('cardinal+quantifier')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
||||
|
||||
@ -719,7 +788,7 @@ class NSWNormalizer:
|
||||
pattern = re.compile(r"(\d{4,32})")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('digit')
|
||||
# print('digit')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
|
||||
|
||||
@ -727,74 +796,82 @@ class NSWNormalizer:
|
||||
pattern = re.compile(r"(\d+(\.\d+)?)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
#print('cardinal')
|
||||
# print('cardinal')
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
||||
|
||||
self.norm_text = text
|
||||
self._particular()
|
||||
|
||||
return self.norm_text.lstrip('^').rstrip('$')
|
||||
return self.norm_text.lstrip("^").rstrip("$")
|
||||
|
||||
|
||||
def nsw_test_case(raw_text):
|
||||
print('I:' + raw_text)
|
||||
print('O:' + NSWNormalizer(raw_text).normalize())
|
||||
print('')
|
||||
print("I:" + raw_text)
|
||||
print("O:" + NSWNormalizer(raw_text).normalize())
|
||||
print("")
|
||||
|
||||
|
||||
def nsw_test():
|
||||
nsw_test_case('固话:0595-23865596或23880880。')
|
||||
nsw_test_case('固话:0595-23865596或23880880。')
|
||||
nsw_test_case('手机:+86 19859213959或15659451527。')
|
||||
nsw_test_case('分数:32477/76391。')
|
||||
nsw_test_case('百分数:80.03%。')
|
||||
nsw_test_case('编号:31520181154418。')
|
||||
nsw_test_case('纯数:2983.07克或12345.60米。')
|
||||
nsw_test_case('日期:1999年2月20日或09年3月15号。')
|
||||
nsw_test_case('金钱:12块5,34.5元,20.1万')
|
||||
nsw_test_case('特殊:O2O或B2C。')
|
||||
nsw_test_case('3456万吨')
|
||||
nsw_test_case('2938个')
|
||||
nsw_test_case('938')
|
||||
nsw_test_case('今天吃了115个小笼包231个馒头')
|
||||
nsw_test_case('有62%的概率')
|
||||
nsw_test_case("固话:0595-23865596或23880880。")
|
||||
nsw_test_case("固话:0595-23865596或23880880。")
|
||||
nsw_test_case("手机:+86 19859213959或15659451527。")
|
||||
nsw_test_case("分数:32477/76391。")
|
||||
nsw_test_case("百分数:80.03%。")
|
||||
nsw_test_case("编号:31520181154418。")
|
||||
nsw_test_case("纯数:2983.07克或12345.60米。")
|
||||
nsw_test_case("日期:1999年2月20日或09年3月15号。")
|
||||
nsw_test_case("金钱:12块5,34.5元,20.1万")
|
||||
nsw_test_case("特殊:O2O或B2C。")
|
||||
nsw_test_case("3456万吨")
|
||||
nsw_test_case("2938个")
|
||||
nsw_test_case("938")
|
||||
nsw_test_case("今天吃了115个小笼包231个馒头")
|
||||
nsw_test_case("有62%的概率")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#nsw_test()
|
||||
if __name__ == "__main__":
|
||||
# nsw_test()
|
||||
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('ifile', help='input filename, assume utf-8 encoding')
|
||||
p.add_argument('ofile', help='output filename')
|
||||
p.add_argument('--to_upper', action='store_true', help='convert to upper case')
|
||||
p.add_argument('--to_lower', action='store_true', help='convert to lower case')
|
||||
p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
|
||||
p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"')
|
||||
p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"')
|
||||
p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
|
||||
p.add_argument("ifile", help="input filename, assume utf-8 encoding")
|
||||
p.add_argument("ofile", help="output filename")
|
||||
p.add_argument("--to_upper", action="store_true", help="convert to upper case")
|
||||
p.add_argument("--to_lower", action="store_true", help="convert to lower case")
|
||||
p.add_argument(
|
||||
"--has_key", action="store_true", help="input text has Kaldi's key as first field."
|
||||
)
|
||||
p.add_argument(
|
||||
"--remove_fillers", type=bool, default=True, help='remove filler chars such as "呃, 啊"'
|
||||
)
|
||||
p.add_argument(
|
||||
"--remove_erhua", type=bool, default=True, help='remove erhua chars such as "这儿"'
|
||||
)
|
||||
p.add_argument(
|
||||
"--log_interval", type=int, default=10000, help="log interval in number of processed lines"
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
ifile = codecs.open(args.ifile, 'r', 'utf8')
|
||||
ofile = codecs.open(args.ofile, 'w+', 'utf8')
|
||||
ifile = codecs.open(args.ifile, "r", "utf8")
|
||||
ofile = codecs.open(args.ofile, "w+", "utf8")
|
||||
|
||||
n = 0
|
||||
for l in ifile:
|
||||
key = ''
|
||||
text = ''
|
||||
key = ""
|
||||
text = ""
|
||||
if args.has_key:
|
||||
cols = l.split(maxsplit=1)
|
||||
key = cols[0]
|
||||
if len(cols) == 2:
|
||||
text = cols[1].strip()
|
||||
else:
|
||||
text = ''
|
||||
text = ""
|
||||
else:
|
||||
text = l.strip()
|
||||
|
||||
# cases
|
||||
if args.to_upper and args.to_lower:
|
||||
sys.stderr.write('text norm: to_upper OR to_lower?')
|
||||
sys.stderr.write("text norm: to_upper OR to_lower?")
|
||||
exit(1)
|
||||
if args.to_upper:
|
||||
text = text.upper()
|
||||
@ -804,7 +881,7 @@ if __name__ == '__main__':
|
||||
# Filler chars removal
|
||||
if args.remove_fillers:
|
||||
for ch in FILLER_CHARS:
|
||||
text = text.replace(ch, '')
|
||||
text = text.replace(ch, "")
|
||||
|
||||
if args.remove_erhua:
|
||||
text = remove_erhua(text, ER_WHITELIST)
|
||||
@ -813,16 +890,16 @@ if __name__ == '__main__':
|
||||
text = NSWNormalizer(text).normalize()
|
||||
|
||||
# Punctuations removal
|
||||
old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
|
||||
new_chars = ' ' * len(old_chars)
|
||||
del_chars = ''
|
||||
old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
|
||||
new_chars = " " * len(old_chars)
|
||||
del_chars = ""
|
||||
text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
|
||||
|
||||
#
|
||||
if args.has_key:
|
||||
ofile.write(key + '\t' + text + '\n')
|
||||
ofile.write(key + "\t" + text + "\n")
|
||||
else:
|
||||
ofile.write(text + '\n')
|
||||
ofile.write(text + "\n")
|
||||
|
||||
n += 1
|
||||
if n % args.log_interval == 0:
|
||||
|
||||
@ -16,4 +16,4 @@ model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch")
|
||||
for wav_id in multilingual_wavs:
|
||||
wav_file = f"{model.model_path}/examples/{wav_id}"
|
||||
res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
|
||||
print("detect sample {}: {}".format(wav_id, res))
|
||||
print("detect sample {}: {}".format(wav_id, res))
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
multilingual_wavs=[
|
||||
multilingual_wavs = [
|
||||
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
|
||||
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
|
||||
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
|
||||
@ -14,9 +14,9 @@ multilingual_wavs=[
|
||||
]
|
||||
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model='iic/speech_whisper-large_lid_multilingual_pytorch')
|
||||
task=Tasks.auto_speech_recognition, model="iic/speech_whisper-large_lid_multilingual_pytorch"
|
||||
)
|
||||
|
||||
for wav in multilingual_wavs:
|
||||
rec_result = inference_pipeline(input=wav, inference_clip_length=250)
|
||||
print(rec_result)
|
||||
print(rec_result)
|
||||
|
||||
@ -5,11 +5,16 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav",
|
||||
batch_size_s=300,
|
||||
batch_size_threshold_s=60,
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -7,7 +7,10 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", device="cpu")
|
||||
model = AutoModel(
|
||||
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
@ -16,7 +19,10 @@ print(res)
|
||||
# method2, inference from local path
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", device="cpu")
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
print(res)
|
||||
|
||||
@ -5,8 +5,9 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_campplus_sv_zh-cn_16k-common"
|
||||
)
|
||||
model = AutoModel(model="iic/speech_campplus_sv_zh-cn_16k-common")
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
print(res)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -7,6 +7,7 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch")
|
||||
|
||||
res = model.generate(input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav")
|
||||
res = model.generate(
|
||||
input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@ -7,6 +7,8 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404")
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
hotword='达摩院 魔搭')
|
||||
print(res)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
hotword="达摩院 魔搭",
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -7,7 +7,9 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@ -15,5 +17,7 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/punc_ct-transformer_cn-en-common-vocab471067-large")
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
|
||||
print(res)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -7,8 +7,9 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
@ -17,7 +18,9 @@ print(res)
|
||||
# method2, inference from local path
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
print(res)
|
||||
|
||||
@ -13,6 +13,6 @@ rec_result_all = "outputs: "
|
||||
cache = {}
|
||||
for vad in vads:
|
||||
rec_result = model.generate(input=vad, cache=cache)
|
||||
rec_result_all += rec_result[0]['text']
|
||||
rec_result_all += rec_result[0]["text"]
|
||||
|
||||
print(rec_result_all)
|
||||
|
||||
@ -7,8 +7,9 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
@ -17,7 +18,9 @@ print(res)
|
||||
# method2, inference from local path
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727")
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
print(res)
|
||||
|
||||
@ -6,12 +6,15 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
# model="iic/emotion2vec_base"
|
||||
model = AutoModel(model="iic/emotion2vec_base_finetuned",
|
||||
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_model_revision="master",
|
||||
# vad_kwargs={"max_single_segment_time": 2000},
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/emotion2vec_base_finetuned",
|
||||
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_model_revision="master",
|
||||
# vad_kwargs={"max_single_segment_time": 2000},
|
||||
)
|
||||
|
||||
wav_file = f"{model.model_path}/example/test.wav"
|
||||
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
|
||||
print(res)
|
||||
res = model.generate(
|
||||
wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
|
||||
|
||||
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
|
||||
@ -14,28 +15,28 @@ print(res)
|
||||
# beg/end: ms
|
||||
|
||||
|
||||
|
||||
import soundfile
|
||||
import os
|
||||
|
||||
wav_file = os.path.join(model.model_path, "example/vad_example.wav")
|
||||
speech, sample_rate = soundfile.read(wav_file)
|
||||
|
||||
chunk_size = 200 # ms
|
||||
chunk_size = 200 # ms
|
||||
chunk_stride = int(chunk_size * sample_rate / 1000)
|
||||
|
||||
cache = {}
|
||||
|
||||
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
|
||||
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||
for i in range(total_chunk_num):
|
||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
||||
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||
is_final = i == total_chunk_num - 1
|
||||
res = model.generate(input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
disable_pbar=True,
|
||||
)
|
||||
res = model.generate(
|
||||
input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
disable_pbar=True,
|
||||
)
|
||||
# print(res)
|
||||
if len(res[0]["value"]):
|
||||
print(res)
|
||||
@ -44,4 +45,4 @@ for i in range(total_chunk_num):
|
||||
# 1. [[beg1, end1], [beg2, end2], .., [begN, endN]]; [[beg, end]]; [[beg1, end1], [beg2, end2]]
|
||||
# 2. [[beg, -1]]
|
||||
# 3. [[-1, end]]
|
||||
# beg/end: ms
|
||||
# beg/end: ms
|
||||
|
||||
@ -17,7 +17,9 @@ print(res)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
|
||||
@ -9,6 +9,7 @@ import argparse
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import pdb
|
||||
|
||||
remove_tag = False
|
||||
spacelist = [" ", "\t", "\r", "\n"]
|
||||
puncts = [
|
||||
@ -51,9 +52,9 @@ class WordError(object):
|
||||
def get_wer(self):
|
||||
assert self.ref_words != 0
|
||||
errors = (
|
||||
self.errors[Code.substitution]
|
||||
+ self.errors[Code.insertion]
|
||||
+ self.errors[Code.deletion]
|
||||
self.errors[Code.substitution]
|
||||
+ self.errors[Code.insertion]
|
||||
+ self.errors[Code.deletion]
|
||||
)
|
||||
return 100.0 * errors / self.ref_words
|
||||
|
||||
@ -299,30 +300,30 @@ def default_cluster(word):
|
||||
for i in reversed(range(len(unicode_names))):
|
||||
if unicode_names[i].startswith("DIGIT"): # 1
|
||||
unicode_names[i] = "Number" # 'DIGIT'
|
||||
elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
|
||||
i
|
||||
].startswith("CJK COMPATIBILITY IDEOGRAPH"):
|
||||
elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[i].startswith(
|
||||
"CJK COMPATIBILITY IDEOGRAPH"
|
||||
):
|
||||
# 明 / 郎
|
||||
unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
|
||||
elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
|
||||
i
|
||||
].startswith("LATIN SMALL LETTER"):
|
||||
elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[i].startswith(
|
||||
"LATIN SMALL LETTER"
|
||||
):
|
||||
# A / a
|
||||
unicode_names[i] = "English" # 'LATIN LETTER'
|
||||
elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め
|
||||
unicode_names[i] = "Japanese" # 'GANA LETTER'
|
||||
elif (
|
||||
unicode_names[i].startswith("AMPERSAND")
|
||||
or unicode_names[i].startswith("APOSTROPHE")
|
||||
or unicode_names[i].startswith("COMMERCIAL AT")
|
||||
or unicode_names[i].startswith("DEGREE CELSIUS")
|
||||
or unicode_names[i].startswith("EQUALS SIGN")
|
||||
or unicode_names[i].startswith("FULL STOP")
|
||||
or unicode_names[i].startswith("HYPHEN-MINUS")
|
||||
or unicode_names[i].startswith("LOW LINE")
|
||||
or unicode_names[i].startswith("NUMBER SIGN")
|
||||
or unicode_names[i].startswith("PLUS SIGN")
|
||||
or unicode_names[i].startswith("SEMICOLON")
|
||||
unicode_names[i].startswith("AMPERSAND")
|
||||
or unicode_names[i].startswith("APOSTROPHE")
|
||||
or unicode_names[i].startswith("COMMERCIAL AT")
|
||||
or unicode_names[i].startswith("DEGREE CELSIUS")
|
||||
or unicode_names[i].startswith("EQUALS SIGN")
|
||||
or unicode_names[i].startswith("FULL STOP")
|
||||
or unicode_names[i].startswith("HYPHEN-MINUS")
|
||||
or unicode_names[i].startswith("LOW LINE")
|
||||
or unicode_names[i].startswith("NUMBER SIGN")
|
||||
or unicode_names[i].startswith("PLUS SIGN")
|
||||
or unicode_names[i].startswith("SEMICOLON")
|
||||
):
|
||||
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
|
||||
del unicode_names[i]
|
||||
@ -411,11 +412,13 @@ def main(args):
|
||||
if len(array) == 0:
|
||||
continue
|
||||
fid = array[0]
|
||||
rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
|
||||
rec_sets[rec_names[i]][fid] = normalize(
|
||||
array[1:], ignore_words, case_sensitive, split
|
||||
)
|
||||
|
||||
calculators_dict[rec_names[i]] = Calculator()
|
||||
ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
|
||||
hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
|
||||
hotwords_related_dict[rec_names[i]] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
|
||||
# tp: 热词在label里,同时在rec里
|
||||
# tn: 热词不在label里,同时不在rec里
|
||||
# fp: 热词不在label里,但是在rec里
|
||||
@ -431,21 +434,22 @@ def main(args):
|
||||
_file_total_len = int(pipe.read().strip())
|
||||
|
||||
# compute error rate on the interaction of reference file and hyp file
|
||||
for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
|
||||
for line in tqdm(open(ref_file, "r", encoding="utf-8"), total=_file_total_len):
|
||||
if tochar:
|
||||
array = characterize(line)
|
||||
else:
|
||||
array = line.rstrip('\n').split()
|
||||
if len(array) == 0: continue
|
||||
array = line.rstrip("\n").split()
|
||||
if len(array) == 0:
|
||||
continue
|
||||
fid = array[0]
|
||||
lab = normalize(array[1:], ignore_words, case_sensitive, split)
|
||||
|
||||
if verbose:
|
||||
print('\nutt: %s' % fid)
|
||||
print("\nutt: %s" % fid)
|
||||
|
||||
ocr_text = ref_ocr_dict[fid]
|
||||
ocr_set = set(ocr_text)
|
||||
print('ocr: {}'.format(" ".join(ocr_text)))
|
||||
print("ocr: {}".format(" ".join(ocr_text)))
|
||||
list_match = [] # 指label里面在ocr里面的内容
|
||||
list_not_mathch = []
|
||||
tmp_error = 0
|
||||
@ -458,7 +462,7 @@ def main(args):
|
||||
else:
|
||||
tmp_match += 1
|
||||
list_match.append(lab[index])
|
||||
print('label in ocr: {}'.format(" ".join(list_match)))
|
||||
print("label in ocr: {}".format(" ".join(list_match)))
|
||||
|
||||
# for each reco file
|
||||
base_wrong_ocr_wer = None
|
||||
@ -482,33 +486,44 @@ def main(args):
|
||||
|
||||
result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
|
||||
if verbose:
|
||||
if result['all'] != 0:
|
||||
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
|
||||
if result["all"] != 0:
|
||||
wer = (
|
||||
float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
|
||||
)
|
||||
else:
|
||||
wer = 0.0
|
||||
print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
|
||||
print('N=%d C=%d S=%d D=%d I=%d' %
|
||||
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
|
||||
|
||||
print("WER(%s): %4.2f %%" % (rec_name, wer), end=" ")
|
||||
print(
|
||||
"N=%d C=%d S=%d D=%d I=%d"
|
||||
% (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
|
||||
)
|
||||
|
||||
# print(result['rec'])
|
||||
wrong_rec_but_in_ocr = []
|
||||
for idx in range(len(result['lab'])):
|
||||
if result['lab'][idx] != "":
|
||||
if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
|
||||
if result['lab'][idx] in list_match:
|
||||
wrong_rec_but_in_ocr.append(result['lab'][idx])
|
||||
for idx in range(len(result["lab"])):
|
||||
if result["lab"][idx] != "":
|
||||
if result["lab"][idx] != result["rec"][idx].replace("<BIAS>", ""):
|
||||
if result["lab"][idx] in list_match:
|
||||
wrong_rec_but_in_ocr.append(result["lab"][idx])
|
||||
wrong_rec_but_in_ocr_dict[rec_name] += 1
|
||||
print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
|
||||
print("wrong_rec_but_in_ocr: {}".format(" ".join(wrong_rec_but_in_ocr)))
|
||||
|
||||
if rec_name == "base":
|
||||
base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
|
||||
if "ocr" in rec_name or "hot" in rec_name:
|
||||
ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
|
||||
if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
|
||||
print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
|
||||
print(
|
||||
"{} {} helps, {} -> {}".format(
|
||||
fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
|
||||
)
|
||||
)
|
||||
elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
|
||||
print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
|
||||
print(
|
||||
"{} {} hurts, {} -> {}".format(
|
||||
fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
|
||||
)
|
||||
)
|
||||
|
||||
# recall = 0
|
||||
# false_alarm = 0
|
||||
@ -537,11 +552,11 @@ def main(args):
|
||||
# if badhotword == word:
|
||||
# count += 1
|
||||
if count == 0:
|
||||
hotwords_related_dict[rec_name]['tn'] += 1
|
||||
hotwords_related_dict[rec_name]["tn"] += 1
|
||||
_tn += 1
|
||||
# fp: 0
|
||||
else:
|
||||
hotwords_related_dict[rec_name]['fp'] += count
|
||||
hotwords_related_dict[rec_name]["fp"] += count
|
||||
_fp += count
|
||||
# tn: 0
|
||||
# if badhotword in _rec_list:
|
||||
@ -553,23 +568,30 @@ def main(args):
|
||||
rec_count = len([word for word in _rec_list if hotword == word])
|
||||
# print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
|
||||
if rec_count == true_count:
|
||||
hotwords_related_dict[rec_name]['tp'] += true_count
|
||||
hotwords_related_dict[rec_name]["tp"] += true_count
|
||||
_tp += true_count
|
||||
elif rec_count > true_count:
|
||||
hotwords_related_dict[rec_name]['tp'] += true_count
|
||||
hotwords_related_dict[rec_name]["tp"] += true_count
|
||||
# fp: 不在label里,但是在rec里
|
||||
hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
|
||||
hotwords_related_dict[rec_name]["fp"] += rec_count - true_count
|
||||
_tp += true_count
|
||||
_fp += rec_count - true_count
|
||||
else:
|
||||
hotwords_related_dict[rec_name]['tp'] += rec_count
|
||||
hotwords_related_dict[rec_name]["tp"] += rec_count
|
||||
# fn: 热词在label里,但是不在rec里
|
||||
hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
|
||||
hotwords_related_dict[rec_name]["fn"] += true_count - rec_count
|
||||
_tp += rec_count
|
||||
_fn += true_count - rec_count
|
||||
print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
|
||||
_tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
|
||||
))
|
||||
print(
|
||||
"hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
|
||||
_tp,
|
||||
_tn,
|
||||
_fp,
|
||||
_fn,
|
||||
sum([_tp, _tn, _fp, _fn]),
|
||||
_tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0,
|
||||
)
|
||||
)
|
||||
|
||||
# if hotword in _rec_list:
|
||||
# hotwords_related_dict[rec_name]['tp'] += 1
|
||||
@ -612,77 +634,89 @@ def main(args):
|
||||
ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
|
||||
|
||||
space = {}
|
||||
space['lab'] = []
|
||||
space['rec'] = []
|
||||
for idx in range(len(result['lab'])):
|
||||
len_lab = width(result['lab'][idx])
|
||||
len_rec = width(result['rec'][idx])
|
||||
space["lab"] = []
|
||||
space["rec"] = []
|
||||
for idx in range(len(result["lab"])):
|
||||
len_lab = width(result["lab"][idx])
|
||||
len_rec = width(result["rec"][idx])
|
||||
length = max(len_lab, len_rec)
|
||||
space['lab'].append(length - len_lab)
|
||||
space['rec'].append(length - len_rec)
|
||||
upper_lab = len(result['lab'])
|
||||
upper_rec = len(result['rec'])
|
||||
space["lab"].append(length - len_lab)
|
||||
space["rec"].append(length - len_rec)
|
||||
upper_lab = len(result["lab"])
|
||||
upper_rec = len(result["rec"])
|
||||
lab1, rec1 = 0, 0
|
||||
while lab1 < upper_lab or rec1 < upper_rec:
|
||||
if verbose > 1:
|
||||
print('lab(%s):' % fid.encode('utf-8'), end=' ')
|
||||
print("lab(%s):" % fid.encode("utf-8"), end=" ")
|
||||
else:
|
||||
print('lab:', end=' ')
|
||||
print("lab:", end=" ")
|
||||
lab2 = min(upper_lab, lab1 + max_words_per_line)
|
||||
for idx in range(lab1, lab2):
|
||||
token = result['lab'][idx]
|
||||
print('{token}'.format(token=token), end='')
|
||||
for n in range(space['lab'][idx]):
|
||||
print(padding_symbol, end='')
|
||||
print(' ', end='')
|
||||
token = result["lab"][idx]
|
||||
print("{token}".format(token=token), end="")
|
||||
for n in range(space["lab"][idx]):
|
||||
print(padding_symbol, end="")
|
||||
print(" ", end="")
|
||||
print()
|
||||
if verbose > 1:
|
||||
print('rec(%s):' % fid.encode('utf-8'), end=' ')
|
||||
print("rec(%s):" % fid.encode("utf-8"), end=" ")
|
||||
else:
|
||||
print('rec:', end=' ')
|
||||
print("rec:", end=" ")
|
||||
|
||||
rec2 = min(upper_rec, rec1 + max_words_per_line)
|
||||
for idx in range(rec1, rec2):
|
||||
token = result['rec'][idx]
|
||||
print('{token}'.format(token=token), end='')
|
||||
for n in range(space['rec'][idx]):
|
||||
print(padding_symbol, end='')
|
||||
print(' ', end='')
|
||||
token = result["rec"][idx]
|
||||
print("{token}".format(token=token), end="")
|
||||
for n in range(space["rec"][idx]):
|
||||
print(padding_symbol, end="")
|
||||
print(" ", end="")
|
||||
print()
|
||||
# print('\n', end='\n')
|
||||
lab1 = lab2
|
||||
rec1 = rec2
|
||||
print('\n', end='\n')
|
||||
print("\n", end="\n")
|
||||
# break
|
||||
if verbose:
|
||||
print('===========================================================================')
|
||||
print("===========================================================================")
|
||||
print()
|
||||
|
||||
print(wrong_rec_but_in_ocr_dict)
|
||||
for rec_name in rec_names:
|
||||
result = calculators_dict[rec_name].overall()
|
||||
|
||||
if result['all'] != 0:
|
||||
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
|
||||
if result["all"] != 0:
|
||||
wer = float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
|
||||
else:
|
||||
wer = 0.0
|
||||
print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
|
||||
print('N=%d C=%d S=%d D=%d I=%d' %
|
||||
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
|
||||
print("{} Overall -> {:4.2f} %".format(rec_name, wer), end=" ")
|
||||
print(
|
||||
"N=%d C=%d S=%d D=%d I=%d"
|
||||
% (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
|
||||
)
|
||||
print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
|
||||
print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
|
||||
print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
|
||||
|
||||
print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
|
||||
hotwords_related_dict[rec_name]['tp'],
|
||||
hotwords_related_dict[rec_name]['tn'],
|
||||
hotwords_related_dict[rec_name]['fp'],
|
||||
hotwords_related_dict[rec_name]['fn'],
|
||||
sum([v for k, v in hotwords_related_dict[rec_name].items()]),
|
||||
hotwords_related_dict[rec_name]['tp'] / (
|
||||
hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
|
||||
) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
|
||||
))
|
||||
print(
|
||||
"hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
|
||||
hotwords_related_dict[rec_name]["tp"],
|
||||
hotwords_related_dict[rec_name]["tn"],
|
||||
hotwords_related_dict[rec_name]["fp"],
|
||||
hotwords_related_dict[rec_name]["fn"],
|
||||
sum([v for k, v in hotwords_related_dict[rec_name].items()]),
|
||||
(
|
||||
hotwords_related_dict[rec_name]["tp"]
|
||||
/ (
|
||||
hotwords_related_dict[rec_name]["tp"]
|
||||
+ hotwords_related_dict[rec_name]["fn"]
|
||||
)
|
||||
* 100
|
||||
if hotwords_related_dict[rec_name]["tp"] + hotwords_related_dict[rec_name]["fn"]
|
||||
!= 0
|
||||
else 0
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# tp: 热词在label里,同时在rec里
|
||||
# tn: 热词不在label里,同时不在rec里
|
||||
@ -695,8 +729,7 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
|
||||
# print("")
|
||||
print(args)
|
||||
main(args)
|
||||
|
||||
|
||||
@ -5,11 +5,15 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/LCB-NET",
|
||||
model_revision="v1.0.0")
|
||||
model = AutoModel(model="iic/LCB-NET", model_revision="v1.0.0")
|
||||
|
||||
|
||||
|
||||
res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
|
||||
res = model.generate(
|
||||
input=(
|
||||
"https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav",
|
||||
"https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt",
|
||||
),
|
||||
data_type=("sound", "text"),
|
||||
)
|
||||
|
||||
print(res)
|
||||
|
||||
@ -7,9 +7,12 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_timestamp_prediction-v1-16k-offline")
|
||||
|
||||
res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
"欢迎大家来到魔搭社区进行体验"),
|
||||
data_type=("sound", "text"),
|
||||
batch_size=2,
|
||||
)
|
||||
res = model.generate(
|
||||
input=(
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
"欢迎大家来到魔搭社区进行体验",
|
||||
),
|
||||
data_type=("sound", "text"),
|
||||
batch_size=2,
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -5,12 +5,15 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
hotword='达摩院 磨搭')
|
||||
print(res)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
hotword="达摩院 磨搭",
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -5,22 +5,29 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 60000},
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 60000},
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
|
||||
print(res)
|
||||
|
||||
|
||||
''' can not use currently
|
||||
""" can not use currently
|
||||
from funasr import AutoFrontend
|
||||
|
||||
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||||
@ -30,4 +37,4 @@ fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/
|
||||
for batch_idx, fbank_dict in enumerate(fbanks):
|
||||
res = model.generate(**fbank_dict)
|
||||
print(res)
|
||||
'''
|
||||
"""
|
||||
|
||||
@ -9,7 +9,9 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",)
|
||||
model = AutoModel(
|
||||
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
@ -18,7 +20,9 @@ print(res)
|
||||
# method2, inference from local path
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
print(res)
|
||||
|
||||
@ -7,17 +7,18 @@ import os
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
|
||||
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
|
||||
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
|
||||
chunk_size = [0, 10, 5] # [0, 10, 5] 600ms, [0, 8, 4] 480ms
|
||||
encoder_chunk_look_back = 4 # number of chunks to lookback for encoder self-attention
|
||||
decoder_chunk_look_back = 1 # number of encoder chunks to lookback for decoder cross-attention
|
||||
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
|
||||
|
||||
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
|
||||
res = model.generate(input=wav_file,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
res = model.generate(
|
||||
input=wav_file,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@ -27,18 +28,19 @@ import soundfile
|
||||
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
|
||||
speech, sample_rate = soundfile.read(wav_file)
|
||||
|
||||
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
|
||||
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
|
||||
|
||||
cache = {}
|
||||
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
|
||||
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||
for i in range(total_chunk_num):
|
||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
||||
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||
is_final = i == total_chunk_num - 1
|
||||
res = model.generate(input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
res = model.generate(
|
||||
input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -9,7 +9,9 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", )
|
||||
model = AutoModel(
|
||||
model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
@ -19,7 +21,9 @@ print(res)
|
||||
from funasr import AutoModel
|
||||
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online"
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
print(res)
|
||||
print(res)
|
||||
|
||||
@ -12,7 +12,7 @@ model = AutoModel(model="Qwen-Audio-Chat")
|
||||
audio_in = "https://github.com/QwenLM/Qwen-Audio/raw/main/assets/audio/1272-128104-0000.flac"
|
||||
|
||||
# 1st dialogue turn
|
||||
prompt = 'what does the person say?'
|
||||
prompt = "what does the person say?"
|
||||
cache = {"history": None}
|
||||
res = model.generate(input=audio_in, prompt=prompt, cache=cache)
|
||||
print(res)
|
||||
@ -22,4 +22,3 @@ print(res)
|
||||
prompt = 'Find the start time and end time of the word "middle classes"'
|
||||
res = model.generate(input=None, prompt=prompt, cache=cache)
|
||||
print(res)
|
||||
|
||||
|
||||
@ -7,14 +7,17 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="Qwen-Audio-Chat",
|
||||
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio-Chat",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="Qwen-Audio-Chat",
|
||||
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio-Chat",
|
||||
)
|
||||
|
||||
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
audio_in = (
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
|
||||
# 1st dialogue turn
|
||||
prompt = 'what does the person say?'
|
||||
prompt = "what does the person say?"
|
||||
cache = {"history": None}
|
||||
res = model.generate(input=audio_in, prompt=prompt, cache=cache)
|
||||
print(res)
|
||||
@ -24,4 +27,3 @@ print(res)
|
||||
prompt = 'Find the start time and end time of the word "middle classes"'
|
||||
res = model.generate(input=None, prompt=prompt, cache=cache)
|
||||
print(res)
|
||||
|
||||
|
||||
@ -7,9 +7,10 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="Qwen-Audio",
|
||||
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="Qwen-Audio",
|
||||
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio",
|
||||
)
|
||||
|
||||
audio_in = "https://github.com/QwenLM/Qwen-Audio/raw/main/assets/audio/1272-128104-0000.flac"
|
||||
prompt = "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
|
||||
|
||||
@ -5,17 +5,20 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
|
||||
encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
|
||||
decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
|
||||
chunk_size = [5, 10, 5] # [0, 10, 5] 600ms, [0, 8, 4] 480ms
|
||||
encoder_chunk_look_back = 0 # number of chunks to lookback for encoder self-attention
|
||||
decoder_chunk_look_back = 0 # number of encoder chunks to lookback for decoder cross-attention
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_SCAMA_asr-zh-cn-16k-common-vocab8358-streaming")
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/Downloads/modelscope_models/speech_SCAMA_asr-zh-cn-16k-common-vocab8358-streaming"
|
||||
)
|
||||
cache = {}
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@ -25,18 +28,19 @@ import os
|
||||
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
|
||||
speech, sample_rate = soundfile.read(wav_file)
|
||||
|
||||
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
|
||||
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
|
||||
|
||||
cache = {}
|
||||
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
|
||||
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
|
||||
for i in range(total_chunk_num):
|
||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
||||
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
|
||||
is_final = i == total_chunk_num - 1
|
||||
res = model.generate(input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
res = model.generate(
|
||||
input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -5,24 +5,26 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
)
|
||||
|
||||
|
||||
# example1
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
hotword='达摩院 魔搭',
|
||||
# return_raw_text=True, # return raw text recognition results splited by space of equal length with timestamp
|
||||
# preset_spk_num=2, # preset speaker num for speaker cluster model
|
||||
# sentence_timestamp=True, # return sentence level information when spk_model is not given
|
||||
)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
hotword="达摩院 魔搭",
|
||||
# return_raw_text=True, # return raw text recognition results splited by space of equal length with timestamp
|
||||
# preset_spk_num=2, # preset speaker num for speaker cluster model
|
||||
# sentence_timestamp=True, # return sentence level information when spk_model is not given
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
# tensor or numpy as input
|
||||
# example2
|
||||
import torchaudio
|
||||
@ -39,4 +41,4 @@ import soundfile
|
||||
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
|
||||
speech, sample_rate = soundfile.read(wav_file)
|
||||
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
|
||||
'''
|
||||
"""
|
||||
|
||||
@ -5,20 +5,23 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
)
|
||||
model = AutoModel(
|
||||
model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
)
|
||||
|
||||
|
||||
input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
input_wav = (
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
|
||||
DecodingOptions = {
|
||||
"task": ("ASR", "AED", "SER"),
|
||||
"language": "auto",
|
||||
"fp16": True,
|
||||
"gain_event": True,
|
||||
}
|
||||
"task": ("ASR", "AED", "SER"),
|
||||
"language": "auto",
|
||||
"fp16": True,
|
||||
"gain_event": True,
|
||||
}
|
||||
|
||||
res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions)
|
||||
print(res)
|
||||
|
||||
@ -7,8 +7,11 @@ from funasr import AutoModel
|
||||
|
||||
# Transducer, BAT and RWKV_BAT models are just same to use, use the correct model_revision
|
||||
# https://modelscope.cn/models?name=transducer&page=1&tasks=auto-speech-recognition&type=audio
|
||||
model = AutoModel(model="iic/speech_bat_asr-zh-cn-16k-aishell1-vocab4234-pytorch",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/speech_bat_asr-zh-cn-16k-aishell1-vocab4234-pytorch",
|
||||
)
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
print(res)
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -6,14 +6,18 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
|
||||
model = AutoModel(model="iic/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline",)
|
||||
model = AutoModel(
|
||||
model="iic/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline",
|
||||
)
|
||||
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
res = model.generate(
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
''' can not use currently
|
||||
""" can not use currently
|
||||
from funasr import AutoFrontend
|
||||
|
||||
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||||
@ -23,4 +27,4 @@ fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/
|
||||
for batch_idx, fbank_dict in enumerate(fbanks):
|
||||
res = model.generate(**fbank_dict)
|
||||
print(res)
|
||||
'''
|
||||
"""
|
||||
|
||||
@ -7,22 +7,24 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/Whisper-large-v3",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
)
|
||||
model = AutoModel(
|
||||
model="iic/Whisper-large-v3",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
)
|
||||
|
||||
DecodingOptions = {
|
||||
"task": "transcribe",
|
||||
"language": None,
|
||||
"beam_size": None,
|
||||
"fp16": True,
|
||||
"without_timestamps": False,
|
||||
"prompt": None,
|
||||
}
|
||||
"task": "transcribe",
|
||||
"language": None,
|
||||
"beam_size": None,
|
||||
"fp16": True,
|
||||
"without_timestamps": False,
|
||||
"prompt": None,
|
||||
}
|
||||
res = model.generate(
|
||||
DecodingOptions=DecodingOptions,
|
||||
batch_size_s=0,
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
DecodingOptions=DecodingOptions,
|
||||
batch_size_s=0,
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
)
|
||||
|
||||
print(res)
|
||||
|
||||
@ -10,15 +10,25 @@ from funasr import AutoModel
|
||||
# model = AutoModel(model="Whisper-small", hub="openai")
|
||||
# model = AutoModel(model="Whisper-medium", hub="openai")
|
||||
# model = AutoModel(model="Whisper-large-v2", hub="openai")
|
||||
model = AutoModel(model="Whisper-large-v3",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
hub="openai",
|
||||
)
|
||||
model = AutoModel(
|
||||
model="Whisper-large-v3",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
hub="openai",
|
||||
)
|
||||
|
||||
DecodingOptions = {
|
||||
"task": "transcribe",
|
||||
"language": None,
|
||||
"beam_size": None,
|
||||
"fp16": True,
|
||||
"without_timestamps": False,
|
||||
"prompt": None,
|
||||
}
|
||||
res = model.generate(
|
||||
language=None,
|
||||
task="transcribe",
|
||||
batch_size_s=0,
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
DecodingOptions=DecodingOptions,
|
||||
batch_size_s=0,
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
)
|
||||
|
||||
print(res)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
|
||||
|
||||
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
|
||||
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import (
|
||||
ClassifyFst,
|
||||
)
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize import VerbalizeFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import (
|
||||
VerbalizeFinalFst,
|
||||
)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_SIGMA, GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -6,7 +5,7 @@ from pynini.lib import pynutil
|
||||
|
||||
class CardinalFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for classifying cardinals. Numbers below ten are not converted.
|
||||
Finite state transducer for classifying cardinals. Numbers below ten are not converted.
|
||||
Allows both compound numeral strings or separated by whitespace.
|
||||
"und" (en: "and") can be inserted between "hundert" and following number or "tausend" and following single or double digit number.
|
||||
|
||||
@ -18,7 +17,7 @@ class CardinalFst(GraphFst):
|
||||
e.g. ein tausend -> cardinal { integer: "1000" } }
|
||||
e.g. eintausend -> cardinal { integer: "1000" } }
|
||||
e.g. ein tausend zwanzig -> cardinal { integer: "1020" } }
|
||||
|
||||
|
||||
Args:
|
||||
tn_cardinal_tagger: TN cardinal tagger
|
||||
"""
|
||||
@ -31,24 +30,36 @@ class CardinalFst(GraphFst):
|
||||
|
||||
graph = (tn_cardinal_tagger.graph @ optional_delete_space).invert().optimize()
|
||||
self.graph_hundred_component_at_least_one_none_zero_digit = (
|
||||
(tn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit @ optional_delete_space)
|
||||
(
|
||||
tn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit
|
||||
@ optional_delete_space
|
||||
)
|
||||
.invert()
|
||||
.optimize()
|
||||
)
|
||||
|
||||
self.graph_ties = (tn_cardinal_tagger.two_digit_non_zero @ optional_delete_space).invert().optimize()
|
||||
self.graph_ties = (
|
||||
(tn_cardinal_tagger.two_digit_non_zero @ optional_delete_space).invert().optimize()
|
||||
)
|
||||
# this is to make sure if there is an ambiguity with decimal, decimal is chosen, e.g. 1000000 vs. 1 million
|
||||
graph = pynutil.add_weight(graph, weight=0.001)
|
||||
self.graph_no_exception = graph
|
||||
self.digit = pynini.arcmap(tn_cardinal_tagger.digit, map_type="rmweight").invert().optimize()
|
||||
graph_exception = pynini.project(self.digit, 'input')
|
||||
self.digit = (
|
||||
pynini.arcmap(tn_cardinal_tagger.digit, map_type="rmweight").invert().optimize()
|
||||
)
|
||||
graph_exception = pynini.project(self.digit, "input")
|
||||
self.graph = (pynini.project(graph, "input") - graph_exception.arcsort()) @ graph
|
||||
|
||||
self.optional_minus_graph = pynini.closure(
|
||||
pynutil.insert("negative: ") + pynini.cross("minus ", "\"-\" "), 0, 1
|
||||
pynutil.insert("negative: ") + pynini.cross("minus ", '"-" '), 0, 1
|
||||
)
|
||||
|
||||
final_graph = self.optional_minus_graph + pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
|
||||
final_graph = (
|
||||
self.optional_minus_graph
|
||||
+ pynutil.insert('integer: "')
|
||||
+ self.graph
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
final_graph = self.add_tokens(final_graph)
|
||||
self.fst = final_graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_DIGIT,
|
||||
@ -35,20 +34,28 @@ class DateFst(GraphFst):
|
||||
):
|
||||
super().__init__(name="date", kind="classify", deterministic=deterministic)
|
||||
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
|
||||
pynutil.insert("0") + DAMO_DIGIT
|
||||
)
|
||||
optional_delete_space = pynini.closure(DAMO_SIGMA | pynutil.delete(" ", weight=0.0001))
|
||||
tagger = tn_date_verbalizer.graph.invert().optimize()
|
||||
|
||||
delete_day_marker = (
|
||||
pynutil.delete("day: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||
pynutil.delete('day: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
|
||||
) @ itn_cardinal_tagger.graph_no_exception
|
||||
|
||||
month_as_number = pynutil.delete("month: \"") + itn_cardinal_tagger.graph_no_exception + pynutil.delete("\"")
|
||||
month_as_string = pynutil.delete("month: \"") + tn_date_tagger.month_abbr.invert() + pynutil.delete("\"")
|
||||
month_as_number = (
|
||||
pynutil.delete('month: "')
|
||||
+ itn_cardinal_tagger.graph_no_exception
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
month_as_string = (
|
||||
pynutil.delete('month: "') + tn_date_tagger.month_abbr.invert() + pynutil.delete('"')
|
||||
)
|
||||
|
||||
convert_year = (tn_date_tagger.year @ optional_delete_space).invert().optimize()
|
||||
delete_year_marker = (
|
||||
pynutil.delete("year: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||
pynutil.delete('year: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
|
||||
) @ convert_year
|
||||
|
||||
# day. month as string (year)
|
||||
@ -73,5 +80,5 @@ class DateFst(GraphFst):
|
||||
|
||||
final_graph = tagger @ verbalizer
|
||||
|
||||
graph = pynutil.insert("name: \"") + convert_space(final_graph) + pynutil.insert("\"")
|
||||
graph = pynutil.insert('name: "') + convert_space(final_graph) + pynutil.insert('"')
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.de.taggers.decimal import get_quantity, quantities
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_SIGMA, GraphFst
|
||||
@ -15,18 +14,24 @@ class DecimalFst(GraphFst):
|
||||
tn_decimal_tagger: TN decimal tagger
|
||||
"""
|
||||
|
||||
def __init__(self, itn_cardinal_tagger: GraphFst, tn_decimal_tagger: GraphFst, deterministic: bool = True):
|
||||
def __init__(
|
||||
self, itn_cardinal_tagger: GraphFst, tn_decimal_tagger: GraphFst, deterministic: bool = True
|
||||
):
|
||||
super().__init__(name="decimal", kind="classify", deterministic=deterministic)
|
||||
|
||||
self.graph = tn_decimal_tagger.graph.invert().optimize()
|
||||
|
||||
delete_point = pynutil.delete(" komma")
|
||||
|
||||
allow_spelling = pynini.cdrewrite(pynini.cross("eine ", "eins ") + quantities, "[BOS]", "[EOS]", DAMO_SIGMA)
|
||||
allow_spelling = pynini.cdrewrite(
|
||||
pynini.cross("eine ", "eins ") + quantities, "[BOS]", "[EOS]", DAMO_SIGMA
|
||||
)
|
||||
|
||||
graph_fractional = pynutil.insert("fractional_part: \"") + self.graph + pynutil.insert("\"")
|
||||
graph_fractional = pynutil.insert('fractional_part: "') + self.graph + pynutil.insert('"')
|
||||
graph_integer = (
|
||||
pynutil.insert("integer_part: \"") + itn_cardinal_tagger.graph_no_exception + pynutil.insert("\"")
|
||||
pynutil.insert('integer_part: "')
|
||||
+ itn_cardinal_tagger.graph_no_exception
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
final_graph_wo_sign = graph_integer + delete_point + pynini.accep(" ") + graph_fractional
|
||||
|
||||
@ -35,7 +40,8 @@ class DecimalFst(GraphFst):
|
||||
@ (
|
||||
final_graph_wo_sign
|
||||
| get_quantity(
|
||||
final_graph_wo_sign, itn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit
|
||||
final_graph_wo_sign,
|
||||
itn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit,
|
||||
)
|
||||
).optimize()
|
||||
)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -8,18 +7,23 @@ class ElectronicFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for classifying electronic: email addresses, etc.
|
||||
e.g. c d f eins at a b c punkt e d u -> tokens { name: "cdf1.abc.edu" }
|
||||
|
||||
|
||||
Args:
|
||||
tn_electronic_tagger: TN eletronic tagger
|
||||
tn_electronic_verbalizer: TN eletronic verbalizer
|
||||
"""
|
||||
|
||||
def __init__(self, tn_electronic_tagger: GraphFst, tn_electronic_verbalizer: GraphFst, deterministic: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
tn_electronic_tagger: GraphFst,
|
||||
tn_electronic_verbalizer: GraphFst,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
super().__init__(name="electronic", kind="classify", deterministic=deterministic)
|
||||
|
||||
tagger = pynini.invert(tn_electronic_verbalizer.graph).optimize()
|
||||
verbalizer = pynini.invert(tn_electronic_tagger.graph).optimize()
|
||||
final_graph = tagger @ verbalizer
|
||||
|
||||
graph = pynutil.insert("name: \"") + final_graph + pynutil.insert("\"")
|
||||
graph = pynutil.insert('name: "') + final_graph + pynutil.insert('"')
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_NOT_QUOTE,
|
||||
@ -15,28 +14,41 @@ class FractionFst(GraphFst):
|
||||
e.g. ein halb -> tokens { name: "1/2" }
|
||||
e.g. ein ein halb -> tokens { name: "1 1/2" }
|
||||
e.g. drei zwei ein hundertstel -> tokens { name: "3 2/100" }
|
||||
|
||||
|
||||
Args:
|
||||
itn_cardinal_tagger: ITN cardinal tagger
|
||||
tn_fraction_verbalizer: TN fraction verbalizer
|
||||
"""
|
||||
|
||||
def __init__(self, itn_cardinal_tagger: GraphFst, tn_fraction_verbalizer: GraphFst, deterministic: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
itn_cardinal_tagger: GraphFst,
|
||||
tn_fraction_verbalizer: GraphFst,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
super().__init__(name="fraction", kind="classify", deterministic=deterministic)
|
||||
tagger = tn_fraction_verbalizer.graph.invert().optimize()
|
||||
|
||||
delete_optional_sign = pynini.closure(pynutil.delete("negative: ") + pynini.cross("\"true\" ", "-"), 0, 1)
|
||||
delete_optional_sign = pynini.closure(
|
||||
pynutil.delete("negative: ") + pynini.cross('"true" ', "-"), 0, 1
|
||||
)
|
||||
delete_integer_marker = (
|
||||
pynutil.delete("integer_part: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||
pynutil.delete('integer_part: "')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete('"')
|
||||
) @ itn_cardinal_tagger.graph_no_exception
|
||||
|
||||
delete_numerator_marker = (
|
||||
pynutil.delete("numerator: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||
pynutil.delete('numerator: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
|
||||
) @ itn_cardinal_tagger.graph_no_exception
|
||||
|
||||
delete_denominator_marker = (
|
||||
pynutil.insert('/')
|
||||
+ (pynutil.delete("denominator: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\""))
|
||||
pynutil.insert("/")
|
||||
+ (
|
||||
pynutil.delete('denominator: "')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
@ itn_cardinal_tagger.graph_no_exception
|
||||
)
|
||||
|
||||
@ -50,5 +62,5 @@ class FractionFst(GraphFst):
|
||||
|
||||
self.graph = tagger @ verbalizer
|
||||
|
||||
graph = pynutil.insert("name: \"") + convert_space(self.graph) + pynutil.insert("\"")
|
||||
graph = pynutil.insert('name: "') + convert_space(self.graph) + pynutil.insert('"')
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.de.taggers.measure import singular_to_plural, unit_singular
|
||||
from fun_text_processing.text_normalization.de.taggers.measure import (
|
||||
singular_to_plural,
|
||||
unit_singular,
|
||||
)
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_SIGMA,
|
||||
GraphFst,
|
||||
@ -35,25 +37,31 @@ class MeasureFst(GraphFst):
|
||||
super().__init__(name="measure", kind="classify", deterministic=deterministic)
|
||||
|
||||
cardinal_graph = (
|
||||
pynini.cdrewrite(pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA)
|
||||
pynini.cdrewrite(
|
||||
pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA
|
||||
)
|
||||
@ itn_cardinal_tagger.graph_no_exception
|
||||
)
|
||||
|
||||
graph_unit_singular = pynini.invert(unit_singular) # singular -> abbr
|
||||
unit = (pynini.invert(singular_to_plural()) @ graph_unit_singular) | graph_unit_singular # plural -> abbr
|
||||
unit = (
|
||||
pynini.invert(singular_to_plural()) @ graph_unit_singular
|
||||
) | graph_unit_singular # plural -> abbr
|
||||
unit = convert_space(unit)
|
||||
graph_unit_singular = convert_space(graph_unit_singular)
|
||||
|
||||
optional_graph_negative = pynini.closure(
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", "\"true\"") + delete_extra_space, 0, 1
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", '"true"') + delete_extra_space,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
|
||||
unit_misc = pynutil.insert("/") + pynutil.delete("pro") + delete_space + graph_unit_singular
|
||||
|
||||
unit = (
|
||||
pynutil.insert("units: \"")
|
||||
pynutil.insert('units: "')
|
||||
+ (unit | unit_misc | pynutil.add_weight(unit + delete_space + unit_misc, 0.01))
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
subgraph_decimal = (
|
||||
@ -68,9 +76,9 @@ class MeasureFst(GraphFst):
|
||||
subgraph_fraction = (
|
||||
pynutil.insert("decimal { ")
|
||||
+ optional_graph_negative
|
||||
+ pynutil.insert("integer_part: \"")
|
||||
+ pynutil.insert('integer_part: "')
|
||||
+ itn_fraction_tagger.graph
|
||||
+ pynutil.insert("\" }")
|
||||
+ pynutil.insert('" }')
|
||||
+ delete_extra_space
|
||||
+ unit
|
||||
)
|
||||
@ -78,9 +86,9 @@ class MeasureFst(GraphFst):
|
||||
subgraph_cardinal = (
|
||||
pynutil.insert("cardinal { ")
|
||||
+ optional_graph_negative
|
||||
+ pynutil.insert("integer: \"")
|
||||
+ pynutil.insert('integer: "')
|
||||
+ cardinal_graph
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ pynutil.insert(" }")
|
||||
+ delete_extra_space
|
||||
+ unit
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.de.taggers.money import maj_singular, min_plural, min_singular
|
||||
from fun_text_processing.text_normalization.de.taggers.money import (
|
||||
maj_singular,
|
||||
min_plural,
|
||||
min_singular,
|
||||
)
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_DIGIT,
|
||||
DAMO_SIGMA,
|
||||
@ -23,26 +26,35 @@ class MoneyFst(GraphFst):
|
||||
itn_decimal_tagger: ITN Decimal Tagger
|
||||
"""
|
||||
|
||||
def __init__(self, itn_cardinal_tagger: GraphFst, itn_decimal_tagger: GraphFst, deterministic: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
itn_cardinal_tagger: GraphFst,
|
||||
itn_decimal_tagger: GraphFst,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
super().__init__(name="money", kind="classify", deterministic=deterministic)
|
||||
cardinal_graph = (
|
||||
pynini.cdrewrite(pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA)
|
||||
pynini.cdrewrite(
|
||||
pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA
|
||||
)
|
||||
@ itn_cardinal_tagger.graph_no_exception
|
||||
)
|
||||
graph_decimal_final = itn_decimal_tagger.final_graph_wo_negative
|
||||
|
||||
graph_unit = pynini.invert(maj_singular)
|
||||
graph_unit = pynutil.insert("currency: \"") + convert_space(graph_unit) + pynutil.insert("\"")
|
||||
graph_unit = pynutil.insert('currency: "') + convert_space(graph_unit) + pynutil.insert('"')
|
||||
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
|
||||
pynutil.insert("0") + DAMO_DIGIT
|
||||
)
|
||||
min_unit = pynini.project(min_singular | min_plural, "output")
|
||||
# elf euro (und) vier cent, vier cent
|
||||
cents_standalone = (
|
||||
pynutil.insert("fractional_part: \"")
|
||||
pynutil.insert('fractional_part: "')
|
||||
+ cardinal_graph @ add_leading_zero_to_double_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete(min_unit)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
optional_cents_standalone = pynini.closure(
|
||||
@ -56,23 +68,23 @@ class MoneyFst(GraphFst):
|
||||
# elf euro vierzig, only after integer
|
||||
optional_cents_suffix = pynini.closure(
|
||||
delete_extra_space
|
||||
+ pynutil.insert("fractional_part: \"")
|
||||
+ pynutil.insert('fractional_part: "')
|
||||
+ pynutil.add_weight(cardinal_graph @ add_leading_zero_to_double_digit, -0.7)
|
||||
+ pynutil.insert("\""),
|
||||
+ pynutil.insert('"'),
|
||||
0,
|
||||
1,
|
||||
)
|
||||
|
||||
graph_integer = (
|
||||
pynutil.insert("integer_part: \"")
|
||||
pynutil.insert('integer_part: "')
|
||||
+ cardinal_graph
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ delete_extra_space
|
||||
+ graph_unit
|
||||
+ (optional_cents_standalone | optional_cents_suffix)
|
||||
)
|
||||
graph_decimal = graph_decimal_final + delete_extra_space + graph_unit
|
||||
graph_decimal |= pynutil.insert("currency: \"€\" integer_part: \"0\" ") + cents_standalone
|
||||
graph_decimal |= pynutil.insert('currency: "€" integer_part: "0" ') + cents_standalone
|
||||
final_graph = graph_integer | graph_decimal
|
||||
final_graph = self.add_tokens(final_graph)
|
||||
self.fst = final_graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -14,16 +13,21 @@ class OrdinalFst(GraphFst):
|
||||
tn_ordinal_verbalizer: TN Ordinal Verbalizer
|
||||
"""
|
||||
|
||||
def __init__(self, itn_cardinal_tagger: GraphFst, tn_ordinal_verbalizer: GraphFst, deterministic: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
itn_cardinal_tagger: GraphFst,
|
||||
tn_ordinal_verbalizer: GraphFst,
|
||||
deterministic: bool = True,
|
||||
):
|
||||
super().__init__(name="ordinal", kind="classify", deterministic=deterministic)
|
||||
|
||||
tagger = tn_ordinal_verbalizer.graph.invert().optimize()
|
||||
|
||||
graph = (
|
||||
pynutil.delete("integer: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||
pynutil.delete('integer: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
|
||||
) @ itn_cardinal_tagger.graph
|
||||
|
||||
final_graph = tagger @ graph + pynutil.insert(".")
|
||||
|
||||
graph = pynutil.insert("name: \"") + final_graph + pynutil.insert("\"")
|
||||
graph = pynutil.insert('name: "') + final_graph + pynutil.insert('"')
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, convert_space, insert_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
GraphFst,
|
||||
convert_space,
|
||||
insert_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
class TelephoneFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for classifying telephone numbers, e.g.
|
||||
Finite state transducer for classifying telephone numbers, e.g.
|
||||
null vier eins eins eins zwei drei vier eins zwei drei vier -> tokens { name: "(0411) 1234-1234" }
|
||||
|
||||
|
||||
Args:
|
||||
tn_cardinal_tagger: TN Cardinal Tagger
|
||||
"""
|
||||
@ -35,6 +38,6 @@ class TelephoneFst(GraphFst):
|
||||
+ digit
|
||||
)
|
||||
graph = convert_space(pynini.invert(number_part))
|
||||
final_graph = pynutil.insert("name: \"") + graph + pynutil.insert("\"")
|
||||
final_graph = pynutil.insert('name: "') + graph + pynutil.insert('"')
|
||||
|
||||
self.fst = final_graph.optimize()
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_SIGMA, GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -17,7 +15,7 @@ class TimeFst(GraphFst):
|
||||
e.g. drei vor zwölf -> time { minutes: "57" hours: "11" }
|
||||
e.g. drei nach zwölf -> time { minutes: "3" hours: "12" }
|
||||
e.g. drei uhr zehn minuten zehn sekunden -> time { hours: "3" hours: "10" sekunden: "10"}
|
||||
|
||||
|
||||
Args:
|
||||
tn_time_verbalizer: TN time verbalizer
|
||||
"""
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
|
||||
import pynini
|
||||
@ -15,15 +14,27 @@ from fun_text_processing.inverse_text_normalization.de.taggers.time import TimeF
|
||||
from fun_text_processing.inverse_text_normalization.de.taggers.whitelist import WhiteListFst
|
||||
from fun_text_processing.inverse_text_normalization.en.taggers.punctuation import PunctuationFst
|
||||
from fun_text_processing.inverse_text_normalization.en.taggers.word import WordFst
|
||||
from fun_text_processing.text_normalization.de.taggers.cardinal import CardinalFst as TNCardinalTagger
|
||||
from fun_text_processing.text_normalization.de.taggers.cardinal import (
|
||||
CardinalFst as TNCardinalTagger,
|
||||
)
|
||||
from fun_text_processing.text_normalization.de.taggers.date import DateFst as TNDateTagger
|
||||
from fun_text_processing.text_normalization.de.taggers.decimal import DecimalFst as TNDecimalTagger
|
||||
from fun_text_processing.text_normalization.de.taggers.electronic import ElectronicFst as TNElectronicTagger
|
||||
from fun_text_processing.text_normalization.de.taggers.whitelist import WhiteListFst as TNWhitelistTagger
|
||||
from fun_text_processing.text_normalization.de.taggers.electronic import (
|
||||
ElectronicFst as TNElectronicTagger,
|
||||
)
|
||||
from fun_text_processing.text_normalization.de.taggers.whitelist import (
|
||||
WhiteListFst as TNWhitelistTagger,
|
||||
)
|
||||
from fun_text_processing.text_normalization.de.verbalizers.date import DateFst as TNDateVerbalizer
|
||||
from fun_text_processing.text_normalization.de.verbalizers.electronic import ElectronicFst as TNElectronicVerbalizer
|
||||
from fun_text_processing.text_normalization.de.verbalizers.fraction import FractionFst as TNFractionVerbalizer
|
||||
from fun_text_processing.text_normalization.de.verbalizers.ordinal import OrdinalFst as TNOrdinalVerbalizer
|
||||
from fun_text_processing.text_normalization.de.verbalizers.electronic import (
|
||||
ElectronicFst as TNElectronicVerbalizer,
|
||||
)
|
||||
from fun_text_processing.text_normalization.de.verbalizers.fraction import (
|
||||
FractionFst as TNFractionVerbalizer,
|
||||
)
|
||||
from fun_text_processing.text_normalization.de.verbalizers.ordinal import (
|
||||
OrdinalFst as TNOrdinalVerbalizer,
|
||||
)
|
||||
from fun_text_processing.text_normalization.de.verbalizers.time import TimeFst as TNTimeVerbalizer
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
GraphFst,
|
||||
@ -39,7 +50,7 @@ import logging
|
||||
class ClassifyFst(GraphFst):
|
||||
"""
|
||||
Final class that composes all other classification grammars. This class can process an entire sentence, that is lower cased.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||
|
||||
Args:
|
||||
@ -47,11 +58,13 @@ class ClassifyFst(GraphFst):
|
||||
overwrite_cache: set to True to overwrite .far files
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, deterministic: bool = True):
|
||||
def __init__(
|
||||
self, cache_dir: str = None, overwrite_cache: bool = False, deterministic: bool = True
|
||||
):
|
||||
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
|
||||
|
||||
far_file = None
|
||||
if cache_dir is not None and cache_dir != 'None':
|
||||
if cache_dir is not None and cache_dir != "None":
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
far_file = os.path.join(cache_dir, "_de_itn.far")
|
||||
if not overwrite_cache and far_file and os.path.exists(far_file):
|
||||
@ -63,9 +76,15 @@ class ClassifyFst(GraphFst):
|
||||
tn_date_tagger = TNDateTagger(cardinal=tn_cardinal_tagger, deterministic=False)
|
||||
tn_decimal_tagger = TNDecimalTagger(cardinal=tn_cardinal_tagger, deterministic=False)
|
||||
tn_ordinal_verbalizer = TNOrdinalVerbalizer(deterministic=False)
|
||||
tn_fraction_verbalizer = TNFractionVerbalizer(ordinal=tn_ordinal_verbalizer, deterministic=False)
|
||||
tn_time_verbalizer = TNTimeVerbalizer(cardinal_tagger=tn_cardinal_tagger, deterministic=False)
|
||||
tn_date_verbalizer = TNDateVerbalizer(ordinal=tn_ordinal_verbalizer, deterministic=False)
|
||||
tn_fraction_verbalizer = TNFractionVerbalizer(
|
||||
ordinal=tn_ordinal_verbalizer, deterministic=False
|
||||
)
|
||||
tn_time_verbalizer = TNTimeVerbalizer(
|
||||
cardinal_tagger=tn_cardinal_tagger, deterministic=False
|
||||
)
|
||||
tn_date_verbalizer = TNDateVerbalizer(
|
||||
ordinal=tn_ordinal_verbalizer, deterministic=False
|
||||
)
|
||||
tn_electronic_tagger = TNElectronicTagger(deterministic=False)
|
||||
tn_electronic_verbalizer = TNElectronicVerbalizer(deterministic=False)
|
||||
tn_whitelist_tagger = TNWhitelistTagger(input_case="cased", deterministic=False)
|
||||
@ -73,19 +92,27 @@ class ClassifyFst(GraphFst):
|
||||
cardinal = CardinalFst(tn_cardinal_tagger=tn_cardinal_tagger)
|
||||
cardinal_graph = cardinal.fst
|
||||
|
||||
ordinal = OrdinalFst(itn_cardinal_tagger=cardinal, tn_ordinal_verbalizer=tn_ordinal_verbalizer)
|
||||
ordinal = OrdinalFst(
|
||||
itn_cardinal_tagger=cardinal, tn_ordinal_verbalizer=tn_ordinal_verbalizer
|
||||
)
|
||||
ordinal_graph = ordinal.fst
|
||||
decimal = DecimalFst(itn_cardinal_tagger=cardinal, tn_decimal_tagger=tn_decimal_tagger)
|
||||
decimal_graph = decimal.fst
|
||||
|
||||
fraction = FractionFst(itn_cardinal_tagger=cardinal, tn_fraction_verbalizer=tn_fraction_verbalizer)
|
||||
fraction = FractionFst(
|
||||
itn_cardinal_tagger=cardinal, tn_fraction_verbalizer=tn_fraction_verbalizer
|
||||
)
|
||||
fraction_graph = fraction.fst
|
||||
|
||||
measure_graph = MeasureFst(
|
||||
itn_cardinal_tagger=cardinal, itn_decimal_tagger=decimal, itn_fraction_tagger=fraction
|
||||
itn_cardinal_tagger=cardinal,
|
||||
itn_decimal_tagger=decimal,
|
||||
itn_fraction_tagger=fraction,
|
||||
).fst
|
||||
date_graph = DateFst(
|
||||
itn_cardinal_tagger=cardinal, tn_date_verbalizer=tn_date_verbalizer, tn_date_tagger=tn_date_tagger
|
||||
itn_cardinal_tagger=cardinal,
|
||||
tn_date_verbalizer=tn_date_verbalizer,
|
||||
tn_date_tagger=tn_date_tagger,
|
||||
).fst
|
||||
word_graph = WordFst().fst
|
||||
time_graph = TimeFst(tn_time_verbalizer=tn_time_verbalizer).fst
|
||||
@ -93,7 +120,8 @@ class ClassifyFst(GraphFst):
|
||||
whitelist_graph = WhiteListFst(tn_whitelist_tagger=tn_whitelist_tagger).fst
|
||||
punct_graph = PunctuationFst().fst
|
||||
electronic_graph = ElectronicFst(
|
||||
tn_electronic_tagger=tn_electronic_tagger, tn_electronic_verbalizer=tn_electronic_verbalizer
|
||||
tn_electronic_tagger=tn_electronic_tagger,
|
||||
tn_electronic_verbalizer=tn_electronic_verbalizer,
|
||||
).fst
|
||||
telephone_graph = TelephoneFst(tn_cardinal_tagger=tn_cardinal_tagger).fst
|
||||
|
||||
@ -112,10 +140,16 @@ class ClassifyFst(GraphFst):
|
||||
| pynutil.add_weight(word_graph, 100)
|
||||
)
|
||||
|
||||
punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=1.1) + pynutil.insert(" }")
|
||||
punct = (
|
||||
pynutil.insert("tokens { ")
|
||||
+ pynutil.add_weight(punct_graph, weight=1.1)
|
||||
+ pynutil.insert(" }")
|
||||
)
|
||||
token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }")
|
||||
token_plus_punct = (
|
||||
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
|
||||
pynini.closure(punct + pynutil.insert(" "))
|
||||
+ token
|
||||
+ pynini.closure(pynutil.insert(" ") + punct)
|
||||
)
|
||||
|
||||
graph = token_plus_punct + pynini.closure(delete_extra_space + token_plus_punct)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, convert_space
|
||||
from pynini.lib import pynutil
|
||||
@ -16,5 +15,5 @@ class WhiteListFst(GraphFst):
|
||||
super().__init__(name="whitelist", kind="classify", deterministic=deterministic)
|
||||
|
||||
whitelist = pynini.invert(tn_whitelist_tagger.graph)
|
||||
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
|
||||
graph = pynutil.insert('name: "') + convert_space(whitelist) + pynutil.insert('"')
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -16,7 +15,9 @@ class CardinalFst(GraphFst):
|
||||
def __init__(self, tn_cardinal_verbalizer: GraphFst, deterministic: bool = True):
|
||||
super().__init__(name="cardinal", kind="verbalize", deterministic=deterministic)
|
||||
self.numbers = tn_cardinal_verbalizer.numbers
|
||||
optional_sign = pynini.closure(pynutil.delete("negative: \"") + DAMO_NOT_QUOTE + pynutil.delete("\" "), 0, 1)
|
||||
optional_sign = pynini.closure(
|
||||
pynutil.delete('negative: "') + DAMO_NOT_QUOTE + pynutil.delete('" '), 0, 1
|
||||
)
|
||||
graph = optional_sign + self.numbers
|
||||
delete_tokens = self.delete_tokens(graph)
|
||||
self.fst = delete_tokens.optimize()
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_preserve_order
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_NOT_QUOTE,
|
||||
GraphFst,
|
||||
delete_preserve_order,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -17,13 +20,17 @@ class DecimalFst(GraphFst):
|
||||
super().__init__(name="decimal", kind="verbalize", deterministic=deterministic)
|
||||
delete_space = pynutil.delete(" ")
|
||||
optional_sign = pynini.closure(
|
||||
pynutil.delete("negative: \"") + DAMO_NOT_QUOTE + pynutil.delete("\"") + delete_space, 0, 1
|
||||
pynutil.delete('negative: "') + DAMO_NOT_QUOTE + pynutil.delete('"') + delete_space,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
optional_integer = pynini.closure(tn_decimal_verbalizer.integer, 0, 1)
|
||||
optional_fractional = pynini.closure(
|
||||
delete_space + pynutil.insert(",") + tn_decimal_verbalizer.fractional_default, 0, 1
|
||||
)
|
||||
graph = (optional_integer + optional_fractional + tn_decimal_verbalizer.optional_quantity).optimize()
|
||||
graph = (
|
||||
optional_integer + optional_fractional + tn_decimal_verbalizer.optional_quantity
|
||||
).optimize()
|
||||
self.numbers = optional_sign + graph
|
||||
graph = self.numbers + delete_preserve_order
|
||||
delete_tokens = self.delete_tokens(graph)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
|
||||
from pynini.lib import pynutil
|
||||
@ -18,13 +17,13 @@ class MeasureFst(GraphFst):
|
||||
|
||||
def __init__(self, decimal: GraphFst, cardinal: GraphFst, deterministic: bool = True):
|
||||
super().__init__(name="measure", kind="verbalize", deterministic=deterministic)
|
||||
optional_sign = pynini.closure(pynini.cross("negative: \"true\"", "-"), 0, 1)
|
||||
optional_sign = pynini.closure(pynini.cross('negative: "true"', "-"), 0, 1)
|
||||
unit = (
|
||||
pynutil.delete("units:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_CHAR - " ", 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ delete_space
|
||||
)
|
||||
graph_decimal = (
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
|
||||
from pynini.lib import pynutil
|
||||
@ -18,9 +17,9 @@ class MoneyFst(GraphFst):
|
||||
unit = (
|
||||
pynutil.delete("currency:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_CHAR - " ", 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
graph = unit + delete_space + decimal.numbers
|
||||
delete_tokens = self.delete_tokens(graph)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_ALPHA, DAMO_DIGIT, GraphFst, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_ALPHA,
|
||||
DAMO_DIGIT,
|
||||
GraphFst,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -9,26 +13,35 @@ class TimeFst(GraphFst):
|
||||
Finite state transducer for verbalizing time, e.g.
|
||||
time { hours: "8" minutes: "30" zone: "e s t" } -> 08:30 Uhr est
|
||||
time { hours: "8" } -> 8 Uhr
|
||||
time { hours: "8" minutes: "30" seconds: "10" } -> 08:30:10 Uhr
|
||||
time { hours: "8" minutes: "30" seconds: "10" } -> 08:30:10 Uhr
|
||||
"""
|
||||
|
||||
def __init__(self, deterministic: bool = True):
|
||||
super().__init__(name="time", kind="verbalize", deterministic=deterministic)
|
||||
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
|
||||
hour = pynutil.delete("hours: \"") + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete("\"")
|
||||
minute = pynutil.delete("minutes: \"") + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete("\"")
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
|
||||
pynutil.insert("0") + DAMO_DIGIT
|
||||
)
|
||||
hour = pynutil.delete('hours: "') + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete('"')
|
||||
minute = pynutil.delete('minutes: "') + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete('"')
|
||||
|
||||
second = pynutil.delete("seconds: \"") + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete("\"")
|
||||
second = pynutil.delete('seconds: "') + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete('"')
|
||||
zone = (
|
||||
pynutil.delete("zone: \"") + pynini.closure(DAMO_ALPHA + delete_space) + DAMO_ALPHA + pynutil.delete("\"")
|
||||
pynutil.delete('zone: "')
|
||||
+ pynini.closure(DAMO_ALPHA + delete_space)
|
||||
+ DAMO_ALPHA
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
optional_zone = pynini.closure(pynini.accep(" ") + zone, 0, 1)
|
||||
graph = (
|
||||
delete_space
|
||||
+ pynutil.insert(":")
|
||||
+ (minute @ add_leading_zero_to_double_digit)
|
||||
+ pynini.closure(delete_space + pynutil.insert(":") + (second @ add_leading_zero_to_double_digit), 0, 1)
|
||||
+ pynini.closure(
|
||||
delete_space + pynutil.insert(":") + (second @ add_leading_zero_to_double_digit),
|
||||
0,
|
||||
1,
|
||||
)
|
||||
+ pynutil.insert(" Uhr")
|
||||
+ optional_zone
|
||||
)
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
|
||||
from fun_text_processing.inverse_text_normalization.de.verbalizers.cardinal import CardinalFst
|
||||
from fun_text_processing.inverse_text_normalization.de.verbalizers.decimal import DecimalFst
|
||||
from fun_text_processing.inverse_text_normalization.de.verbalizers.measure import MeasureFst
|
||||
from fun_text_processing.inverse_text_normalization.de.verbalizers.money import MoneyFst
|
||||
from fun_text_processing.inverse_text_normalization.de.verbalizers.time import TimeFst
|
||||
from fun_text_processing.text_normalization.de.verbalizers.cardinal import CardinalFst as TNCardinalVerbalizer
|
||||
from fun_text_processing.text_normalization.de.verbalizers.decimal import DecimalFst as TNDecimalVerbalizer
|
||||
from fun_text_processing.text_normalization.de.verbalizers.cardinal import (
|
||||
CardinalFst as TNCardinalVerbalizer,
|
||||
)
|
||||
from fun_text_processing.text_normalization.de.verbalizers.decimal import (
|
||||
DecimalFst as TNDecimalVerbalizer,
|
||||
)
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
|
||||
|
||||
|
||||
class VerbalizeFst(GraphFst):
|
||||
"""
|
||||
Composes other verbalizer grammars.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||
"""
|
||||
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.de.verbalizers.verbalize import VerbalizeFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.word import WordFst
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, delete_extra_space, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
GraphFst,
|
||||
delete_extra_space,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
class VerbalizeFinalFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer that verbalizes an entire sentence, e.g.
|
||||
Finite state transducer that verbalizes an entire sentence, e.g.
|
||||
tokens { name: "jetzt" } tokens { name: "ist" } tokens { time { hours: "12" minutes: "30" } } -> jetzt ist 12:30 Uhr
|
||||
"""
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
|
||||
|
||||
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
|
||||
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import (
|
||||
ClassifyFst,
|
||||
)
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize import VerbalizeFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import (
|
||||
VerbalizeFinalFst,
|
||||
)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from typing import List
|
||||
|
||||
@ -55,7 +53,7 @@ class Filter:
|
||||
|
||||
Args:
|
||||
processes given instance with process function
|
||||
|
||||
|
||||
Returns: processed instance if instance belongs to expected class type or original instance
|
||||
"""
|
||||
if instance.token_type != self.class_type:
|
||||
@ -73,7 +71,9 @@ def process_cardinal_1(instance: Instance) -> Instance:
|
||||
normalized = instance.normalized
|
||||
un_normalized = re.sub(r"[^0-9]", "", un_normalized)
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_ordinal_1(instance: Instance) -> bool:
|
||||
@ -86,7 +86,9 @@ def process_ordinal_1(instance: Instance) -> Instance:
|
||||
normalized = instance.normalized
|
||||
un_normalized = re.sub(r"[,\s]", "", un_normalized)
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_decimal_1(instance: Instance) -> bool:
|
||||
@ -99,7 +101,9 @@ def process_decimal_1(instance: Instance) -> Instance:
|
||||
un_normalized = re.sub(r",", "", un_normalized)
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_measure_1(instance: Instance) -> bool:
|
||||
@ -116,7 +120,9 @@ def process_measure_1(instance: Instance) -> Instance:
|
||||
normalized = re.sub(r"[^a-z\s]", "", normalized)
|
||||
normalized = re.sub(r"per ([a-z\s]*)s$", r"per \1", normalized)
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_money_1(instance: Instance) -> bool:
|
||||
@ -133,7 +139,9 @@ def process_money_1(instance: Instance) -> Instance:
|
||||
un_normalized = re.sub(r"(\d)m\s*$", r"\1 million", un_normalized)
|
||||
un_normalized = re.sub(r"(\d)bn?\s*$", r"\1 billion", un_normalized)
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_time_1(instance: Instance) -> bool:
|
||||
@ -148,7 +156,9 @@ def process_time_1(instance: Instance) -> Instance:
|
||||
un_normalized = re.sub(r"(\d)\s?p\s?m\s?", r"\1 p.m.", un_normalized)
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_plain_1(instance: Instance) -> bool:
|
||||
@ -159,7 +169,9 @@ def filter_plain_1(instance: Instance) -> bool:
|
||||
def process_plain_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_punct_1(instance: Instance) -> bool:
|
||||
@ -170,7 +182,9 @@ def filter_punct_1(instance: Instance) -> bool:
|
||||
def process_punct_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_date_1(instance: Instance) -> bool:
|
||||
@ -183,7 +197,9 @@ def process_date_1(instance: Instance) -> Instance:
|
||||
un_normalized = re.sub(r",", "", un_normalized)
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_letters_1(instance: Instance) -> bool:
|
||||
@ -195,7 +211,9 @@ def process_letters_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_verbatim_1(instance: Instance) -> bool:
|
||||
@ -206,7 +224,9 @@ def filter_verbatim_1(instance: Instance) -> bool:
|
||||
def process_verbatim_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_digit_1(instance: Instance) -> bool:
|
||||
@ -218,7 +238,9 @@ def process_digit_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_telephone_1(instance: Instance) -> bool:
|
||||
@ -230,7 +252,9 @@ def process_telephone_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_electronic_1(instance: Instance) -> bool:
|
||||
@ -242,7 +266,9 @@ def process_electronic_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_fraction_1(instance: Instance) -> bool:
|
||||
@ -254,7 +280,9 @@ def process_fraction_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
def filter_address_1(instance: Instance) -> bool:
|
||||
@ -266,27 +294,51 @@ def process_address_1(instance: Instance) -> Instance:
|
||||
un_normalized = instance.un_normalized
|
||||
normalized = instance.normalized
|
||||
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||
return Instance(
|
||||
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
|
||||
)
|
||||
|
||||
|
||||
filters = []
|
||||
filters.append(Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1))
|
||||
filters.append(Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1))
|
||||
filters.append(Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1))
|
||||
filters.append(Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1))
|
||||
filters.append(
|
||||
Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1)
|
||||
)
|
||||
filters.append(
|
||||
Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1)
|
||||
)
|
||||
filters.append(
|
||||
Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1)
|
||||
)
|
||||
filters.append(
|
||||
Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1)
|
||||
)
|
||||
filters.append(Filter(class_type="MONEY", process_func=process_money_1, filter_func=filter_money_1))
|
||||
filters.append(Filter(class_type="TIME", process_func=process_time_1, filter_func=filter_time_1))
|
||||
|
||||
filters.append(Filter(class_type="DATE", process_func=process_date_1, filter_func=filter_date_1))
|
||||
filters.append(Filter(class_type="PLAIN", process_func=process_plain_1, filter_func=filter_plain_1))
|
||||
filters.append(Filter(class_type="PUNCT", process_func=process_punct_1, filter_func=filter_punct_1))
|
||||
filters.append(Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1))
|
||||
filters.append(Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1))
|
||||
filters.append(
|
||||
Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1)
|
||||
)
|
||||
filters.append(
|
||||
Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1)
|
||||
)
|
||||
filters.append(Filter(class_type="DIGIT", process_func=process_digit_1, filter_func=filter_digit_1))
|
||||
filters.append(Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1))
|
||||
filters.append(Filter(class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1))
|
||||
filters.append(Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1))
|
||||
filters.append(Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1))
|
||||
filters.append(
|
||||
Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1)
|
||||
)
|
||||
filters.append(
|
||||
Filter(
|
||||
class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1
|
||||
)
|
||||
)
|
||||
filters.append(
|
||||
Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1)
|
||||
)
|
||||
filters.append(
|
||||
Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1)
|
||||
)
|
||||
filters.append(Filter(class_type=EOS_TYPE, process_func=lambda x: x, filter_func=lambda x: True))
|
||||
|
||||
|
||||
@ -315,8 +367,10 @@ def filter_loaded_data(data: List[Instance], verbose: bool = False) -> List[Inst
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--input", help="input file path", type=str, default='./en_with_types/output-00001-of-00100')
|
||||
parser.add_argument("--verbose", help="print filtered instances", action='store_true')
|
||||
parser.add_argument(
|
||||
"--input", help="input file path", type=str, default="./en_with_types/output-00001-of-00100"
|
||||
)
|
||||
parser.add_argument("--verbose", help="print filtered instances", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path, num_to_word
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
@ -17,7 +15,7 @@ class CardinalFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for classifying cardinals
|
||||
e.g. minus twenty three -> cardinal { integer: "23" negative: "-" } }
|
||||
Numbers below thirteen are not converted.
|
||||
Numbers below thirteen are not converted.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -29,7 +27,9 @@ class CardinalFst(GraphFst):
|
||||
|
||||
graph_hundred = pynini.cross("hundred", "")
|
||||
|
||||
graph_hundred_component = pynini.union(graph_digit + delete_space + graph_hundred, pynutil.insert("0"))
|
||||
graph_hundred_component = pynini.union(
|
||||
graph_digit + delete_space + graph_hundred, pynutil.insert("0")
|
||||
)
|
||||
graph_hundred_component += delete_space
|
||||
graph_hundred_component += pynini.union(
|
||||
graph_teen | pynutil.insert("00"),
|
||||
@ -44,32 +44,46 @@ class CardinalFst(GraphFst):
|
||||
)
|
||||
|
||||
graph_thousands = pynini.union(
|
||||
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("thousand"),
|
||||
graph_hundred_component_at_least_one_none_zero_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete("thousand"),
|
||||
pynutil.insert("000", weight=0.1),
|
||||
)
|
||||
|
||||
graph_million = pynini.union(
|
||||
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("million"),
|
||||
graph_hundred_component_at_least_one_none_zero_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete("million"),
|
||||
pynutil.insert("000", weight=0.1),
|
||||
)
|
||||
graph_billion = pynini.union(
|
||||
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("billion"),
|
||||
graph_hundred_component_at_least_one_none_zero_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete("billion"),
|
||||
pynutil.insert("000", weight=0.1),
|
||||
)
|
||||
graph_trillion = pynini.union(
|
||||
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("trillion"),
|
||||
graph_hundred_component_at_least_one_none_zero_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete("trillion"),
|
||||
pynutil.insert("000", weight=0.1),
|
||||
)
|
||||
graph_quadrillion = pynini.union(
|
||||
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("quadrillion"),
|
||||
graph_hundred_component_at_least_one_none_zero_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete("quadrillion"),
|
||||
pynutil.insert("000", weight=0.1),
|
||||
)
|
||||
graph_quintillion = pynini.union(
|
||||
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("quintillion"),
|
||||
graph_hundred_component_at_least_one_none_zero_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete("quintillion"),
|
||||
pynutil.insert("000", weight=0.1),
|
||||
)
|
||||
graph_sextillion = pynini.union(
|
||||
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("sextillion"),
|
||||
graph_hundred_component_at_least_one_none_zero_digit
|
||||
+ delete_space
|
||||
+ pynutil.delete("sextillion"),
|
||||
pynutil.insert("000", weight=0.1),
|
||||
)
|
||||
|
||||
@ -93,7 +107,10 @@ class CardinalFst(GraphFst):
|
||||
)
|
||||
|
||||
graph = graph @ pynini.union(
|
||||
pynutil.delete(pynini.closure("0")) + pynini.difference(DAMO_DIGIT, "0") + pynini.closure(DAMO_DIGIT), "0"
|
||||
pynutil.delete(pynini.closure("0"))
|
||||
+ pynini.difference(DAMO_DIGIT, "0")
|
||||
+ pynini.closure(DAMO_DIGIT),
|
||||
"0",
|
||||
)
|
||||
|
||||
labels_exception = [num_to_word(x) for x in range(0, 13)]
|
||||
@ -110,10 +127,12 @@ class CardinalFst(GraphFst):
|
||||
self.graph = (pynini.project(graph, "input") - graph_exception.arcsort()) @ graph
|
||||
|
||||
optional_minus_graph = pynini.closure(
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", "\"-\"") + DAMO_SPACE, 0, 1
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", '"-"') + DAMO_SPACE, 0, 1
|
||||
)
|
||||
|
||||
final_graph = optional_minus_graph + pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
|
||||
final_graph = (
|
||||
optional_minus_graph + pynutil.insert('integer: "') + self.graph + pynutil.insert('"')
|
||||
)
|
||||
|
||||
final_graph = self.add_tokens(final_graph)
|
||||
self.fst = final_graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
@ -63,7 +62,9 @@ def _get_year_graph():
|
||||
|
||||
def _get_thousands_graph():
|
||||
graph_ties = _get_ties_graph()
|
||||
graph_hundred_component = (graph_digit + delete_space + pynutil.delete("hundred")) | pynutil.insert("0")
|
||||
graph_hundred_component = (
|
||||
graph_digit + delete_space + pynutil.delete("hundred")
|
||||
) | pynutil.insert("0")
|
||||
graph = (
|
||||
graph_digit
|
||||
+ delete_space
|
||||
@ -90,7 +91,7 @@ def _get_year_graph():
|
||||
|
||||
class DateFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for classifying date,
|
||||
Finite state transducer for classifying date,
|
||||
e.g. january fifth twenty twelve -> date { month: "january" day: "5" year: "2012" preserve_order: true }
|
||||
e.g. the fifth of january twenty twelve -> date { day: "5" month: "january" year: "2012" preserve_order: true }
|
||||
e.g. twenty twenty -> date { year: "2012" preserve_order: true }
|
||||
@ -108,18 +109,26 @@ class DateFst(GraphFst):
|
||||
year_graph = pynutil.add_weight(year_graph, YEAR_WEIGHT)
|
||||
month_graph = _get_month_graph()
|
||||
|
||||
month_graph = pynutil.insert("month: \"") + month_graph + pynutil.insert("\"")
|
||||
month_graph = pynutil.insert('month: "') + month_graph + pynutil.insert('"')
|
||||
|
||||
day_graph = pynutil.insert("day: \"") + pynutil.add_weight(ordinal_graph, -0.7) + pynutil.insert("\"")
|
||||
day_graph = (
|
||||
pynutil.insert('day: "') + pynutil.add_weight(ordinal_graph, -0.7) + pynutil.insert('"')
|
||||
)
|
||||
graph_year = (
|
||||
delete_extra_space
|
||||
+ pynutil.insert("year: \"")
|
||||
+ pynutil.insert('year: "')
|
||||
+ pynutil.add_weight(year_graph, -YEAR_WEIGHT)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
optional_graph_year = pynini.closure(
|
||||
graph_year,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
optional_graph_year = pynini.closure(graph_year, 0, 1,)
|
||||
graph_mdy = month_graph + (
|
||||
(delete_extra_space + day_graph) | graph_year | (delete_extra_space + day_graph + graph_year)
|
||||
(delete_extra_space + day_graph)
|
||||
| graph_year
|
||||
| (delete_extra_space + day_graph + graph_year)
|
||||
)
|
||||
graph_dmy = (
|
||||
pynutil.delete("the")
|
||||
@ -131,7 +140,9 @@ class DateFst(GraphFst):
|
||||
+ month_graph
|
||||
+ optional_graph_year
|
||||
)
|
||||
graph_year = pynutil.insert("year: \"") + (year_graph | _get_range_graph()) + pynutil.insert("\"")
|
||||
graph_year = (
|
||||
pynutil.insert('year: "') + (year_graph | _get_range_graph()) + pynutil.insert('"')
|
||||
)
|
||||
|
||||
final_graph = graph_mdy | graph_dmy | graph_year
|
||||
final_graph += pynutil.insert(" preserve_order: true")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
@ -10,30 +9,42 @@ from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
def get_quantity(decimal: 'pynini.FstLike', cardinal_up_to_hundred: 'pynini.FstLike') -> 'pynini.FstLike':
|
||||
def get_quantity(
|
||||
decimal: "pynini.FstLike", cardinal_up_to_hundred: "pynini.FstLike"
|
||||
) -> "pynini.FstLike":
|
||||
"""
|
||||
Returns FST that transforms either a cardinal or decimal followed by a quantity into a numeral,
|
||||
e.g. one million -> integer_part: "1" quantity: "million"
|
||||
e.g. one point five million -> integer_part: "1" fractional_part: "5" quantity: "million"
|
||||
|
||||
Args:
|
||||
Args:
|
||||
decimal: decimal FST
|
||||
cardinal_up_to_hundred: cardinal FST
|
||||
"""
|
||||
numbers = cardinal_up_to_hundred @ (
|
||||
pynutil.delete(pynini.closure("0")) + pynini.difference(DAMO_DIGIT, "0") + pynini.closure(DAMO_DIGIT)
|
||||
pynutil.delete(pynini.closure("0"))
|
||||
+ pynini.difference(DAMO_DIGIT, "0")
|
||||
+ pynini.closure(DAMO_DIGIT)
|
||||
)
|
||||
suffix = pynini.union(
|
||||
"million", "billion", "trillion", "quadrillion", "quintillion", "sextillion"
|
||||
)
|
||||
suffix = pynini.union("million", "billion", "trillion", "quadrillion", "quintillion", "sextillion")
|
||||
res = (
|
||||
pynutil.insert("integer_part: \"")
|
||||
pynutil.insert('integer_part: "')
|
||||
+ numbers
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ delete_extra_space
|
||||
+ pynutil.insert("quantity: \"")
|
||||
+ pynutil.insert('quantity: "')
|
||||
+ suffix
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
res |= (
|
||||
decimal
|
||||
+ delete_extra_space
|
||||
+ pynutil.insert('quantity: "')
|
||||
+ (suffix | "thousand")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
res |= decimal + delete_extra_space + pynutil.insert("quantity: \"") + (suffix | "thousand") + pynutil.insert("\"")
|
||||
return res
|
||||
|
||||
|
||||
@ -52,7 +63,9 @@ class DecimalFst(GraphFst):
|
||||
cardinal_graph = cardinal.graph_no_exception
|
||||
|
||||
graph_decimal = pynini.string_file(get_abs_path("data/numbers/digit.tsv"))
|
||||
graph_decimal |= pynini.string_file(get_abs_path("data/numbers/zero.tsv")) | pynini.cross("o", "0")
|
||||
graph_decimal |= pynini.string_file(get_abs_path("data/numbers/zero.tsv")) | pynini.cross(
|
||||
"o", "0"
|
||||
)
|
||||
|
||||
graph_decimal = pynini.closure(graph_decimal + delete_space) + graph_decimal
|
||||
self.graph = graph_decimal
|
||||
@ -60,13 +73,20 @@ class DecimalFst(GraphFst):
|
||||
point = pynutil.delete("point")
|
||||
|
||||
optional_graph_negative = pynini.closure(
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", "\"true\"") + delete_extra_space, 0, 1
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", '"true"') + delete_extra_space,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
|
||||
graph_fractional = pynutil.insert("fractional_part: \"") + graph_decimal + pynutil.insert("\"")
|
||||
graph_integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"")
|
||||
graph_fractional = (
|
||||
pynutil.insert('fractional_part: "') + graph_decimal + pynutil.insert('"')
|
||||
)
|
||||
graph_integer = pynutil.insert('integer_part: "') + cardinal_graph + pynutil.insert('"')
|
||||
final_graph_wo_sign = (
|
||||
pynini.closure(graph_integer + delete_extra_space, 0, 1) + point + delete_extra_space + graph_fractional
|
||||
pynini.closure(graph_integer + delete_extra_space, 0, 1)
|
||||
+ point
|
||||
+ delete_extra_space
|
||||
+ graph_fractional
|
||||
)
|
||||
final_graph = optional_graph_negative + final_graph_wo_sign
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_ALPHA, GraphFst, insert_space
|
||||
@ -25,35 +24,50 @@ class ElectronicFst(GraphFst):
|
||||
|
||||
accepted_username = alpha_num | symbols
|
||||
process_dot = pynini.cross("dot", ".")
|
||||
username = (alpha_num + pynini.closure(delete_extra_space + accepted_username)) | pynutil.add_weight(
|
||||
pynini.closure(DAMO_ALPHA, 1), weight=0.0001
|
||||
)
|
||||
username = pynutil.insert("username: \"") + username + pynutil.insert("\"")
|
||||
username = (
|
||||
alpha_num + pynini.closure(delete_extra_space + accepted_username)
|
||||
) | pynutil.add_weight(pynini.closure(DAMO_ALPHA, 1), weight=0.0001)
|
||||
username = pynutil.insert('username: "') + username + pynutil.insert('"')
|
||||
single_alphanum = pynini.closure(alpha_num + delete_extra_space) + alpha_num
|
||||
server = single_alphanum | pynini.string_file(get_abs_path("data/electronic/server_name.tsv"))
|
||||
server = single_alphanum | pynini.string_file(
|
||||
get_abs_path("data/electronic/server_name.tsv")
|
||||
)
|
||||
domain = single_alphanum | pynini.string_file(get_abs_path("data/electronic/domain.tsv"))
|
||||
domain_graph = (
|
||||
pynutil.insert("domain: \"")
|
||||
pynutil.insert('domain: "')
|
||||
+ server
|
||||
+ delete_extra_space
|
||||
+ process_dot
|
||||
+ delete_extra_space
|
||||
+ domain
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
graph = (
|
||||
username
|
||||
+ delete_extra_space
|
||||
+ pynutil.delete("at")
|
||||
+ insert_space
|
||||
+ delete_extra_space
|
||||
+ domain_graph
|
||||
)
|
||||
graph = username + delete_extra_space + pynutil.delete("at") + insert_space + delete_extra_space + domain_graph
|
||||
|
||||
############# url ###
|
||||
protocol_end = pynini.cross(pynini.union("w w w", "www"), "www")
|
||||
protocol_start = (pynini.cross("h t t p", "http") | pynini.cross("h t t p s", "https")) + pynini.cross(
|
||||
" colon slash slash ", "://"
|
||||
)
|
||||
protocol_start = (
|
||||
pynini.cross("h t t p", "http") | pynini.cross("h t t p s", "https")
|
||||
) + pynini.cross(" colon slash slash ", "://")
|
||||
# .com,
|
||||
ending = (
|
||||
delete_extra_space
|
||||
+ symbols
|
||||
+ delete_extra_space
|
||||
+ (domain | pynini.closure(accepted_username + delete_extra_space,) + accepted_username)
|
||||
+ (
|
||||
domain
|
||||
| pynini.closure(
|
||||
accepted_username + delete_extra_space,
|
||||
)
|
||||
+ accepted_username
|
||||
)
|
||||
)
|
||||
|
||||
protocol_default = (
|
||||
@ -64,12 +78,18 @@ class ElectronicFst(GraphFst):
|
||||
+ pynini.closure(ending, 1)
|
||||
).optimize()
|
||||
protocol = (
|
||||
pynini.closure(protocol_start, 0, 1) + protocol_end + delete_extra_space + process_dot + protocol_default
|
||||
pynini.closure(protocol_start, 0, 1)
|
||||
+ protocol_end
|
||||
+ delete_extra_space
|
||||
+ process_dot
|
||||
+ protocol_default
|
||||
).optimize()
|
||||
|
||||
protocol |= pynini.closure(protocol_end + delete_extra_space + process_dot, 0, 1) + protocol_default
|
||||
protocol |= (
|
||||
pynini.closure(protocol_end + delete_extra_space + process_dot, 0, 1) + protocol_default
|
||||
)
|
||||
|
||||
protocol = pynutil.insert("protocol: \"") + protocol.optimize() + pynutil.insert("\"")
|
||||
protocol = pynutil.insert('protocol: "') + protocol.optimize() + pynutil.insert('"')
|
||||
graph |= protocol
|
||||
########
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
@ -32,22 +31,37 @@ class MeasureFst(GraphFst):
|
||||
graph_unit_plural = get_singulars(graph_unit_singular) # plural -> abbr
|
||||
|
||||
optional_graph_negative = pynini.closure(
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", "\"true\"") + delete_extra_space, 0, 1
|
||||
pynutil.insert("negative: ") + pynini.cross("minus", '"true"') + delete_extra_space,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
|
||||
unit_singular = convert_space(graph_unit_singular)
|
||||
unit_plural = convert_space(graph_unit_plural)
|
||||
unit_misc = pynutil.insert("/") + pynutil.delete("per") + delete_space + convert_space(graph_unit_singular)
|
||||
unit_misc = (
|
||||
pynutil.insert("/")
|
||||
+ pynutil.delete("per")
|
||||
+ delete_space
|
||||
+ convert_space(graph_unit_singular)
|
||||
)
|
||||
|
||||
unit_singular = (
|
||||
pynutil.insert("units: \"")
|
||||
+ (unit_singular | unit_misc | pynutil.add_weight(unit_singular + delete_space + unit_misc, 0.01))
|
||||
+ pynutil.insert("\"")
|
||||
pynutil.insert('units: "')
|
||||
+ (
|
||||
unit_singular
|
||||
| unit_misc
|
||||
| pynutil.add_weight(unit_singular + delete_space + unit_misc, 0.01)
|
||||
)
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
unit_plural = (
|
||||
pynutil.insert("units: \"")
|
||||
+ (unit_plural | unit_misc | pynutil.add_weight(unit_plural + delete_space + unit_misc, 0.01))
|
||||
+ pynutil.insert("\"")
|
||||
pynutil.insert('units: "')
|
||||
+ (
|
||||
unit_plural
|
||||
| unit_misc
|
||||
| pynutil.add_weight(unit_plural + delete_space + unit_misc, 0.01)
|
||||
)
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
subgraph_decimal = (
|
||||
@ -61,9 +75,9 @@ class MeasureFst(GraphFst):
|
||||
subgraph_cardinal = (
|
||||
pynutil.insert("cardinal { ")
|
||||
+ optional_graph_negative
|
||||
+ pynutil.insert("integer: \"")
|
||||
+ pynutil.insert('integer: "')
|
||||
+ ((DAMO_SIGMA - "one") @ cardinal_graph)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ pynutil.insert(" }")
|
||||
+ delete_extra_space
|
||||
+ unit_plural
|
||||
@ -71,9 +85,9 @@ class MeasureFst(GraphFst):
|
||||
subgraph_cardinal |= (
|
||||
pynutil.insert("cardinal { ")
|
||||
+ optional_graph_negative
|
||||
+ pynutil.insert("integer: \"")
|
||||
+ pynutil.insert('integer: "')
|
||||
+ pynini.cross("one", "1")
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ pynutil.insert(" }")
|
||||
+ delete_extra_space
|
||||
+ unit_singular
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
@ -33,8 +32,11 @@ class MoneyFst(GraphFst):
|
||||
# add support for missing hundred (only for 3 digit numbers)
|
||||
# "one fifty" -> "one hundred fifty"
|
||||
with_hundred = pynini.compose(
|
||||
pynini.closure(DAMO_NOT_SPACE) + pynini.accep(" ") + pynutil.insert("hundred ") + DAMO_SIGMA,
|
||||
pynini.compose(cardinal_graph, DAMO_DIGIT ** 3),
|
||||
pynini.closure(DAMO_NOT_SPACE)
|
||||
+ pynini.accep(" ")
|
||||
+ pynutil.insert("hundred ")
|
||||
+ DAMO_SIGMA,
|
||||
pynini.compose(cardinal_graph, DAMO_DIGIT**3),
|
||||
)
|
||||
cardinal_graph |= with_hundred
|
||||
graph_decimal_final = decimal.final_graph_wo_negative
|
||||
@ -43,20 +45,27 @@ class MoneyFst(GraphFst):
|
||||
unit_singular = pynini.invert(unit)
|
||||
unit_plural = get_singulars(unit_singular)
|
||||
|
||||
graph_unit_singular = pynutil.insert("currency: \"") + convert_space(unit_singular) + pynutil.insert("\"")
|
||||
graph_unit_plural = pynutil.insert("currency: \"") + convert_space(unit_plural) + pynutil.insert("\"")
|
||||
graph_unit_singular = (
|
||||
pynutil.insert('currency: "') + convert_space(unit_singular) + pynutil.insert('"')
|
||||
)
|
||||
graph_unit_plural = (
|
||||
pynutil.insert('currency: "') + convert_space(unit_plural) + pynutil.insert('"')
|
||||
)
|
||||
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
|
||||
pynutil.insert("0") + DAMO_DIGIT
|
||||
)
|
||||
# twelve dollars (and) fifty cents, zero cents
|
||||
cents_standalone = (
|
||||
pynutil.insert("fractional_part: \"")
|
||||
pynutil.insert('fractional_part: "')
|
||||
+ pynini.union(
|
||||
pynutil.add_weight(((DAMO_SIGMA - "one") @ cardinal_graph), -0.7) @ add_leading_zero_to_double_digit
|
||||
pynutil.add_weight(((DAMO_SIGMA - "one") @ cardinal_graph), -0.7)
|
||||
@ add_leading_zero_to_double_digit
|
||||
+ delete_space
|
||||
+ (pynutil.delete("cents") | pynutil.delete("cent")),
|
||||
pynini.cross("one", "01") + delete_space + pynutil.delete("cent"),
|
||||
)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
optional_cents_standalone = pynini.closure(
|
||||
@ -70,31 +79,31 @@ class MoneyFst(GraphFst):
|
||||
# twelve dollars fifty, only after integer
|
||||
optional_cents_suffix = pynini.closure(
|
||||
delete_extra_space
|
||||
+ pynutil.insert("fractional_part: \"")
|
||||
+ pynutil.insert('fractional_part: "')
|
||||
+ pynutil.add_weight(cardinal_graph @ add_leading_zero_to_double_digit, -0.7)
|
||||
+ pynutil.insert("\""),
|
||||
+ pynutil.insert('"'),
|
||||
0,
|
||||
1,
|
||||
)
|
||||
|
||||
graph_integer = (
|
||||
pynutil.insert("integer_part: \"")
|
||||
pynutil.insert('integer_part: "')
|
||||
+ ((DAMO_SIGMA - "one") @ cardinal_graph)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ delete_extra_space
|
||||
+ graph_unit_plural
|
||||
+ (optional_cents_standalone | optional_cents_suffix)
|
||||
)
|
||||
graph_integer |= (
|
||||
pynutil.insert("integer_part: \"")
|
||||
pynutil.insert('integer_part: "')
|
||||
+ pynini.cross("one", "1")
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ delete_extra_space
|
||||
+ graph_unit_singular
|
||||
+ (optional_cents_standalone | optional_cents_suffix)
|
||||
)
|
||||
graph_decimal = graph_decimal_final + delete_extra_space + graph_unit_plural
|
||||
graph_decimal |= pynutil.insert("currency: \"$\" integer_part: \"0\" ") + cents_standalone
|
||||
graph_decimal |= pynutil.insert('currency: "$" integer_part: "0" ') + cents_standalone
|
||||
final_graph = graph_integer | graph_decimal
|
||||
|
||||
final_graph = self.add_tokens(final_graph)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst
|
||||
@ -25,6 +24,6 @@ class OrdinalFst(GraphFst):
|
||||
)
|
||||
|
||||
self.graph = graph @ cardinal_graph
|
||||
final_graph = pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
|
||||
final_graph = pynutil.insert('integer: "') + self.graph + pynutil.insert('"')
|
||||
final_graph = self.add_tokens(final_graph)
|
||||
self.fst = final_graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -13,9 +12,9 @@ class PunctuationFst(GraphFst):
|
||||
def __init__(self):
|
||||
super().__init__(name="punctuation", kind="classify")
|
||||
|
||||
s = "!#$%&\'()*+,-./:;<=>?@^_`{|}~"
|
||||
s = "!#$%&'()*+,-./:;<=>?@^_`{|}~"
|
||||
punct = pynini.union(*s)
|
||||
|
||||
graph = pynutil.insert("name: \"") + punct + pynutil.insert("\"")
|
||||
graph = pynutil.insert('name: "') + punct + pynutil.insert('"')
|
||||
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
@ -24,7 +23,7 @@ def get_serial_number(cardinal):
|
||||
|
||||
class TelephoneFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for classifying telephone numbers, e.g.
|
||||
Finite state transducer for classifying telephone numbers, e.g.
|
||||
one two three one two three five six seven eight -> { number_part: "123-123-5678" }
|
||||
|
||||
This class also support card number and IP format.
|
||||
@ -61,61 +60,89 @@ class TelephoneFst(GraphFst):
|
||||
double_digit.invert()
|
||||
|
||||
# to handle cases like "one twenty three"
|
||||
two_digit_cardinal = pynini.compose(cardinal.graph_no_exception, DAMO_DIGIT ** 2)
|
||||
two_digit_cardinal = pynini.compose(cardinal.graph_no_exception, DAMO_DIGIT**2)
|
||||
double_digit_to_digit = (
|
||||
pynini.compose(double_digit, str_to_digit + pynutil.delete(" ") + str_to_digit) | two_digit_cardinal
|
||||
pynini.compose(double_digit, str_to_digit + pynutil.delete(" ") + str_to_digit)
|
||||
| two_digit_cardinal
|
||||
)
|
||||
|
||||
single_or_double_digit = (pynutil.add_weight(double_digit_to_digit, -0.0001) | str_to_digit).optimize()
|
||||
single_or_double_digit = (
|
||||
pynutil.add_weight(double_digit_to_digit, -0.0001) | str_to_digit
|
||||
).optimize()
|
||||
single_or_double_digit |= (
|
||||
single_or_double_digit
|
||||
+ pynini.closure(pynutil.add_weight(pynutil.delete(" ") + single_or_double_digit, 0.0001))
|
||||
+ pynini.closure(
|
||||
pynutil.add_weight(pynutil.delete(" ") + single_or_double_digit, 0.0001)
|
||||
)
|
||||
).optimize()
|
||||
|
||||
number_part = pynini.compose(
|
||||
single_or_double_digit,
|
||||
DAMO_DIGIT ** 3 + pynutil.insert("-") + DAMO_DIGIT ** 3 + pynutil.insert("-") + DAMO_DIGIT ** 4,
|
||||
DAMO_DIGIT**3
|
||||
+ pynutil.insert("-")
|
||||
+ DAMO_DIGIT**3
|
||||
+ pynutil.insert("-")
|
||||
+ DAMO_DIGIT**4,
|
||||
).optimize()
|
||||
number_part = pynutil.insert("number_part: \"") + number_part.optimize() + pynutil.insert("\"")
|
||||
number_part = (
|
||||
pynutil.insert('number_part: "') + number_part.optimize() + pynutil.insert('"')
|
||||
)
|
||||
|
||||
cardinal_option = pynini.compose(single_or_double_digit, DAMO_DIGIT ** (2, 3))
|
||||
|
||||
country_code = (
|
||||
pynutil.insert("country_code: \"")
|
||||
pynutil.insert('country_code: "')
|
||||
+ pynini.closure(pynini.cross("plus ", "+"), 0, 1)
|
||||
+ ((pynini.closure(str_to_digit + pynutil.delete(" "), 0, 2) + str_to_digit) | cardinal_option)
|
||||
+ pynutil.insert("\"")
|
||||
+ (
|
||||
(pynini.closure(str_to_digit + pynutil.delete(" "), 0, 2) + str_to_digit)
|
||||
| cardinal_option
|
||||
)
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
optional_country_code = pynini.closure(country_code + pynutil.delete(" ") + insert_space, 0, 1).optimize()
|
||||
optional_country_code = pynini.closure(
|
||||
country_code + pynutil.delete(" ") + insert_space, 0, 1
|
||||
).optimize()
|
||||
graph = optional_country_code + number_part
|
||||
|
||||
# credit card number
|
||||
space_four_digits = insert_space + DAMO_DIGIT ** 4
|
||||
credit_card_graph = pynini.compose(single_or_double_digit, DAMO_DIGIT ** 4 + space_four_digits ** 3).optimize()
|
||||
graph |= pynutil.insert("number_part: \"") + credit_card_graph.optimize() + pynutil.insert("\"")
|
||||
space_four_digits = insert_space + DAMO_DIGIT**4
|
||||
credit_card_graph = pynini.compose(
|
||||
single_or_double_digit, DAMO_DIGIT**4 + space_four_digits**3
|
||||
).optimize()
|
||||
graph |= (
|
||||
pynutil.insert('number_part: "') + credit_card_graph.optimize() + pynutil.insert('"')
|
||||
)
|
||||
|
||||
# SSN
|
||||
ssn_graph = pynini.compose(
|
||||
single_or_double_digit,
|
||||
DAMO_DIGIT ** 3 + pynutil.insert("-") + DAMO_DIGIT ** 2 + pynutil.insert("-") + DAMO_DIGIT ** 4,
|
||||
DAMO_DIGIT**3
|
||||
+ pynutil.insert("-")
|
||||
+ DAMO_DIGIT**2
|
||||
+ pynutil.insert("-")
|
||||
+ DAMO_DIGIT**4,
|
||||
).optimize()
|
||||
graph |= pynutil.insert("number_part: \"") + ssn_graph.optimize() + pynutil.insert("\"")
|
||||
graph |= pynutil.insert('number_part: "') + ssn_graph.optimize() + pynutil.insert('"')
|
||||
|
||||
# ip
|
||||
digit_or_double = pynini.closure(str_to_digit + pynutil.delete(" "), 0, 1) + double_digit_to_digit
|
||||
digit_or_double |= double_digit_to_digit + pynini.closure(pynutil.delete(" ") + str_to_digit, 0, 1)
|
||||
digit_or_double = (
|
||||
pynini.closure(str_to_digit + pynutil.delete(" "), 0, 1) + double_digit_to_digit
|
||||
)
|
||||
digit_or_double |= double_digit_to_digit + pynini.closure(
|
||||
pynutil.delete(" ") + str_to_digit, 0, 1
|
||||
)
|
||||
digit_or_double |= str_to_digit + (pynutil.delete(" ") + str_to_digit) ** (0, 2)
|
||||
digit_or_double |= cardinal_option
|
||||
digit_or_double = digit_or_double.optimize()
|
||||
|
||||
ip_graph = digit_or_double + (pynini.cross(" dot ", ".") + digit_or_double) ** 3
|
||||
|
||||
graph |= pynutil.insert("number_part: \"") + ip_graph.optimize() + pynutil.insert("\"")
|
||||
graph |= pynutil.insert('number_part: "') + ip_graph.optimize() + pynutil.insert('"')
|
||||
graph |= (
|
||||
pynutil.insert("number_part: \"")
|
||||
pynutil.insert('number_part: "')
|
||||
+ pynutil.add_weight(get_serial_number(cardinal=cardinal), weight=0.0001)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
final_graph = self.add_tokens(graph)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.taggers.cardinal import CardinalFst
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path, num_to_word
|
||||
@ -47,21 +45,23 @@ class TimeFst(GraphFst):
|
||||
graph_minute_verbose = pynini.cross("half", "30") | pynini.cross("quarter", "15")
|
||||
oclock = pynini.cross(pynini.union("o' clock", "o clock", "o'clock", "oclock"), "")
|
||||
|
||||
final_graph_hour = pynutil.insert("hours: \"") + graph_hour + pynutil.insert("\"")
|
||||
final_graph_hour = pynutil.insert('hours: "') + graph_hour + pynutil.insert('"')
|
||||
graph_minute = (
|
||||
oclock + pynutil.insert("00")
|
||||
| pynutil.delete("o") + delete_space + graph_minute_single
|
||||
| graph_minute_double
|
||||
)
|
||||
final_suffix = pynutil.insert("suffix: \"") + convert_space(suffix_graph) + pynutil.insert("\"")
|
||||
final_suffix = (
|
||||
pynutil.insert('suffix: "') + convert_space(suffix_graph) + pynutil.insert('"')
|
||||
)
|
||||
final_suffix = delete_space + insert_space + final_suffix
|
||||
final_suffix_optional = pynini.closure(final_suffix, 0, 1)
|
||||
final_time_zone_optional = pynini.closure(
|
||||
delete_space
|
||||
+ insert_space
|
||||
+ pynutil.insert("zone: \"")
|
||||
+ pynutil.insert('zone: "')
|
||||
+ convert_space(time_zone_graph)
|
||||
+ pynutil.insert("\""),
|
||||
+ pynutil.insert('"'),
|
||||
0,
|
||||
1,
|
||||
)
|
||||
@ -70,13 +70,17 @@ class TimeFst(GraphFst):
|
||||
# two o eight, two thirty five (am/pm)
|
||||
# two pm/am
|
||||
graph_hm = (
|
||||
final_graph_hour + delete_extra_space + pynutil.insert("minutes: \"") + graph_minute + pynutil.insert("\"")
|
||||
final_graph_hour
|
||||
+ delete_extra_space
|
||||
+ pynutil.insert('minutes: "')
|
||||
+ graph_minute
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
# 10 past four, quarter past four, half past four
|
||||
graph_m_past_h = (
|
||||
pynutil.insert("minutes: \"")
|
||||
pynutil.insert('minutes: "')
|
||||
+ pynini.union(graph_minute_single, graph_minute_double, graph_minute_verbose)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ delete_space
|
||||
+ pynutil.delete("past")
|
||||
+ delete_extra_space
|
||||
@ -84,42 +88,48 @@ class TimeFst(GraphFst):
|
||||
)
|
||||
|
||||
graph_quarter_time = (
|
||||
pynutil.insert("minutes: \"")
|
||||
pynutil.insert('minutes: "')
|
||||
+ pynini.cross("quarter", "45")
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ delete_space
|
||||
+ pynutil.delete(pynini.union("to", "till"))
|
||||
+ delete_extra_space
|
||||
+ pynutil.insert("hours: \"")
|
||||
+ pynutil.insert('hours: "')
|
||||
+ to_hour_graph
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
)
|
||||
|
||||
graph_m_to_h_suffix_time = (
|
||||
pynutil.insert("minutes: \"")
|
||||
pynutil.insert('minutes: "')
|
||||
+ ((graph_minute_single | graph_minute_double).optimize() @ minute_to_graph)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynini.closure(delete_space + pynutil.delete(pynini.union("min", "mins", "minute", "minutes")), 0, 1)
|
||||
+ pynutil.insert('"')
|
||||
+ pynini.closure(
|
||||
delete_space + pynutil.delete(pynini.union("min", "mins", "minute", "minutes")),
|
||||
0,
|
||||
1,
|
||||
)
|
||||
+ delete_space
|
||||
+ pynutil.delete(pynini.union("to", "till"))
|
||||
+ delete_extra_space
|
||||
+ pynutil.insert("hours: \"")
|
||||
+ pynutil.insert('hours: "')
|
||||
+ to_hour_graph
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ final_suffix
|
||||
)
|
||||
|
||||
graph_h = (
|
||||
final_graph_hour
|
||||
+ delete_extra_space
|
||||
+ pynutil.insert("minutes: \"")
|
||||
+ pynutil.insert('minutes: "')
|
||||
+ (pynutil.insert("00") | graph_minute)
|
||||
+ pynutil.insert("\"")
|
||||
+ pynutil.insert('"')
|
||||
+ final_suffix
|
||||
+ final_time_zone_optional
|
||||
)
|
||||
final_graph = (
|
||||
(graph_hm | graph_m_past_h | graph_quarter_time) + final_suffix_optional + final_time_zone_optional
|
||||
(graph_hm | graph_m_past_h | graph_quarter_time)
|
||||
+ final_suffix_optional
|
||||
+ final_time_zone_optional
|
||||
)
|
||||
final_graph |= graph_h
|
||||
final_graph |= graph_m_to_h_suffix_time
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
|
||||
import pynini
|
||||
@ -28,7 +27,7 @@ import logging
|
||||
class ClassifyFst(GraphFst):
|
||||
"""
|
||||
Final class that composes all other classification grammars. This class can process an entire sentence, that is lower cased.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||
|
||||
Args:
|
||||
@ -81,10 +80,16 @@ class ClassifyFst(GraphFst):
|
||||
| pynutil.add_weight(word_graph, 100)
|
||||
)
|
||||
|
||||
punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=1.1) + pynutil.insert(" }")
|
||||
punct = (
|
||||
pynutil.insert("tokens { ")
|
||||
+ pynutil.add_weight(punct_graph, weight=1.1)
|
||||
+ pynutil.insert(" }")
|
||||
)
|
||||
token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }")
|
||||
token_plus_punct = (
|
||||
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
|
||||
pynini.closure(punct + pynutil.insert(" "))
|
||||
+ token
|
||||
+ pynini.closure(pynutil.insert(" ") + punct)
|
||||
)
|
||||
|
||||
graph = token_plus_punct + pynini.closure(delete_extra_space + token_plus_punct)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, convert_space
|
||||
@ -16,5 +15,5 @@ class WhiteListFst(GraphFst):
|
||||
super().__init__(name="whitelist", kind="classify")
|
||||
|
||||
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv")).invert()
|
||||
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
|
||||
graph = pynutil.insert('name: "') + convert_space(whitelist) + pynutil.insert('"')
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_SPACE, GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -12,5 +11,5 @@ class WordFst(GraphFst):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="word", kind="classify")
|
||||
word = pynutil.insert("name: \"") + pynini.closure(DAMO_NOT_SPACE, 1) + pynutil.insert("\"")
|
||||
word = pynutil.insert('name: "') + pynini.closure(DAMO_NOT_SPACE, 1) + pynutil.insert('"')
|
||||
self.fst = word.optimize()
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
@ -15,7 +13,7 @@ def num_to_word(x: Union[str, int]):
|
||||
Args
|
||||
x: integer
|
||||
|
||||
Returns: spoken representation
|
||||
Returns: spoken representation
|
||||
"""
|
||||
if isinstance(x, int):
|
||||
x = str(x)
|
||||
@ -29,7 +27,7 @@ def get_abs_path(rel_path):
|
||||
|
||||
Args:
|
||||
rel_path: relative path to this file
|
||||
|
||||
|
||||
Returns absolute path
|
||||
"""
|
||||
return os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
|
||||
return os.path.dirname(os.path.abspath(__file__)) + "/" + rel_path
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_NOT_QUOTE,
|
||||
GraphFst,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -16,9 +18,9 @@ class CardinalFst(GraphFst):
|
||||
optional_sign = pynini.closure(
|
||||
pynutil.delete("negative:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ DAMO_NOT_QUOTE
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ delete_space,
|
||||
0,
|
||||
1,
|
||||
@ -26,9 +28,9 @@ class CardinalFst(GraphFst):
|
||||
graph = (
|
||||
pynutil.delete("integer:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
self.numbers = graph
|
||||
graph = optional_sign + graph
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_NOT_QUOTE,
|
||||
@ -22,43 +20,47 @@ class DateFst(GraphFst):
|
||||
month = (
|
||||
pynutil.delete("month:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
day = (
|
||||
pynutil.delete("day:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
year = (
|
||||
pynutil.delete("year:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
|
||||
# month (day) year
|
||||
graph_mdy = (
|
||||
month + pynini.closure(delete_extra_space + day, 0, 1) + pynini.closure(delete_extra_space + year, 0, 1)
|
||||
month
|
||||
+ pynini.closure(delete_extra_space + day, 0, 1)
|
||||
+ pynini.closure(delete_extra_space + year, 0, 1)
|
||||
)
|
||||
|
||||
# (day) month year
|
||||
graph_dmy = (
|
||||
pynini.closure(day + delete_extra_space, 0, 1) + month + pynini.closure(delete_extra_space + year, 0, 1)
|
||||
pynini.closure(day + delete_extra_space, 0, 1)
|
||||
+ month
|
||||
+ pynini.closure(delete_extra_space + year, 0, 1)
|
||||
)
|
||||
|
||||
optional_preserve_order = pynini.closure(
|
||||
pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space
|
||||
| pynutil.delete("field_order:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ DAMO_NOT_QUOTE
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ delete_space
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_NOT_QUOTE,
|
||||
GraphFst,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -13,30 +15,30 @@ class DecimalFst(GraphFst):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="decimal", kind="verbalize")
|
||||
optionl_sign = pynini.closure(pynini.cross("negative: \"true\"", "-") + delete_space, 0, 1)
|
||||
optionl_sign = pynini.closure(pynini.cross('negative: "true"', "-") + delete_space, 0, 1)
|
||||
integer = (
|
||||
pynutil.delete("integer_part:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
optional_integer = pynini.closure(integer + delete_space, 0, 1)
|
||||
fractional = (
|
||||
pynutil.insert(".")
|
||||
+ pynutil.delete("fractional_part:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
optional_fractional = pynini.closure(fractional + delete_space, 0, 1)
|
||||
quantity = (
|
||||
pynutil.delete("quantity:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
optional_quantity = pynini.closure(pynutil.insert(" ") + quantity + delete_space, 0, 1)
|
||||
graph = optional_integer + optional_fractional + optional_quantity
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_NOT_QUOTE,
|
||||
GraphFst,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -16,24 +18,24 @@ class ElectronicFst(GraphFst):
|
||||
user_name = (
|
||||
pynutil.delete("username:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
domain = (
|
||||
pynutil.delete("domain:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
|
||||
protocol = (
|
||||
pynutil.delete("protocol:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
|
||||
graph = user_name + delete_space + pynutil.insert("@") + domain
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
|
||||
|
||||
|
||||
class FractionFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for verbalizing fraction,
|
||||
Finite state transducer for verbalizing fraction,
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
|
||||
from pynini.lib import pynutil
|
||||
@ -16,13 +15,13 @@ class MeasureFst(GraphFst):
|
||||
|
||||
def __init__(self, decimal: GraphFst, cardinal: GraphFst):
|
||||
super().__init__(name="measure", kind="verbalize")
|
||||
optional_sign = pynini.closure(pynini.cross("negative: \"true\"", "-"), 0, 1)
|
||||
optional_sign = pynini.closure(pynini.cross('negative: "true"', "-"), 0, 1)
|
||||
unit = (
|
||||
pynutil.delete("units:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_CHAR - " ", 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ delete_space
|
||||
)
|
||||
graph_decimal = (
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
|
||||
from pynini.lib import pynutil
|
||||
@ -18,9 +17,9 @@ class MoneyFst(GraphFst):
|
||||
unit = (
|
||||
pynutil.delete("currency:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_CHAR - " ", 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
graph = unit + delete_space + decimal.numbers
|
||||
delete_tokens = self.delete_tokens(graph)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, DAMO_SIGMA, GraphFst, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_NOT_QUOTE,
|
||||
DAMO_SIGMA,
|
||||
GraphFst,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -15,9 +19,9 @@ class OrdinalFst(GraphFst):
|
||||
graph = (
|
||||
pynutil.delete("integer:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
convert_eleven = pynini.cross("11", "11th")
|
||||
convert_twelve = pynini.cross("12", "12th")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst
|
||||
from pynini.lib import pynutil
|
||||
@ -8,17 +7,21 @@ class TelephoneFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer for verbalizing telephone, e.g.
|
||||
telephone { number_part: "123-123-5678" }
|
||||
-> 123-123-5678
|
||||
-> 123-123-5678
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="telephone", kind="verbalize")
|
||||
|
||||
number_part = pynutil.delete("number_part: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||
optional_country_code = pynini.closure(
|
||||
pynutil.delete("country_code: \"")
|
||||
number_part = (
|
||||
pynutil.delete('number_part: "')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
optional_country_code = pynini.closure(
|
||||
pynutil.delete('country_code: "')
|
||||
+ pynini.closure(DAMO_NOT_QUOTE, 1)
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.accep(" "),
|
||||
0,
|
||||
1,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_CHAR,
|
||||
@ -20,29 +19,31 @@ class TimeFst(GraphFst):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="time", kind="verbalize")
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
|
||||
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
|
||||
pynutil.insert("0") + DAMO_DIGIT
|
||||
)
|
||||
hour = (
|
||||
pynutil.delete("hours:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_DIGIT, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
minute = (
|
||||
pynutil.delete("minutes:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_DIGIT, 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
suffix = (
|
||||
delete_space
|
||||
+ insert_space
|
||||
+ pynutil.delete("suffix:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_CHAR - " ", 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
optional_suffix = pynini.closure(suffix, 0, 1)
|
||||
zone = (
|
||||
@ -50,9 +51,9 @@ class TimeFst(GraphFst):
|
||||
+ insert_space
|
||||
+ pynutil.delete("zone:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_CHAR - " ", 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
optional_zone = pynini.closure(zone, 0, 1)
|
||||
graph = (
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.cardinal import CardinalFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.date import DateFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.decimal import DecimalFst
|
||||
@ -15,7 +14,7 @@ from fun_text_processing.text_normalization.en.graph_utils import GraphFst
|
||||
class VerbalizeFst(GraphFst):
|
||||
"""
|
||||
Composes other verbalizer grammars.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||
"""
|
||||
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize import VerbalizeFst
|
||||
from fun_text_processing.inverse_text_normalization.en.verbalizers.word import WordFst
|
||||
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, delete_extra_space, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
GraphFst,
|
||||
delete_extra_space,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
class VerbalizeFinalFst(GraphFst):
|
||||
"""
|
||||
Finite state transducer that verbalizes an entire sentence, e.g.
|
||||
Finite state transducer that verbalizes an entire sentence, e.g.
|
||||
tokens { name: "its" } tokens { time { hours: "12" minutes: "30" } } tokens { name: "now" } -> its 12:30 now
|
||||
"""
|
||||
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, DAMO_SIGMA, GraphFst, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_CHAR,
|
||||
DAMO_SIGMA,
|
||||
GraphFst,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -16,9 +19,9 @@ class WhiteListFst(GraphFst):
|
||||
graph = (
|
||||
pynutil.delete("name:")
|
||||
+ delete_space
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
+ pynini.closure(DAMO_CHAR - " ", 1)
|
||||
+ pynutil.delete("\"")
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
graph = graph @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", DAMO_SIGMA)
|
||||
graph = graph @ pynini.cdrewrite(pynini.cross("\u00A0", " "), "", "", DAMO_SIGMA)
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
|
||||
import pynini
|
||||
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, DAMO_SIGMA, GraphFst, delete_space
|
||||
from fun_text_processing.text_normalization.en.graph_utils import (
|
||||
DAMO_CHAR,
|
||||
DAMO_SIGMA,
|
||||
GraphFst,
|
||||
delete_space,
|
||||
)
|
||||
from pynini.lib import pynutil
|
||||
|
||||
|
||||
@ -13,7 +17,13 @@ class WordFst(GraphFst):
|
||||
def __init__(self):
|
||||
super().__init__(name="word", kind="verbalize")
|
||||
chars = pynini.closure(DAMO_CHAR - " ", 1)
|
||||
char = pynutil.delete("name:") + delete_space + pynutil.delete("\"") + chars + pynutil.delete("\"")
|
||||
graph = char @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", DAMO_SIGMA)
|
||||
char = (
|
||||
pynutil.delete("name:")
|
||||
+ delete_space
|
||||
+ pynutil.delete('"')
|
||||
+ chars
|
||||
+ pynutil.delete('"')
|
||||
)
|
||||
graph = char @ pynini.cdrewrite(pynini.cross("\u00A0", " "), "", "", DAMO_SIGMA)
|
||||
|
||||
self.fst = graph.optimize()
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
|
||||
from fun_text_processing.inverse_text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
|
||||
from fun_text_processing.inverse_text_normalization.es.taggers.tokenize_and_classify import (
|
||||
ClassifyFst,
|
||||
)
|
||||
from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize import VerbalizeFst
|
||||
from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||
from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize_final import (
|
||||
VerbalizeFinalFst,
|
||||
)
|
||||
|
||||
@ -1,3 +1 @@
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1 @@
|
||||
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user