Dev gzf exp (#1654)

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* whisper

* whisper

* update style

* update style

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
This commit is contained in:
zhifu gao 2024-04-24 16:03:38 +08:00 committed by GitHub
parent 7c3ba91f67
commit 861147c730
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1015 changed files with 25665 additions and 30646 deletions

6
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 24.4.0
hooks:
- id: black
args: ['--line-length=100'] # 示例参数black默认使用4个空格缩进

View File

@ -17,9 +17,9 @@
# -- Project information -----------------------------------------------------
project = 'FunASR'
copyright = '2022, Speech Lab, Alibaba Group'
author = 'Speech Lab, Alibaba Grou'
project = "FunASR"
copyright = "2022, Speech Lab, Alibaba Group"
author = "Speech Lab, Alibaba Group"
# -- General configuration ---------------------------------------------------
@ -30,18 +30,18 @@ author = 'Speech Lab, Alibaba Grou'
extensions = [
"nbsphinx",
"sphinx.ext.autodoc",
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.mathjax",
"sphinx.ext.todo",
# "sphinxarg.ext",
"sphinx_markdown_tables",
'recommonmark',
'sphinx_rtd_theme',
"recommonmark",
"sphinx_rtd_theme",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
source_suffix = [".rst", ".md"]
@ -64,4 +64,4 @@ html_theme = "sphinx_rtd_theme"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ['_static']
# html_static_path = ['_static']

View File

@ -6,23 +6,23 @@
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = 'MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0'
copyright = '2023, Speech Lab, Alibaba Group; ASLP Group, Northwestern Polytechnical University'
author = 'Speech Lab, Alibaba Group; Audio, Speech and Language Processing Group, Northwestern Polytechnical University'
project = "MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0"
copyright = "2023, Speech Lab, Alibaba Group; ASLP Group, Northwestern Polytechnical University"
author = "Speech Lab, Alibaba Group; Audio, Speech and Language Processing Group, Northwestern Polytechnical University"
extensions = [
"nbsphinx",
"sphinx.ext.autodoc",
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.mathjax",
"sphinx.ext.todo",
# "sphinxarg.ext",
"sphinx_markdown_tables",
# 'recommonmark',
'sphinx_rtd_theme',
'myst_parser',
"sphinx_rtd_theme",
"myst_parser",
]
myst_enable_extensions = [
@ -33,13 +33,12 @@ myst_enable_extensions = [
]
myst_heading_anchors = 2
myst_highlight_code_blocks=True
myst_update_mathjax=False
myst_highlight_code_blocks = True
myst_update_mathjax = False
templates_path = ['_templates']
templates_path = ["_templates"]
source_suffix = [".rst", ".md"]
pygments_style = "sphinx"
html_theme = "sphinx_rtd_theme"

View File

@ -2,66 +2,98 @@ import os
import numpy as np
import sys
def compute_wer(ref_file,
hyp_file,
cer_detail_file):
def compute_wer(ref_file, hyp_file, cer_detail_file):
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
"Wrd": 0,
"Corr": 0,
"Ins": 0,
"Del": 0,
"Sub": 0,
"Snt": 0,
"Err": 0.0,
"S.Err": 0.0,
"wrong_words": 0,
"wrong_sentences": 0,
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, 'r') as hyp_reader:
with open(hyp_file, "r") as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, 'r') as ref_reader:
with open(ref_file, "r") as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, 'w')
cer_detail_writer = open(cer_detail_file, "w")
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst["Wrd"] += out_item["nwords"]
rst["Corr"] += out_item["cor"]
rst["wrong_words"] += out_item["wrong"]
rst["Ins"] += out_item["ins"]
rst["Del"] += out_item["del"]
rst["Sub"] += out_item["sub"]
rst["Snt"] += 1
if out_item["wrong"] > 0:
rst["wrong_sentences"] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + "\n")
cer_detail_writer.write(
"ref:" + "\t" + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + "\n"
)
cer_detail_writer.write(
"hyp:" + "\t" + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + "\n"
)
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
if rst["Wrd"] > 0:
rst["Err"] = round(rst["wrong_words"] * 100 / rst["Wrd"], 2)
if rst["Snt"] > 0:
rst["S.Err"] = round(rst["wrong_sentences"] * 100 / rst["Snt"], 2)
cer_detail_writer.write('\n')
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
cer_detail_writer.write("\n")
cer_detail_writer.write(
"%WER "
+ str(rst["Err"])
+ " [ "
+ str(rst["wrong_words"])
+ " / "
+ str(rst["Wrd"])
+ ", "
+ str(rst["Ins"])
+ " ins, "
+ str(rst["Del"])
+ " del, "
+ str(rst["Sub"])
+ " sub ]"
+ "\n"
)
cer_detail_writer.write(
"%SER "
+ str(rst["S.Err"])
+ " [ "
+ str(rst["wrong_sentences"])
+ " / "
+ str(rst["Snt"])
+ " ]"
+ "\n"
)
cer_detail_writer.write(
"Scored "
+ str(len(hyp_dict))
+ " sentences, "
+ str(len(hyp_dict) - rst["Snt"])
+ " not present in hyp."
+ "\n"
)
def compute_wer_by_line(hyp,
ref):
def compute_wer_by_line(hyp, ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
@ -96,14 +128,7 @@ def compute_wer_by_line(hyp,
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
rst = {"nwords": len_ref, "cor": 0, "wrong": 0, "ins": 0, "del": 0, "sub": 0}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
@ -111,42 +136,57 @@ def compute_wer_by_line(hyp,
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
rst["cor"] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
rst["ins"] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
rst["del"] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
rst["sub"] += 1
if i < 0 and j >= 0:
rst['del'] += 1
rst["del"] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
rst["ins"] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
rst["wrong"] = wrong_cnt
return rst
def print_cer_detail(rst):
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
if __name__ == '__main__':
def print_cer_detail(rst):
return (
"("
+ "nwords="
+ str(rst["nwords"])
+ ",cor="
+ str(rst["cor"])
+ ",ins="
+ str(rst["ins"])
+ ",del="
+ str(rst["del"])
+ ",sub="
+ str(rst["sub"])
+ ") corr:"
+ "{:.2%}".format(rst["cor"] / rst["nwords"])
+ ",cer:"
+ "{:.2%}".format(rst["wrong"] / rst["nwords"])
)
if __name__ == "__main__":
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)

View File

@ -5,6 +5,7 @@ import os
import torch
from kaldiio import WriteHelper
import re
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
@ -16,17 +17,17 @@ model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
with open(text_file_json, 'r') as f:
with open(text_file_json, "r") as f:
js = f.readlines()
f_shape = open(out_shape, "w")
with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer:
with torch.no_grad():
for idx, line in enumerate(js):
id, tokens = line.strip().split(" ", 1)
tokens = re.sub(" ", "", tokens.strip())
tokens = ' '.join([j for j in tokens])
tokens = " ".join([j for j in tokens])
token_num = len(tokens.split(" "))
outputs = extractor(tokens)
outputs = np.array(outputs)
@ -38,10 +39,11 @@ with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
f_shape.write(shape_line)
else:
print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
print(
"{}, size has changed, {}, {}, {}".format(
id, token_num, token_num_embeds, tokens
)
)
f_shape.close()

View File

@ -1,4 +1,3 @@
import sys
import re
@ -7,25 +6,25 @@ out_f = sys.argv[2]
with open(in_f, "r", encoding="utf-8") as f:
lines = f.readlines()
lines = f.readlines()
with open(out_f, "w", encoding="utf-8") as f:
for line in lines:
outs = line.strip().split(" ", 1)
if len(outs) == 2:
idx, text = outs
text = re.sub("</s>", "", text)
text = re.sub("<s>", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
text = re.sub("<unk>", "", text)
text = re.sub(" ", "", text)
text = text.lower()
else:
idx = outs[0]
text = " "
for line in lines:
outs = line.strip().split(" ", 1)
if len(outs) == 2:
idx, text = outs
text = re.sub("</s>", "", text)
text = re.sub("<s>", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
text = re.sub("<unk>", "", text)
text = re.sub(" ", "", text)
text = text.lower()
else:
idx = outs[0]
text = " "
text = [x for x in text]
text = " ".join(text)
out = "{} {}\n".format(idx, text)
f.write(out)
text = [x for x in text]
text = " ".join(text)
out = "{} {}\n".format(idx, text)
f.write(out)

View File

@ -37,9 +37,7 @@ def get_parser():
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--skip-ncols", "-s", default=0, type=int, help="skip first n columns")
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
@ -80,9 +78,7 @@ def main():
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer)
line = f.readline()
n = args.nchar
while line:

View File

@ -4,7 +4,7 @@ import argparse
def load_dict(seg_file):
seg_dict = {}
with open(seg_file, 'r') as infile:
with open(seg_file, "r") as infile:
for line in infile:
s = line.strip().split()
key = s[0]
@ -28,8 +28,7 @@ def forward_segment(text, dic):
return word_list
def tokenize(txt,
seg_dict):
def tokenize(txt, seg_dict):
out_txt = ""
pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
for word in txt:
@ -87,20 +86,19 @@ def main():
parser = get_parser()
args = parser.parse_args()
txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), "w")
shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), "w")
seg_dict = load_dict(args.seg_file)
with open(args.text_file, 'r') as infile:
with open(args.text_file, "r") as infile:
for line in infile:
s = line.strip().split()
text_id = s[0]
text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
text = tokenize(text_list, seg_dict)
lens = len(text.strip().split())
txt_writer.write(text_id + " " + text + '\n')
shape_writer.write(text_id + " " + str(lens) + '\n')
txt_writer.write(text_id + " " + text + "\n")
shape_writer.write(text_id + " " + str(lens) + "\n")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -14,50 +14,59 @@ import sys, os, argparse, codecs, string, re
# ================================================================================ #
# basic constant
# ================================================================================ #
CHINESE_DIGIS = u'零一二三四五六七八九'
BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
CHINESE_DIGIS = "零一二三四五六七八九"
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
ZERO_ALT = u''
ONE_ALT = u''
TWO_ALTS = [u'', u'']
ZERO_ALT = ""
ONE_ALT = ""
TWO_ALTS = ["", ""]
POSITIVE = [u'', u'']
NEGATIVE = [u'', u'']
POINT = [u'', u'']
POSITIVE = ["", ""]
NEGATIVE = ["", ""]
POINT = ["", ""]
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
FILLER_CHARS = ['', '']
ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
'胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
'儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
'佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
FILLER_CHARS = ["", ""]
ER_WHITELIST = (
"(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
"胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
"儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
"佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
)
# 中文数字系统类型
NUMBERING_TYPES = ['low', 'mid', 'high']
NUMBERING_TYPES = ["low", "mid", "high"]
CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
'里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
'砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
'针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
'毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
'盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
'纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
CURRENCY_NAMES = (
"(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
"里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
)
CURRENCY_UNITS = (
"((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
)
COM_QUANTIFIERS = (
"(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
"砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
"针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
"毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
"盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
"纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
)
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
CHINESE_PUNC_STOP = '!?。。'
CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
CHINESE_PUNC_STOP = "!?。。"
CHINESE_PUNC_NON_STOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏"
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
# ================================================================================ #
# basic class
# ================================================================================ #
@ -72,7 +81,7 @@ class ChineseChar(object):
def __init__(self, simplified, traditional):
self.simplified = simplified
self.traditional = traditional
#self.__repr__ = self.__str__
# self.__repr__ = self.__str__
def __str__(self):
return self.simplified or self.traditional or None
@ -95,26 +104,49 @@ class ChineseNumberUnit(ChineseChar):
self.big_t = big_t
def __str__(self):
return '10^{}'.format(self.power)
return "10^{}".format(self.power)
@classmethod
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
if small_unit:
return ChineseNumberUnit(power=index + 1,
simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
return ChineseNumberUnit(
power=index + 1,
simplified=value[0],
traditional=value[1],
big_s=value[1],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[0]:
return ChineseNumberUnit(power=index + 8,
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
return ChineseNumberUnit(
power=index + 8,
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[1]:
return ChineseNumberUnit(power=(index + 2) * 4,
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
return ChineseNumberUnit(
power=(index + 2) * 4,
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[2]:
return ChineseNumberUnit(power=pow(2, index + 3),
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
return ChineseNumberUnit(
power=pow(2, index + 3),
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
else:
raise ValueError(
'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
"Counting type should be in {0} ({1} provided).".format(
NUMBERING_TYPES, numbering_type
)
)
class ChineseNumberDigit(ChineseChar):
@ -158,6 +190,7 @@ class NumberSystem(object):
"""
中文数字系统
"""
pass
@ -207,27 +240,27 @@ def create_system(numbering_type=NUMBERING_TYPES[1]):
# chinese number units of '亿' and larger
all_larger_units = zip(
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
larger_units = [CNU.create(i, v, numbering_type, False)
for i, v in enumerate(all_larger_units)]
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
)
larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
# chinese number units of '十, 百, 千, 万'
all_smaller_units = zip(
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
smaller_units = [CNU.create(i, v, small_unit=True)
for i, v in enumerate(all_smaller_units)]
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
)
smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
# digis
chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
chinese_digis = zip(
CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL
)
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
# symbols
positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
point_cn = CM(POINT[0], POINT[1], '.', lambda x,
y: float(str(x) + '.' + str(y)))
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
system = NumberSystem()
system.units = smaller_units + larger_units
@ -251,13 +284,14 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
return m
def string2symbols(chinese_string, system):
int_string, dec_string = chinese_string, ''
int_string, dec_string = chinese_string, ""
for p in [system.math.point.simplified, system.math.point.traditional]:
if p in chinese_string:
int_string, dec_string = chinese_string.split(p)
break
return [get_symbol(c, system) for c in int_string], \
[get_symbol(c, system) for c in dec_string]
return [get_symbol(c, system) for c in int_string], [
get_symbol(c, system) for c in dec_string
]
def correct_symbols(integer_symbols, system):
"""
@ -271,8 +305,7 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
if len(integer_symbols) > 1:
if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
integer_symbols.append(
CNU(integer_symbols[-2].power - 1, None, None, None, None))
integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
result = []
unit_count = 0
@ -288,9 +321,13 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
result.append(current_unit)
elif unit_count > 1:
for i in range(len(result)):
if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
result[-i - 1] = CNU(result[-i - 1].power +
current_unit.power, None, None, None, None)
if (
isinstance(result[-i - 1], CNU)
and result[-i - 1].power < current_unit.power
):
result[-i - 1] = CNU(
result[-i - 1].power + current_unit.power, None, None, None, None
)
return result
def compute_value(integer_symbols):
@ -307,8 +344,7 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
elif isinstance(s, CNU):
value[-1] *= pow(10, s.power)
if s.power > last_power:
value[:-1] = list(map(lambda v: v *
pow(10, s.power), value[:-1]))
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
last_power = s.power
value.append(0)
return sum(value)
@ -317,20 +353,28 @@ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
int_part, dec_part = string2symbols(chinese_string, system)
int_part = correct_symbols(int_part, system)
int_str = str(compute_value(int_part))
dec_str = ''.join([str(d.value) for d in dec_part])
dec_str = "".join([str(d.value) for d in dec_part])
if dec_part:
return '{0}.{1}'.format(int_str, dec_str)
return "{0}.{1}".format(int_str, dec_str)
else:
return int_str
def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
traditional=False, alt_zero=False, alt_one=False, alt_two=True,
use_zeros=True, use_units=True):
def num2chn(
number_string,
numbering_type=NUMBERING_TYPES[1],
big=False,
traditional=False,
alt_zero=False,
alt_one=False,
alt_two=True,
use_zeros=True,
use_units=True,
):
def get_value(value_string, use_zeros=True):
striped_string = value_string.lstrip('0')
striped_string = value_string.lstrip("0")
# record nothing if all zeros
if not striped_string:
@ -345,14 +389,17 @@ def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
# recursively record multiple digits
else:
result_unit = next(u for u in reversed(
system.units) if u.power < len(striped_string))
result_string = value_string[:-result_unit.power]
return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
result_string = value_string[: -result_unit.power]
return (
get_value(result_string)
+ [result_unit]
+ get_value(striped_string[-result_unit.power :])
)
system = create_system(numbering_type)
int_dec = number_string.split('.')
int_dec = number_string.split(".")
if len(int_dec) == 1:
int_string = int_dec[0]
dec_string = ""
@ -361,7 +408,8 @@ def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
dec_string = int_dec[1]
else:
raise ValueError(
"invalid input num string with more than one dot: {}".format(number_string))
"invalid input num string with more than one dot: {}".format(number_string)
)
if use_units and len(int_string) > 1:
result_symbols = get_value(int_string)
@ -372,51 +420,62 @@ def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
result_symbols += [system.math.point] + dec_symbols
if alt_two:
liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
system.digits[2].big_s, system.digits[2].big_t)
liang = CND(
2,
system.digits[2].alt_s,
system.digits[2].alt_t,
system.digits[2].big_s,
system.digits[2].big_t,
)
for i, v in enumerate(result_symbols):
if isinstance(v, CND) and v.value == 2:
next_symbol = result_symbols[i +
1] if i < len(result_symbols) - 1 else None
next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
previous_symbol = result_symbols[i - 1] if i > 0 else None
if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol.power != 1)
):
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output
if big:
attr_name = 'big_'
attr_name = "big_"
if traditional:
attr_name += 't'
attr_name += "t"
else:
attr_name += 's'
attr_name += "s"
else:
if traditional:
attr_name = 'traditional'
attr_name = "traditional"
else:
attr_name = 'simplified'
attr_name = "simplified"
result = ''.join([getattr(s, attr_name) for s in result_symbols])
result = "".join([getattr(s, attr_name) for s in result_symbols])
# if not use_zeros:
# result = result.strip(getattr(system.digits[0], attr_name))
if alt_zero:
result = result.replace(
getattr(system.digits[0], attr_name), system.digits[0].alt_s)
result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
if alt_one:
result = result.replace(
getattr(system.digits[1], attr_name), system.digits[1].alt_s)
result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
for i, p in enumerate(POINT):
if result.startswith(p):
return CHINESE_DIGIS[0] + result
# ^10, 11, .., 19
if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
if (
len(result) >= 2
and result[1]
in [
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
]
and result[0]
in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
):
result = result[1:]
return result
@ -440,6 +499,7 @@ class Cardinal:
def cardinal2chntext(self):
return num2chn(self.cardinal)
class Digit:
"""
DIGIT类
@ -476,17 +536,17 @@ class TelePhone:
def telephone2chntext(self, fixed=False):
if fixed:
sil_parts = self.telephone.split('-')
self.raw_chntext = '<SIL>'.join([
num2chn(part, alt_two=False, use_units=False) for part in sil_parts
])
self.chntext = self.raw_chntext.replace('<SIL>', '')
sil_parts = self.telephone.split("-")
self.raw_chntext = "<SIL>".join(
[num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
)
self.chntext = self.raw_chntext.replace("<SIL>", "")
else:
sp_parts = self.telephone.strip('+').split()
self.raw_chntext = '<SP>'.join([
num2chn(part, alt_two=False, use_units=False) for part in sp_parts
])
self.chntext = self.raw_chntext.replace('<SP>', '')
sp_parts = self.telephone.strip("+").split()
self.raw_chntext = "<SP>".join(
[num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
)
self.chntext = self.raw_chntext.replace("<SP>", "")
return self.chntext
@ -500,12 +560,12 @@ class Fraction:
self.chntext = chntext
def chntext2fraction(self):
denominator, numerator = self.chntext.split('分之')
return chn2num(numerator) + '/' + chn2num(denominator)
denominator, numerator = self.chntext.split("分之")
return chn2num(numerator) + "/" + chn2num(denominator)
def fraction2chntext(self):
numerator, denominator = self.fraction.split('/')
return num2chn(denominator) + '分之' + num2chn(numerator)
numerator, denominator = self.fraction.split("/")
return num2chn(denominator) + "分之" + num2chn(numerator)
class Date:
@ -544,23 +604,23 @@ class Date:
def date2chntext(self):
date = self.date
try:
year, other = date.strip().split('', 1)
year = Digit(digit=year).digit2chntext() + ''
year, other = date.strip().split("", 1)
year = Digit(digit=year).digit2chntext() + ""
except ValueError:
other = date
year = ''
year = ""
if other:
try:
month, day = other.strip().split('', 1)
month = Cardinal(cardinal=month).cardinal2chntext() + ''
month, day = other.strip().split("", 1)
month = Cardinal(cardinal=month).cardinal2chntext() + ""
except ValueError:
day = date
month = ''
month = ""
if day:
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
else:
month = ''
day = ''
month = ""
day = ""
chntext = year + month + day
self.chntext = chntext
return self.chntext
@ -580,7 +640,7 @@ class Money:
def money2chntext(self):
money = self.money
pattern = re.compile(r'(\d+(\.\d+)?)')
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(money)
if matchers:
for matcher in matchers:
@ -599,10 +659,10 @@ class Percentage:
self.chntext = chntext
def chntext2percentage(self):
return chn2num(self.chntext.strip().strip('百分之')) + '%'
return chn2num(self.chntext.strip().strip("百分之")) + "%"
def percentage2chntext(self):
return '百分之' + num2chn(self.percentage.strip().strip('%'))
return "百分之" + num2chn(self.percentage.strip().strip("%"))
def remove_erhua(text, er_whitelist):
@ -612,9 +672,9 @@ def remove_erhua(text, er_whitelist):
"""
er_pattern = re.compile(er_whitelist)
new_str=''
while re.search('',text):
a = re.search('',text).span()
new_str = ""
while re.search("", text):
a = re.search("", text).span()
remove_er_flag = 0
if er_pattern.search(text):
@ -622,23 +682,24 @@ def remove_erhua(text, er_whitelist):
if b[0] <= a[0]:
remove_er_flag = 1
if remove_er_flag == 0 :
new_str = new_str + text[0:a[0]]
text = text[a[1]:]
if remove_er_flag == 0:
new_str = new_str + text[0 : a[0]]
text = text[a[1] :]
else:
new_str = new_str + text[0:b[1]]
text = text[b[1]:]
new_str = new_str + text[0 : b[1]]
text = text[b[1] :]
text = new_str + text
return text
# ================================================================================ #
# NSW Normalizer
# ================================================================================ #
class NSWNormalizer:
def __init__(self, raw_text):
self.raw_text = '^' + raw_text + '$'
self.norm_text = ''
self.raw_text = "^" + raw_text + "$"
self.norm_text = ""
def _particular(self):
text = self.norm_text
@ -647,7 +708,7 @@ class NSWNormalizer:
if matchers:
# print('particular')
for matcher in matchers:
text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
self.norm_text = text
return self.norm_text
@ -658,15 +719,17 @@ class NSWNormalizer:
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
matchers = pattern.findall(text)
if matchers:
#print('date')
# print('date')
for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
# 规范化金钱
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
pattern = re.compile(
r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)"
)
matchers = pattern.findall(text)
if matchers:
#print('money')
# print('money')
for matcher in matchers:
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
@ -679,39 +742,45 @@ class NSWNormalizer:
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text)
if matchers:
#print('telephone')
# print('telephone')
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
text = text.replace(
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
)
# 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text)
if matchers:
# print('fixed telephone')
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
text = text.replace(
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1
)
# 规范化分数
pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text)
if matchers:
#print('fraction')
# print('fraction')
for matcher in matchers:
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
# 规范化百分数
text = text.replace('', '%')
text = text.replace("", "%")
pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text)
if matchers:
#print('percentage')
# print('percentage')
for matcher in matchers:
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
text = text.replace(
matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1
)
# 规范化纯数+量词
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text)
if matchers:
#print('cardinal+quantifier')
# print('cardinal+quantifier')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@ -719,7 +788,7 @@ class NSWNormalizer:
pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text)
if matchers:
#print('digit')
# print('digit')
for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
@ -727,74 +796,82 @@ class NSWNormalizer:
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text)
if matchers:
#print('cardinal')
# print('cardinal')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
self.norm_text = text
self._particular()
return self.norm_text.lstrip('^').rstrip('$')
return self.norm_text.lstrip("^").rstrip("$")
def nsw_test_case(raw_text):
print('I:' + raw_text)
print('O:' + NSWNormalizer(raw_text).normalize())
print('')
print("I:" + raw_text)
print("O:" + NSWNormalizer(raw_text).normalize())
print("")
def nsw_test():
nsw_test_case('固话0595-23865596或23880880。')
nsw_test_case('固话0595-23865596或23880880。')
nsw_test_case('手机:+86 19859213959或15659451527。')
nsw_test_case('分数32477/76391。')
nsw_test_case('百分数80.03%')
nsw_test_case('编号31520181154418。')
nsw_test_case('纯数2983.07克或12345.60米。')
nsw_test_case('日期1999年2月20日或09年3月15号。')
nsw_test_case('金钱12块534.5元20.1万')
nsw_test_case('特殊O2O或B2C。')
nsw_test_case('3456万吨')
nsw_test_case('2938个')
nsw_test_case('938')
nsw_test_case('今天吃了115个小笼包231个馒头')
nsw_test_case('有62的概率')
nsw_test_case("固话0595-23865596或23880880。")
nsw_test_case("固话0595-23865596或23880880。")
nsw_test_case("手机:+86 19859213959或15659451527。")
nsw_test_case("分数32477/76391。")
nsw_test_case("百分数80.03%")
nsw_test_case("编号31520181154418。")
nsw_test_case("纯数2983.07克或12345.60米。")
nsw_test_case("日期1999年2月20日或09年3月15号。")
nsw_test_case("金钱12块534.5元20.1万")
nsw_test_case("特殊O2O或B2C。")
nsw_test_case("3456万吨")
nsw_test_case("2938个")
nsw_test_case("938")
nsw_test_case("今天吃了115个小笼包231个馒头")
nsw_test_case("有62的概率")
if __name__ == '__main__':
#nsw_test()
if __name__ == "__main__":
# nsw_test()
p = argparse.ArgumentParser()
p.add_argument('ifile', help='input filename, assume utf-8 encoding')
p.add_argument('ofile', help='output filename')
p.add_argument('--to_upper', action='store_true', help='convert to upper case')
p.add_argument('--to_lower', action='store_true', help='convert to lower case')
p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"')
p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"')
p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
p.add_argument("ifile", help="input filename, assume utf-8 encoding")
p.add_argument("ofile", help="output filename")
p.add_argument("--to_upper", action="store_true", help="convert to upper case")
p.add_argument("--to_lower", action="store_true", help="convert to lower case")
p.add_argument(
"--has_key", action="store_true", help="input text has Kaldi's key as first field."
)
p.add_argument(
"--remove_fillers", type=bool, default=True, help='remove filler chars such as "呃, 啊"'
)
p.add_argument(
"--remove_erhua", type=bool, default=True, help='remove erhua chars such as "这儿"'
)
p.add_argument(
"--log_interval", type=int, default=10000, help="log interval in number of processed lines"
)
args = p.parse_args()
ifile = codecs.open(args.ifile, 'r', 'utf8')
ofile = codecs.open(args.ofile, 'w+', 'utf8')
ifile = codecs.open(args.ifile, "r", "utf8")
ofile = codecs.open(args.ofile, "w+", "utf8")
n = 0
for l in ifile:
key = ''
text = ''
key = ""
text = ""
if args.has_key:
cols = l.split(maxsplit=1)
key = cols[0]
if len(cols) == 2:
text = cols[1].strip()
else:
text = ''
text = ""
else:
text = l.strip()
# cases
if args.to_upper and args.to_lower:
sys.stderr.write('text norm: to_upper OR to_lower?')
sys.stderr.write("text norm: to_upper OR to_lower?")
exit(1)
if args.to_upper:
text = text.upper()
@ -804,7 +881,7 @@ if __name__ == '__main__':
# Filler chars removal
if args.remove_fillers:
for ch in FILLER_CHARS:
text = text.replace(ch, '')
text = text.replace(ch, "")
if args.remove_erhua:
text = remove_erhua(text, ER_WHITELIST)
@ -813,16 +890,16 @@ if __name__ == '__main__':
text = NSWNormalizer(text).normalize()
# Punctuations removal
old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
new_chars = ' ' * len(old_chars)
del_chars = ''
old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
new_chars = " " * len(old_chars)
del_chars = ""
text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
#
if args.has_key:
ofile.write(key + '\t' + text + '\n')
ofile.write(key + "\t" + text + "\n")
else:
ofile.write(text + '\n')
ofile.write(text + "\n")
n += 1
if n % args.log_interval == 0:

View File

@ -16,4 +16,4 @@ model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch")
for wav_id in multilingual_wavs:
wav_file = f"{model.model_path}/examples/{wav_id}"
res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
print("detect sample {}: {}".format(wav_id, res))
print("detect sample {}: {}".format(wav_id, res))

View File

@ -6,7 +6,7 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
multilingual_wavs=[
multilingual_wavs = [
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
@ -14,9 +14,9 @@ multilingual_wavs=[
]
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='iic/speech_whisper-large_lid_multilingual_pytorch')
task=Tasks.auto_speech_recognition, model="iic/speech_whisper-large_lid_multilingual_pytorch"
)
for wav in multilingual_wavs:
rec_result = inference_pipeline(input=wav, inference_clip_length=250)
print(rec_result)
print(rec_result)

View File

@ -5,11 +5,16 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
model = AutoModel(
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav",
batch_size_s=300,
batch_size_threshold_s=60,
)
print(res)

View File

@ -7,7 +7,10 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", device="cpu")
model = AutoModel(
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
device="cpu",
)
res = model.export(type="onnx", quantize=False)
print(res)
@ -16,7 +19,10 @@ print(res)
# method2, inference from local path
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", device="cpu")
model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
device="cpu",
)
res = model.export(type="onnx", quantize=False)
print(res)
print(res)

View File

@ -5,8 +5,9 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_campplus_sv_zh-cn_16k-common"
)
model = AutoModel(model="iic/speech_campplus_sv_zh-cn_16k-common")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
print(res)

View File

@ -7,6 +7,7 @@ from funasr import AutoModel
model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch")
res = model.generate(input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav")
res = model.generate(
input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
)
print(res)

View File

@ -7,6 +7,8 @@ from funasr import AutoModel
model = AutoModel(model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 魔搭')
print(res)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword="达摩院 魔搭",
)
print(res)

View File

@ -7,7 +7,9 @@ from funasr import AutoModel
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
)
print(res)
@ -15,5 +17,7 @@ from funasr import AutoModel
model = AutoModel(model="iic/punc_ct-transformer_cn-en-common-vocab471067-large")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
)
print(res)

View File

@ -7,8 +7,9 @@
from funasr import AutoModel
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
)
model = AutoModel(
model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
)
res = model.export(type="onnx", quantize=False)
print(res)
@ -17,7 +18,9 @@ print(res)
# method2, inference from local path
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
)
res = model.export(type="onnx", quantize=False)
print(res)
print(res)

View File

@ -13,6 +13,6 @@ rec_result_all = "outputs: "
cache = {}
for vad in vads:
rec_result = model.generate(input=vad, cache=cache)
rec_result_all += rec_result[0]['text']
rec_result_all += rec_result[0]["text"]
print(rec_result_all)

View File

@ -7,8 +7,9 @@
from funasr import AutoModel
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
)
model = AutoModel(
model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
)
res = model.export(type="onnx", quantize=False)
print(res)
@ -17,7 +18,9 @@ print(res)
# method2, inference from local path
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727")
model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
)
res = model.export(type="onnx", quantize=False)
print(res)
print(res)

View File

@ -6,12 +6,15 @@
from funasr import AutoModel
# model="iic/emotion2vec_base"
model = AutoModel(model="iic/emotion2vec_base_finetuned",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model_revision="master",
# vad_kwargs={"max_single_segment_time": 2000},
)
model = AutoModel(
model="iic/emotion2vec_base_finetuned",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model_revision="master",
# vad_kwargs={"max_single_segment_time": 2000},
)
wav_file = f"{model.model_path}/example/test.wav"
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
print(res)
res = model.generate(
wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False
)
print(res)

View File

@ -4,6 +4,7 @@
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
@ -14,28 +15,28 @@ print(res)
# beg/end: ms
import soundfile
import os
wav_file = os.path.join(model.model_path, "example/vad_example.wav")
speech, sample_rate = soundfile.read(wav_file)
chunk_size = 200 # ms
chunk_size = 200 # ms
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
disable_pbar=True,
)
res = model.generate(
input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
disable_pbar=True,
)
# print(res)
if len(res[0]["value"]):
print(res)
@ -44,4 +45,4 @@ for i in range(total_chunk_num):
# 1. [[beg1, end1], [beg2, end2], .., [begN, endN]]; [[beg, end]]; [[beg1, end1], [beg2, end2]]
# 2. [[beg, -1]]
# 3. [[-1, end]]
# beg/end: ms
# beg/end: ms

View File

@ -17,7 +17,9 @@ print(res)
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
)
res = model.export(type="onnx", quantize=False)
print(res)

View File

@ -9,6 +9,7 @@ import argparse
from tqdm import tqdm
import os
import pdb
remove_tag = False
spacelist = [" ", "\t", "\r", "\n"]
puncts = [
@ -51,9 +52,9 @@ class WordError(object):
def get_wer(self):
assert self.ref_words != 0
errors = (
self.errors[Code.substitution]
+ self.errors[Code.insertion]
+ self.errors[Code.deletion]
self.errors[Code.substitution]
+ self.errors[Code.insertion]
+ self.errors[Code.deletion]
)
return 100.0 * errors / self.ref_words
@ -299,30 +300,30 @@ def default_cluster(word):
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith("DIGIT"): # 1
unicode_names[i] = "Number" # 'DIGIT'
elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
i
].startswith("CJK COMPATIBILITY IDEOGRAPH"):
elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[i].startswith(
"CJK COMPATIBILITY IDEOGRAPH"
):
# 明 / 郎
unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
i
].startswith("LATIN SMALL LETTER"):
elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[i].startswith(
"LATIN SMALL LETTER"
):
# A / a
unicode_names[i] = "English" # 'LATIN LETTER'
elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め
unicode_names[i] = "Japanese" # 'GANA LETTER'
elif (
unicode_names[i].startswith("AMPERSAND")
or unicode_names[i].startswith("APOSTROPHE")
or unicode_names[i].startswith("COMMERCIAL AT")
or unicode_names[i].startswith("DEGREE CELSIUS")
or unicode_names[i].startswith("EQUALS SIGN")
or unicode_names[i].startswith("FULL STOP")
or unicode_names[i].startswith("HYPHEN-MINUS")
or unicode_names[i].startswith("LOW LINE")
or unicode_names[i].startswith("NUMBER SIGN")
or unicode_names[i].startswith("PLUS SIGN")
or unicode_names[i].startswith("SEMICOLON")
unicode_names[i].startswith("AMPERSAND")
or unicode_names[i].startswith("APOSTROPHE")
or unicode_names[i].startswith("COMMERCIAL AT")
or unicode_names[i].startswith("DEGREE CELSIUS")
or unicode_names[i].startswith("EQUALS SIGN")
or unicode_names[i].startswith("FULL STOP")
or unicode_names[i].startswith("HYPHEN-MINUS")
or unicode_names[i].startswith("LOW LINE")
or unicode_names[i].startswith("NUMBER SIGN")
or unicode_names[i].startswith("PLUS SIGN")
or unicode_names[i].startswith("SEMICOLON")
):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
@ -411,11 +412,13 @@ def main(args):
if len(array) == 0:
continue
fid = array[0]
rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
rec_sets[rec_names[i]][fid] = normalize(
array[1:], ignore_words, case_sensitive, split
)
calculators_dict[rec_names[i]] = Calculator()
ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
hotwords_related_dict[rec_names[i]] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
# tp: 热词在label里同时在rec里
# tn: 热词不在label里同时不在rec里
# fp: 热词不在label里但是在rec里
@ -431,21 +434,22 @@ def main(args):
_file_total_len = int(pipe.read().strip())
# compute error rate on the interaction of reference file and hyp file
for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
for line in tqdm(open(ref_file, "r", encoding="utf-8"), total=_file_total_len):
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array) == 0: continue
array = line.rstrip("\n").split()
if len(array) == 0:
continue
fid = array[0]
lab = normalize(array[1:], ignore_words, case_sensitive, split)
if verbose:
print('\nutt: %s' % fid)
print("\nutt: %s" % fid)
ocr_text = ref_ocr_dict[fid]
ocr_set = set(ocr_text)
print('ocr: {}'.format(" ".join(ocr_text)))
print("ocr: {}".format(" ".join(ocr_text)))
list_match = [] # 指label里面在ocr里面的内容
list_not_mathch = []
tmp_error = 0
@ -458,7 +462,7 @@ def main(args):
else:
tmp_match += 1
list_match.append(lab[index])
print('label in ocr: {}'.format(" ".join(list_match)))
print("label in ocr: {}".format(" ".join(list_match)))
# for each reco file
base_wrong_ocr_wer = None
@ -482,33 +486,44 @@ def main(args):
result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
if verbose:
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
if result["all"] != 0:
wer = (
float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
)
else:
wer = 0.0
print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
print("WER(%s): %4.2f %%" % (rec_name, wer), end=" ")
print(
"N=%d C=%d S=%d D=%d I=%d"
% (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
)
# print(result['rec'])
wrong_rec_but_in_ocr = []
for idx in range(len(result['lab'])):
if result['lab'][idx] != "":
if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
if result['lab'][idx] in list_match:
wrong_rec_but_in_ocr.append(result['lab'][idx])
for idx in range(len(result["lab"])):
if result["lab"][idx] != "":
if result["lab"][idx] != result["rec"][idx].replace("<BIAS>", ""):
if result["lab"][idx] in list_match:
wrong_rec_but_in_ocr.append(result["lab"][idx])
wrong_rec_but_in_ocr_dict[rec_name] += 1
print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
print("wrong_rec_but_in_ocr: {}".format(" ".join(wrong_rec_but_in_ocr)))
if rec_name == "base":
base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
if "ocr" in rec_name or "hot" in rec_name:
ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
print(
"{} {} helps, {} -> {}".format(
fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
)
)
elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
print(
"{} {} hurts, {} -> {}".format(
fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
)
)
# recall = 0
# false_alarm = 0
@ -537,11 +552,11 @@ def main(args):
# if badhotword == word:
# count += 1
if count == 0:
hotwords_related_dict[rec_name]['tn'] += 1
hotwords_related_dict[rec_name]["tn"] += 1
_tn += 1
# fp: 0
else:
hotwords_related_dict[rec_name]['fp'] += count
hotwords_related_dict[rec_name]["fp"] += count
_fp += count
# tn: 0
# if badhotword in _rec_list:
@ -553,23 +568,30 @@ def main(args):
rec_count = len([word for word in _rec_list if hotword == word])
# print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
if rec_count == true_count:
hotwords_related_dict[rec_name]['tp'] += true_count
hotwords_related_dict[rec_name]["tp"] += true_count
_tp += true_count
elif rec_count > true_count:
hotwords_related_dict[rec_name]['tp'] += true_count
hotwords_related_dict[rec_name]["tp"] += true_count
# fp: 不在label里但是在rec里
hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
hotwords_related_dict[rec_name]["fp"] += rec_count - true_count
_tp += true_count
_fp += rec_count - true_count
else:
hotwords_related_dict[rec_name]['tp'] += rec_count
hotwords_related_dict[rec_name]["tp"] += rec_count
# fn: 热词在label里但是不在rec里
hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
hotwords_related_dict[rec_name]["fn"] += true_count - rec_count
_tp += rec_count
_fn += true_count - rec_count
print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
_tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
))
print(
"hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
_tp,
_tn,
_fp,
_fn,
sum([_tp, _tn, _fp, _fn]),
_tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0,
)
)
# if hotword in _rec_list:
# hotwords_related_dict[rec_name]['tp'] += 1
@ -612,77 +634,89 @@ def main(args):
ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])):
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
space["lab"] = []
space["rec"] = []
for idx in range(len(result["lab"])):
len_lab = width(result["lab"][idx])
len_rec = width(result["rec"][idx])
length = max(len_lab, len_rec)
space['lab'].append(length - len_lab)
space['rec'].append(length - len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
space["lab"].append(length - len_lab)
space["rec"].append(length - len_rec)
upper_lab = len(result["lab"])
upper_rec = len(result["rec"])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end=' ')
print("lab(%s):" % fid.encode("utf-8"), end=" ")
else:
print('lab:', end=' ')
print("lab:", end=" ")
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]):
print(padding_symbol, end='')
print(' ', end='')
token = result["lab"][idx]
print("{token}".format(token=token), end="")
for n in range(space["lab"][idx]):
print(padding_symbol, end="")
print(" ", end="")
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end=' ')
print("rec(%s):" % fid.encode("utf-8"), end=" ")
else:
print('rec:', end=' ')
print("rec:", end=" ")
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]):
print(padding_symbol, end='')
print(' ', end='')
token = result["rec"][idx]
print("{token}".format(token=token), end="")
for n in range(space["rec"][idx]):
print(padding_symbol, end="")
print(" ", end="")
print()
# print('\n', end='\n')
lab1 = lab2
rec1 = rec2
print('\n', end='\n')
print("\n", end="\n")
# break
if verbose:
print('===========================================================================')
print("===========================================================================")
print()
print(wrong_rec_but_in_ocr_dict)
for rec_name in rec_names:
result = calculators_dict[rec_name].overall()
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
if result["all"] != 0:
wer = float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
else:
wer = 0.0
print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
print("{} Overall -> {:4.2f} %".format(rec_name, wer), end=" ")
print(
"N=%d C=%d S=%d D=%d I=%d"
% (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
)
print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
hotwords_related_dict[rec_name]['tp'],
hotwords_related_dict[rec_name]['tn'],
hotwords_related_dict[rec_name]['fp'],
hotwords_related_dict[rec_name]['fn'],
sum([v for k, v in hotwords_related_dict[rec_name].items()]),
hotwords_related_dict[rec_name]['tp'] / (
hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
))
print(
"hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
hotwords_related_dict[rec_name]["tp"],
hotwords_related_dict[rec_name]["tn"],
hotwords_related_dict[rec_name]["fp"],
hotwords_related_dict[rec_name]["fn"],
sum([v for k, v in hotwords_related_dict[rec_name].items()]),
(
hotwords_related_dict[rec_name]["tp"]
/ (
hotwords_related_dict[rec_name]["tp"]
+ hotwords_related_dict[rec_name]["fn"]
)
* 100
if hotwords_related_dict[rec_name]["tp"] + hotwords_related_dict[rec_name]["fn"]
!= 0
else 0
),
)
)
# tp: 热词在label里同时在rec里
# tn: 热词不在label里同时不在rec里
@ -695,8 +729,7 @@ def main(args):
if __name__ == "__main__":
args = get_args()
# print("")
print(args)
main(args)

View File

@ -5,11 +5,15 @@
from funasr import AutoModel
model = AutoModel(model="iic/LCB-NET",
model_revision="v1.0.0")
model = AutoModel(model="iic/LCB-NET", model_revision="v1.0.0")
res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
res = model.generate(
input=(
"https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav",
"https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt",
),
data_type=("sound", "text"),
)
print(res)

View File

@ -7,9 +7,12 @@ from funasr import AutoModel
model = AutoModel(model="iic/speech_timestamp_prediction-v1-16k-offline")
res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
"欢迎大家来到魔搭社区进行体验"),
data_type=("sound", "text"),
batch_size=2,
)
res = model.generate(
input=(
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
"欢迎大家来到魔搭社区进行体验",
),
data_type=("sound", "text"),
batch_size=2,
)
print(res)

View File

@ -5,12 +5,15 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
model = AutoModel(
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 磨搭')
print(res)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword="达摩院 磨搭",
)
print(res)

View File

@ -5,22 +5,29 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 60000},
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
model = AutoModel(
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 60000},
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
print(res)
''' can not use currently
""" can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
@ -30,4 +37,4 @@ fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/
for batch_idx, fbank_dict in enumerate(fbanks):
res = model.generate(**fbank_dict)
print(res)
'''
"""

View File

@ -9,7 +9,9 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",)
model = AutoModel(
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
)
res = model.export(type="onnx", quantize=False)
print(res)
@ -18,7 +20,9 @@ print(res)
# method2, inference from local path
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
)
res = model.export(type="onnx", quantize=False)
print(res)
print(res)

View File

@ -7,17 +7,18 @@ import os
from funasr import AutoModel
chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
chunk_size = [0, 10, 5] # [0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 4 # number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 # number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
res = model.generate(input=wav_file,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
res = model.generate(
input=wav_file,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
print(res)
@ -27,18 +28,19 @@ import soundfile
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
res = model.generate(
input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
print(res)

View File

@ -9,7 +9,9 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", )
model = AutoModel(
model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online",
)
res = model.export(type="onnx", quantize=False)
print(res)
@ -19,7 +21,9 @@ print(res)
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online"
)
res = model.export(type="onnx", quantize=False)
print(res)
print(res)

View File

@ -12,7 +12,7 @@ model = AutoModel(model="Qwen-Audio-Chat")
audio_in = "https://github.com/QwenLM/Qwen-Audio/raw/main/assets/audio/1272-128104-0000.flac"
# 1st dialogue turn
prompt = 'what does the person say?'
prompt = "what does the person say?"
cache = {"history": None}
res = model.generate(input=audio_in, prompt=prompt, cache=cache)
print(res)
@ -22,4 +22,3 @@ print(res)
prompt = 'Find the start time and end time of the word "middle classes"'
res = model.generate(input=None, prompt=prompt, cache=cache)
print(res)

View File

@ -7,14 +7,17 @@
from funasr import AutoModel
model = AutoModel(model="Qwen-Audio-Chat",
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio-Chat",
)
model = AutoModel(
model="Qwen-Audio-Chat",
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio-Chat",
)
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
audio_in = (
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
# 1st dialogue turn
prompt = 'what does the person say?'
prompt = "what does the person say?"
cache = {"history": None}
res = model.generate(input=audio_in, prompt=prompt, cache=cache)
print(res)
@ -24,4 +27,3 @@ print(res)
prompt = 'Find the start time and end time of the word "middle classes"'
res = model.generate(input=None, prompt=prompt, cache=cache)
print(res)

View File

@ -7,9 +7,10 @@
from funasr import AutoModel
model = AutoModel(model="Qwen-Audio",
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio",
)
model = AutoModel(
model="Qwen-Audio",
model_path="/nfs/zhifu.gzf/init_model/qwen/Qwen-Audio",
)
audio_in = "https://github.com/QwenLM/Qwen-Audio/raw/main/assets/audio/1272-128104-0000.flac"
prompt = "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"

View File

@ -5,17 +5,20 @@
from funasr import AutoModel
chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
chunk_size = [5, 10, 5] # [0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 0 # number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 0 # number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_SCAMA_asr-zh-cn-16k-common-vocab8358-streaming")
model = AutoModel(
model="/Users/zhifu/Downloads/modelscope_models/speech_SCAMA_asr-zh-cn-16k-common-vocab8358-streaming"
)
cache = {}
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
print(res)
@ -25,18 +28,19 @@ import os
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
cache = {}
total_chunk_num = int(len((speech)-1)/chunk_stride+1)
total_chunk_num = int(len((speech) - 1) / chunk_stride + 1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
speech_chunk = speech[i * chunk_stride : (i + 1) * chunk_stride]
is_final = i == total_chunk_num - 1
res = model.generate(input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
res = model.generate(
input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
print(res)

View File

@ -5,24 +5,26 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
model = AutoModel(
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
# example1
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 魔搭',
# return_raw_text=True, # return raw text recognition results splited by space of equal length with timestamp
# preset_spk_num=2, # preset speaker num for speaker cluster model
# sentence_timestamp=True, # return sentence level information when spk_model is not given
)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword="达摩院 魔搭",
# return_raw_text=True, # return raw text recognition results splited by space of equal length with timestamp
# preset_spk_num=2, # preset speaker num for speaker cluster model
# sentence_timestamp=True, # return sentence level information when spk_model is not given
)
print(res)
'''
"""
# tensor or numpy as input
# example2
import torchaudio
@ -39,4 +41,4 @@ import soundfile
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
'''
"""

View File

@ -5,20 +5,23 @@
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
)
model = AutoModel(
model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
)
input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
input_wav = (
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
DecodingOptions = {
"task": ("ASR", "AED", "SER"),
"language": "auto",
"fp16": True,
"gain_event": True,
}
"task": ("ASR", "AED", "SER"),
"language": "auto",
"fp16": True,
"gain_event": True,
}
res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions)
print(res)

View File

@ -7,8 +7,11 @@ from funasr import AutoModel
# Transducer, BAT and RWKV_BAT models are just same to use, use the correct model_revision
# https://modelscope.cn/models?name=transducer&page=1&tasks=auto-speech-recognition&type=audio
model = AutoModel(model="iic/speech_bat_asr-zh-cn-16k-aishell1-vocab4234-pytorch",
)
model = AutoModel(
model="iic/speech_bat_asr-zh-cn-16k-aishell1-vocab4234-pytorch",
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
print(res)

View File

@ -6,14 +6,18 @@
from funasr import AutoModel
model = AutoModel(model="iic/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline",)
model = AutoModel(
model="iic/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline",
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
)
print(res)
''' can not use currently
""" can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
@ -23,4 +27,4 @@ fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/
for batch_idx, fbank_dict in enumerate(fbanks):
res = model.generate(**fbank_dict)
print(res)
'''
"""

View File

@ -7,22 +7,24 @@
from funasr import AutoModel
model = AutoModel(model="iic/Whisper-large-v3",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
)
model = AutoModel(
model="iic/Whisper-large-v3",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
)
DecodingOptions = {
"task": "transcribe",
"language": None,
"beam_size": None,
"fp16": True,
"without_timestamps": False,
"prompt": None,
}
"task": "transcribe",
"language": None,
"beam_size": None,
"fp16": True,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
DecodingOptions=DecodingOptions,
batch_size_s=0,
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
DecodingOptions=DecodingOptions,
batch_size_s=0,
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
)
print(res)

View File

@ -10,15 +10,25 @@ from funasr import AutoModel
# model = AutoModel(model="Whisper-small", hub="openai")
# model = AutoModel(model="Whisper-medium", hub="openai")
# model = AutoModel(model="Whisper-large-v2", hub="openai")
model = AutoModel(model="Whisper-large-v3",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
hub="openai",
)
model = AutoModel(
model="Whisper-large-v3",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
hub="openai",
)
DecodingOptions = {
"task": "transcribe",
"language": None,
"beam_size": None,
"fp16": True,
"without_timestamps": False,
"prompt": None,
}
res = model.generate(
language=None,
task="transcribe",
batch_size_s=0,
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
DecodingOptions=DecodingOptions,
batch_size_s=0,
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
)
print(res)

View File

@ -1,5 +1,7 @@
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import (
ClassifyFst,
)
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize import VerbalizeFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import (
VerbalizeFinalFst,
)

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_SIGMA, GraphFst
from pynini.lib import pynutil
@ -6,7 +5,7 @@ from pynini.lib import pynutil
class CardinalFst(GraphFst):
"""
Finite state transducer for classifying cardinals. Numbers below ten are not converted.
Finite state transducer for classifying cardinals. Numbers below ten are not converted.
Allows both compound numeral strings or separated by whitespace.
"und" (en: "and") can be inserted between "hundert" and following number or "tausend" and following single or double digit number.
@ -18,7 +17,7 @@ class CardinalFst(GraphFst):
e.g. ein tausend -> cardinal { integer: "1000" } }
e.g. eintausend -> cardinal { integer: "1000" } }
e.g. ein tausend zwanzig -> cardinal { integer: "1020" } }
Args:
tn_cardinal_tagger: TN cardinal tagger
"""
@ -31,24 +30,36 @@ class CardinalFst(GraphFst):
graph = (tn_cardinal_tagger.graph @ optional_delete_space).invert().optimize()
self.graph_hundred_component_at_least_one_none_zero_digit = (
(tn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit @ optional_delete_space)
(
tn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit
@ optional_delete_space
)
.invert()
.optimize()
)
self.graph_ties = (tn_cardinal_tagger.two_digit_non_zero @ optional_delete_space).invert().optimize()
self.graph_ties = (
(tn_cardinal_tagger.two_digit_non_zero @ optional_delete_space).invert().optimize()
)
# this is to make sure if there is an ambiguity with decimal, decimal is chosen, e.g. 1000000 vs. 1 million
graph = pynutil.add_weight(graph, weight=0.001)
self.graph_no_exception = graph
self.digit = pynini.arcmap(tn_cardinal_tagger.digit, map_type="rmweight").invert().optimize()
graph_exception = pynini.project(self.digit, 'input')
self.digit = (
pynini.arcmap(tn_cardinal_tagger.digit, map_type="rmweight").invert().optimize()
)
graph_exception = pynini.project(self.digit, "input")
self.graph = (pynini.project(graph, "input") - graph_exception.arcsort()) @ graph
self.optional_minus_graph = pynini.closure(
pynutil.insert("negative: ") + pynini.cross("minus ", "\"-\" "), 0, 1
pynutil.insert("negative: ") + pynini.cross("minus ", '"-" '), 0, 1
)
final_graph = self.optional_minus_graph + pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
final_graph = (
self.optional_minus_graph
+ pynutil.insert('integer: "')
+ self.graph
+ pynutil.insert('"')
)
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_DIGIT,
@ -35,20 +34,28 @@ class DateFst(GraphFst):
):
super().__init__(name="date", kind="classify", deterministic=deterministic)
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
pynutil.insert("0") + DAMO_DIGIT
)
optional_delete_space = pynini.closure(DAMO_SIGMA | pynutil.delete(" ", weight=0.0001))
tagger = tn_date_verbalizer.graph.invert().optimize()
delete_day_marker = (
pynutil.delete("day: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
pynutil.delete('day: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
) @ itn_cardinal_tagger.graph_no_exception
month_as_number = pynutil.delete("month: \"") + itn_cardinal_tagger.graph_no_exception + pynutil.delete("\"")
month_as_string = pynutil.delete("month: \"") + tn_date_tagger.month_abbr.invert() + pynutil.delete("\"")
month_as_number = (
pynutil.delete('month: "')
+ itn_cardinal_tagger.graph_no_exception
+ pynutil.delete('"')
)
month_as_string = (
pynutil.delete('month: "') + tn_date_tagger.month_abbr.invert() + pynutil.delete('"')
)
convert_year = (tn_date_tagger.year @ optional_delete_space).invert().optimize()
delete_year_marker = (
pynutil.delete("year: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
pynutil.delete('year: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
) @ convert_year
# day. month as string (year)
@ -73,5 +80,5 @@ class DateFst(GraphFst):
final_graph = tagger @ verbalizer
graph = pynutil.insert("name: \"") + convert_space(final_graph) + pynutil.insert("\"")
graph = pynutil.insert('name: "') + convert_space(final_graph) + pynutil.insert('"')
self.fst = graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.de.taggers.decimal import get_quantity, quantities
from fun_text_processing.text_normalization.en.graph_utils import DAMO_SIGMA, GraphFst
@ -15,18 +14,24 @@ class DecimalFst(GraphFst):
tn_decimal_tagger: TN decimal tagger
"""
def __init__(self, itn_cardinal_tagger: GraphFst, tn_decimal_tagger: GraphFst, deterministic: bool = True):
def __init__(
self, itn_cardinal_tagger: GraphFst, tn_decimal_tagger: GraphFst, deterministic: bool = True
):
super().__init__(name="decimal", kind="classify", deterministic=deterministic)
self.graph = tn_decimal_tagger.graph.invert().optimize()
delete_point = pynutil.delete(" komma")
allow_spelling = pynini.cdrewrite(pynini.cross("eine ", "eins ") + quantities, "[BOS]", "[EOS]", DAMO_SIGMA)
allow_spelling = pynini.cdrewrite(
pynini.cross("eine ", "eins ") + quantities, "[BOS]", "[EOS]", DAMO_SIGMA
)
graph_fractional = pynutil.insert("fractional_part: \"") + self.graph + pynutil.insert("\"")
graph_fractional = pynutil.insert('fractional_part: "') + self.graph + pynutil.insert('"')
graph_integer = (
pynutil.insert("integer_part: \"") + itn_cardinal_tagger.graph_no_exception + pynutil.insert("\"")
pynutil.insert('integer_part: "')
+ itn_cardinal_tagger.graph_no_exception
+ pynutil.insert('"')
)
final_graph_wo_sign = graph_integer + delete_point + pynini.accep(" ") + graph_fractional
@ -35,7 +40,8 @@ class DecimalFst(GraphFst):
@ (
final_graph_wo_sign
| get_quantity(
final_graph_wo_sign, itn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit
final_graph_wo_sign,
itn_cardinal_tagger.graph_hundred_component_at_least_one_none_zero_digit,
)
).optimize()
)

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
from pynini.lib import pynutil
@ -8,18 +7,23 @@ class ElectronicFst(GraphFst):
"""
Finite state transducer for classifying electronic: email addresses, etc.
e.g. c d f eins at a b c punkt e d u -> tokens { name: "cdf1.abc.edu" }
Args:
tn_electronic_tagger: TN eletronic tagger
tn_electronic_verbalizer: TN eletronic verbalizer
"""
def __init__(self, tn_electronic_tagger: GraphFst, tn_electronic_verbalizer: GraphFst, deterministic: bool = True):
def __init__(
self,
tn_electronic_tagger: GraphFst,
tn_electronic_verbalizer: GraphFst,
deterministic: bool = True,
):
super().__init__(name="electronic", kind="classify", deterministic=deterministic)
tagger = pynini.invert(tn_electronic_verbalizer.graph).optimize()
verbalizer = pynini.invert(tn_electronic_tagger.graph).optimize()
final_graph = tagger @ verbalizer
graph = pynutil.insert("name: \"") + final_graph + pynutil.insert("\"")
graph = pynutil.insert('name: "') + final_graph + pynutil.insert('"')
self.fst = graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_NOT_QUOTE,
@ -15,28 +14,41 @@ class FractionFst(GraphFst):
e.g. ein halb -> tokens { name: "1/2" }
e.g. ein ein halb -> tokens { name: "1 1/2" }
e.g. drei zwei ein hundertstel -> tokens { name: "3 2/100" }
Args:
itn_cardinal_tagger: ITN cardinal tagger
tn_fraction_verbalizer: TN fraction verbalizer
"""
def __init__(self, itn_cardinal_tagger: GraphFst, tn_fraction_verbalizer: GraphFst, deterministic: bool = True):
def __init__(
self,
itn_cardinal_tagger: GraphFst,
tn_fraction_verbalizer: GraphFst,
deterministic: bool = True,
):
super().__init__(name="fraction", kind="classify", deterministic=deterministic)
tagger = tn_fraction_verbalizer.graph.invert().optimize()
delete_optional_sign = pynini.closure(pynutil.delete("negative: ") + pynini.cross("\"true\" ", "-"), 0, 1)
delete_optional_sign = pynini.closure(
pynutil.delete("negative: ") + pynini.cross('"true" ', "-"), 0, 1
)
delete_integer_marker = (
pynutil.delete("integer_part: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
pynutil.delete('integer_part: "')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete('"')
) @ itn_cardinal_tagger.graph_no_exception
delete_numerator_marker = (
pynutil.delete("numerator: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
pynutil.delete('numerator: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
) @ itn_cardinal_tagger.graph_no_exception
delete_denominator_marker = (
pynutil.insert('/')
+ (pynutil.delete("denominator: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\""))
pynutil.insert("/")
+ (
pynutil.delete('denominator: "')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete('"')
)
@ itn_cardinal_tagger.graph_no_exception
)
@ -50,5 +62,5 @@ class FractionFst(GraphFst):
self.graph = tagger @ verbalizer
graph = pynutil.insert("name: \"") + convert_space(self.graph) + pynutil.insert("\"")
graph = pynutil.insert('name: "') + convert_space(self.graph) + pynutil.insert('"')
self.fst = graph.optimize()

View File

@ -1,6 +1,8 @@
import pynini
from fun_text_processing.text_normalization.de.taggers.measure import singular_to_plural, unit_singular
from fun_text_processing.text_normalization.de.taggers.measure import (
singular_to_plural,
unit_singular,
)
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_SIGMA,
GraphFst,
@ -35,25 +37,31 @@ class MeasureFst(GraphFst):
super().__init__(name="measure", kind="classify", deterministic=deterministic)
cardinal_graph = (
pynini.cdrewrite(pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA)
pynini.cdrewrite(
pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA
)
@ itn_cardinal_tagger.graph_no_exception
)
graph_unit_singular = pynini.invert(unit_singular) # singular -> abbr
unit = (pynini.invert(singular_to_plural()) @ graph_unit_singular) | graph_unit_singular # plural -> abbr
unit = (
pynini.invert(singular_to_plural()) @ graph_unit_singular
) | graph_unit_singular # plural -> abbr
unit = convert_space(unit)
graph_unit_singular = convert_space(graph_unit_singular)
optional_graph_negative = pynini.closure(
pynutil.insert("negative: ") + pynini.cross("minus", "\"true\"") + delete_extra_space, 0, 1
pynutil.insert("negative: ") + pynini.cross("minus", '"true"') + delete_extra_space,
0,
1,
)
unit_misc = pynutil.insert("/") + pynutil.delete("pro") + delete_space + graph_unit_singular
unit = (
pynutil.insert("units: \"")
pynutil.insert('units: "')
+ (unit | unit_misc | pynutil.add_weight(unit + delete_space + unit_misc, 0.01))
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
subgraph_decimal = (
@ -68,9 +76,9 @@ class MeasureFst(GraphFst):
subgraph_fraction = (
pynutil.insert("decimal { ")
+ optional_graph_negative
+ pynutil.insert("integer_part: \"")
+ pynutil.insert('integer_part: "')
+ itn_fraction_tagger.graph
+ pynutil.insert("\" }")
+ pynutil.insert('" }')
+ delete_extra_space
+ unit
)
@ -78,9 +86,9 @@ class MeasureFst(GraphFst):
subgraph_cardinal = (
pynutil.insert("cardinal { ")
+ optional_graph_negative
+ pynutil.insert("integer: \"")
+ pynutil.insert('integer: "')
+ cardinal_graph
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ pynutil.insert(" }")
+ delete_extra_space
+ unit

View File

@ -1,6 +1,9 @@
import pynini
from fun_text_processing.text_normalization.de.taggers.money import maj_singular, min_plural, min_singular
from fun_text_processing.text_normalization.de.taggers.money import (
maj_singular,
min_plural,
min_singular,
)
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_DIGIT,
DAMO_SIGMA,
@ -23,26 +26,35 @@ class MoneyFst(GraphFst):
itn_decimal_tagger: ITN Decimal Tagger
"""
def __init__(self, itn_cardinal_tagger: GraphFst, itn_decimal_tagger: GraphFst, deterministic: bool = True):
def __init__(
self,
itn_cardinal_tagger: GraphFst,
itn_decimal_tagger: GraphFst,
deterministic: bool = True,
):
super().__init__(name="money", kind="classify", deterministic=deterministic)
cardinal_graph = (
pynini.cdrewrite(pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA)
pynini.cdrewrite(
pynini.cross(pynini.union("ein", "eine"), "eins"), "[BOS]", "[EOS]", DAMO_SIGMA
)
@ itn_cardinal_tagger.graph_no_exception
)
graph_decimal_final = itn_decimal_tagger.final_graph_wo_negative
graph_unit = pynini.invert(maj_singular)
graph_unit = pynutil.insert("currency: \"") + convert_space(graph_unit) + pynutil.insert("\"")
graph_unit = pynutil.insert('currency: "') + convert_space(graph_unit) + pynutil.insert('"')
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
pynutil.insert("0") + DAMO_DIGIT
)
min_unit = pynini.project(min_singular | min_plural, "output")
# elf euro (und) vier cent, vier cent
cents_standalone = (
pynutil.insert("fractional_part: \"")
pynutil.insert('fractional_part: "')
+ cardinal_graph @ add_leading_zero_to_double_digit
+ delete_space
+ pynutil.delete(min_unit)
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
optional_cents_standalone = pynini.closure(
@ -56,23 +68,23 @@ class MoneyFst(GraphFst):
# elf euro vierzig, only after integer
optional_cents_suffix = pynini.closure(
delete_extra_space
+ pynutil.insert("fractional_part: \"")
+ pynutil.insert('fractional_part: "')
+ pynutil.add_weight(cardinal_graph @ add_leading_zero_to_double_digit, -0.7)
+ pynutil.insert("\""),
+ pynutil.insert('"'),
0,
1,
)
graph_integer = (
pynutil.insert("integer_part: \"")
pynutil.insert('integer_part: "')
+ cardinal_graph
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ delete_extra_space
+ graph_unit
+ (optional_cents_standalone | optional_cents_suffix)
)
graph_decimal = graph_decimal_final + delete_extra_space + graph_unit
graph_decimal |= pynutil.insert("currency: \"\" integer_part: \"0\" ") + cents_standalone
graph_decimal |= pynutil.insert('currency: "" integer_part: "0" ') + cents_standalone
final_graph = graph_integer | graph_decimal
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst
from pynini.lib import pynutil
@ -14,16 +13,21 @@ class OrdinalFst(GraphFst):
tn_ordinal_verbalizer: TN Ordinal Verbalizer
"""
def __init__(self, itn_cardinal_tagger: GraphFst, tn_ordinal_verbalizer: GraphFst, deterministic: bool = True):
def __init__(
self,
itn_cardinal_tagger: GraphFst,
tn_ordinal_verbalizer: GraphFst,
deterministic: bool = True,
):
super().__init__(name="ordinal", kind="classify", deterministic=deterministic)
tagger = tn_ordinal_verbalizer.graph.invert().optimize()
graph = (
pynutil.delete("integer: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
pynutil.delete('integer: "') + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete('"')
) @ itn_cardinal_tagger.graph
final_graph = tagger @ graph + pynutil.insert(".")
graph = pynutil.insert("name: \"") + final_graph + pynutil.insert("\"")
graph = pynutil.insert('name: "') + final_graph + pynutil.insert('"')
self.fst = graph.optimize()

View File

@ -1,14 +1,17 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, convert_space, insert_space
from fun_text_processing.text_normalization.en.graph_utils import (
GraphFst,
convert_space,
insert_space,
)
from pynini.lib import pynutil
class TelephoneFst(GraphFst):
"""
Finite state transducer for classifying telephone numbers, e.g.
Finite state transducer for classifying telephone numbers, e.g.
null vier eins eins eins zwei drei vier eins zwei drei vier -> tokens { name: "(0411) 1234-1234" }
Args:
tn_cardinal_tagger: TN Cardinal Tagger
"""
@ -35,6 +38,6 @@ class TelephoneFst(GraphFst):
+ digit
)
graph = convert_space(pynini.invert(number_part))
final_graph = pynutil.insert("name: \"") + graph + pynutil.insert("\"")
final_graph = pynutil.insert('name: "') + graph + pynutil.insert('"')
self.fst = final_graph.optimize()

View File

@ -1,5 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_SIGMA, GraphFst
from pynini.lib import pynutil
@ -17,7 +15,7 @@ class TimeFst(GraphFst):
e.g. drei vor zwölf -> time { minutes: "57" hours: "11" }
e.g. drei nach zwölf -> time { minutes: "3" hours: "12" }
e.g. drei uhr zehn minuten zehn sekunden -> time { hours: "3" hours: "10" sekunden: "10"}
Args:
tn_time_verbalizer: TN time verbalizer
"""

View File

@ -1,4 +1,3 @@
import os
import pynini
@ -15,15 +14,27 @@ from fun_text_processing.inverse_text_normalization.de.taggers.time import TimeF
from fun_text_processing.inverse_text_normalization.de.taggers.whitelist import WhiteListFst
from fun_text_processing.inverse_text_normalization.en.taggers.punctuation import PunctuationFst
from fun_text_processing.inverse_text_normalization.en.taggers.word import WordFst
from fun_text_processing.text_normalization.de.taggers.cardinal import CardinalFst as TNCardinalTagger
from fun_text_processing.text_normalization.de.taggers.cardinal import (
CardinalFst as TNCardinalTagger,
)
from fun_text_processing.text_normalization.de.taggers.date import DateFst as TNDateTagger
from fun_text_processing.text_normalization.de.taggers.decimal import DecimalFst as TNDecimalTagger
from fun_text_processing.text_normalization.de.taggers.electronic import ElectronicFst as TNElectronicTagger
from fun_text_processing.text_normalization.de.taggers.whitelist import WhiteListFst as TNWhitelistTagger
from fun_text_processing.text_normalization.de.taggers.electronic import (
ElectronicFst as TNElectronicTagger,
)
from fun_text_processing.text_normalization.de.taggers.whitelist import (
WhiteListFst as TNWhitelistTagger,
)
from fun_text_processing.text_normalization.de.verbalizers.date import DateFst as TNDateVerbalizer
from fun_text_processing.text_normalization.de.verbalizers.electronic import ElectronicFst as TNElectronicVerbalizer
from fun_text_processing.text_normalization.de.verbalizers.fraction import FractionFst as TNFractionVerbalizer
from fun_text_processing.text_normalization.de.verbalizers.ordinal import OrdinalFst as TNOrdinalVerbalizer
from fun_text_processing.text_normalization.de.verbalizers.electronic import (
ElectronicFst as TNElectronicVerbalizer,
)
from fun_text_processing.text_normalization.de.verbalizers.fraction import (
FractionFst as TNFractionVerbalizer,
)
from fun_text_processing.text_normalization.de.verbalizers.ordinal import (
OrdinalFst as TNOrdinalVerbalizer,
)
from fun_text_processing.text_normalization.de.verbalizers.time import TimeFst as TNTimeVerbalizer
from fun_text_processing.text_normalization.en.graph_utils import (
GraphFst,
@ -39,7 +50,7 @@ import logging
class ClassifyFst(GraphFst):
"""
Final class that composes all other classification grammars. This class can process an entire sentence, that is lower cased.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
Args:
@ -47,11 +58,13 @@ class ClassifyFst(GraphFst):
overwrite_cache: set to True to overwrite .far files
"""
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False, deterministic: bool = True):
def __init__(
self, cache_dir: str = None, overwrite_cache: bool = False, deterministic: bool = True
):
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
far_file = None
if cache_dir is not None and cache_dir != 'None':
if cache_dir is not None and cache_dir != "None":
os.makedirs(cache_dir, exist_ok=True)
far_file = os.path.join(cache_dir, "_de_itn.far")
if not overwrite_cache and far_file and os.path.exists(far_file):
@ -63,9 +76,15 @@ class ClassifyFst(GraphFst):
tn_date_tagger = TNDateTagger(cardinal=tn_cardinal_tagger, deterministic=False)
tn_decimal_tagger = TNDecimalTagger(cardinal=tn_cardinal_tagger, deterministic=False)
tn_ordinal_verbalizer = TNOrdinalVerbalizer(deterministic=False)
tn_fraction_verbalizer = TNFractionVerbalizer(ordinal=tn_ordinal_verbalizer, deterministic=False)
tn_time_verbalizer = TNTimeVerbalizer(cardinal_tagger=tn_cardinal_tagger, deterministic=False)
tn_date_verbalizer = TNDateVerbalizer(ordinal=tn_ordinal_verbalizer, deterministic=False)
tn_fraction_verbalizer = TNFractionVerbalizer(
ordinal=tn_ordinal_verbalizer, deterministic=False
)
tn_time_verbalizer = TNTimeVerbalizer(
cardinal_tagger=tn_cardinal_tagger, deterministic=False
)
tn_date_verbalizer = TNDateVerbalizer(
ordinal=tn_ordinal_verbalizer, deterministic=False
)
tn_electronic_tagger = TNElectronicTagger(deterministic=False)
tn_electronic_verbalizer = TNElectronicVerbalizer(deterministic=False)
tn_whitelist_tagger = TNWhitelistTagger(input_case="cased", deterministic=False)
@ -73,19 +92,27 @@ class ClassifyFst(GraphFst):
cardinal = CardinalFst(tn_cardinal_tagger=tn_cardinal_tagger)
cardinal_graph = cardinal.fst
ordinal = OrdinalFst(itn_cardinal_tagger=cardinal, tn_ordinal_verbalizer=tn_ordinal_verbalizer)
ordinal = OrdinalFst(
itn_cardinal_tagger=cardinal, tn_ordinal_verbalizer=tn_ordinal_verbalizer
)
ordinal_graph = ordinal.fst
decimal = DecimalFst(itn_cardinal_tagger=cardinal, tn_decimal_tagger=tn_decimal_tagger)
decimal_graph = decimal.fst
fraction = FractionFst(itn_cardinal_tagger=cardinal, tn_fraction_verbalizer=tn_fraction_verbalizer)
fraction = FractionFst(
itn_cardinal_tagger=cardinal, tn_fraction_verbalizer=tn_fraction_verbalizer
)
fraction_graph = fraction.fst
measure_graph = MeasureFst(
itn_cardinal_tagger=cardinal, itn_decimal_tagger=decimal, itn_fraction_tagger=fraction
itn_cardinal_tagger=cardinal,
itn_decimal_tagger=decimal,
itn_fraction_tagger=fraction,
).fst
date_graph = DateFst(
itn_cardinal_tagger=cardinal, tn_date_verbalizer=tn_date_verbalizer, tn_date_tagger=tn_date_tagger
itn_cardinal_tagger=cardinal,
tn_date_verbalizer=tn_date_verbalizer,
tn_date_tagger=tn_date_tagger,
).fst
word_graph = WordFst().fst
time_graph = TimeFst(tn_time_verbalizer=tn_time_verbalizer).fst
@ -93,7 +120,8 @@ class ClassifyFst(GraphFst):
whitelist_graph = WhiteListFst(tn_whitelist_tagger=tn_whitelist_tagger).fst
punct_graph = PunctuationFst().fst
electronic_graph = ElectronicFst(
tn_electronic_tagger=tn_electronic_tagger, tn_electronic_verbalizer=tn_electronic_verbalizer
tn_electronic_tagger=tn_electronic_tagger,
tn_electronic_verbalizer=tn_electronic_verbalizer,
).fst
telephone_graph = TelephoneFst(tn_cardinal_tagger=tn_cardinal_tagger).fst
@ -112,10 +140,16 @@ class ClassifyFst(GraphFst):
| pynutil.add_weight(word_graph, 100)
)
punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=1.1) + pynutil.insert(" }")
punct = (
pynutil.insert("tokens { ")
+ pynutil.add_weight(punct_graph, weight=1.1)
+ pynutil.insert(" }")
)
token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }")
token_plus_punct = (
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
pynini.closure(punct + pynutil.insert(" "))
+ token
+ pynini.closure(pynutil.insert(" ") + punct)
)
graph = token_plus_punct + pynini.closure(delete_extra_space + token_plus_punct)

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, convert_space
from pynini.lib import pynutil
@ -16,5 +15,5 @@ class WhiteListFst(GraphFst):
super().__init__(name="whitelist", kind="classify", deterministic=deterministic)
whitelist = pynini.invert(tn_whitelist_tagger.graph)
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
graph = pynutil.insert('name: "') + convert_space(whitelist) + pynutil.insert('"')
self.fst = graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst
from pynini.lib import pynutil
@ -16,7 +15,9 @@ class CardinalFst(GraphFst):
def __init__(self, tn_cardinal_verbalizer: GraphFst, deterministic: bool = True):
super().__init__(name="cardinal", kind="verbalize", deterministic=deterministic)
self.numbers = tn_cardinal_verbalizer.numbers
optional_sign = pynini.closure(pynutil.delete("negative: \"") + DAMO_NOT_QUOTE + pynutil.delete("\" "), 0, 1)
optional_sign = pynini.closure(
pynutil.delete('negative: "') + DAMO_NOT_QUOTE + pynutil.delete('" '), 0, 1
)
graph = optional_sign + self.numbers
delete_tokens = self.delete_tokens(graph)
self.fst = delete_tokens.optimize()

View File

@ -1,6 +1,9 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_preserve_order
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_NOT_QUOTE,
GraphFst,
delete_preserve_order,
)
from pynini.lib import pynutil
@ -17,13 +20,17 @@ class DecimalFst(GraphFst):
super().__init__(name="decimal", kind="verbalize", deterministic=deterministic)
delete_space = pynutil.delete(" ")
optional_sign = pynini.closure(
pynutil.delete("negative: \"") + DAMO_NOT_QUOTE + pynutil.delete("\"") + delete_space, 0, 1
pynutil.delete('negative: "') + DAMO_NOT_QUOTE + pynutil.delete('"') + delete_space,
0,
1,
)
optional_integer = pynini.closure(tn_decimal_verbalizer.integer, 0, 1)
optional_fractional = pynini.closure(
delete_space + pynutil.insert(",") + tn_decimal_verbalizer.fractional_default, 0, 1
)
graph = (optional_integer + optional_fractional + tn_decimal_verbalizer.optional_quantity).optimize()
graph = (
optional_integer + optional_fractional + tn_decimal_verbalizer.optional_quantity
).optimize()
self.numbers = optional_sign + graph
graph = self.numbers + delete_preserve_order
delete_tokens = self.delete_tokens(graph)

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
from pynini.lib import pynutil
@ -18,13 +17,13 @@ class MeasureFst(GraphFst):
def __init__(self, decimal: GraphFst, cardinal: GraphFst, deterministic: bool = True):
super().__init__(name="measure", kind="verbalize", deterministic=deterministic)
optional_sign = pynini.closure(pynini.cross("negative: \"true\"", "-"), 0, 1)
optional_sign = pynini.closure(pynini.cross('negative: "true"', "-"), 0, 1)
unit = (
pynutil.delete("units:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_CHAR - " ", 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ delete_space
)
graph_decimal = (

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
from pynini.lib import pynutil
@ -18,9 +17,9 @@ class MoneyFst(GraphFst):
unit = (
pynutil.delete("currency:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_CHAR - " ", 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
graph = unit + delete_space + decimal.numbers
delete_tokens = self.delete_tokens(graph)

View File

@ -1,6 +1,10 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_ALPHA, DAMO_DIGIT, GraphFst, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_ALPHA,
DAMO_DIGIT,
GraphFst,
delete_space,
)
from pynini.lib import pynutil
@ -9,26 +13,35 @@ class TimeFst(GraphFst):
Finite state transducer for verbalizing time, e.g.
time { hours: "8" minutes: "30" zone: "e s t" } -> 08:30 Uhr est
time { hours: "8" } -> 8 Uhr
time { hours: "8" minutes: "30" seconds: "10" } -> 08:30:10 Uhr
time { hours: "8" minutes: "30" seconds: "10" } -> 08:30:10 Uhr
"""
def __init__(self, deterministic: bool = True):
super().__init__(name="time", kind="verbalize", deterministic=deterministic)
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
hour = pynutil.delete("hours: \"") + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete("\"")
minute = pynutil.delete("minutes: \"") + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete("\"")
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
pynutil.insert("0") + DAMO_DIGIT
)
hour = pynutil.delete('hours: "') + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete('"')
minute = pynutil.delete('minutes: "') + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete('"')
second = pynutil.delete("seconds: \"") + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete("\"")
second = pynutil.delete('seconds: "') + pynini.closure(DAMO_DIGIT, 1) + pynutil.delete('"')
zone = (
pynutil.delete("zone: \"") + pynini.closure(DAMO_ALPHA + delete_space) + DAMO_ALPHA + pynutil.delete("\"")
pynutil.delete('zone: "')
+ pynini.closure(DAMO_ALPHA + delete_space)
+ DAMO_ALPHA
+ pynutil.delete('"')
)
optional_zone = pynini.closure(pynini.accep(" ") + zone, 0, 1)
graph = (
delete_space
+ pynutil.insert(":")
+ (minute @ add_leading_zero_to_double_digit)
+ pynini.closure(delete_space + pynutil.insert(":") + (second @ add_leading_zero_to_double_digit), 0, 1)
+ pynini.closure(
delete_space + pynutil.insert(":") + (second @ add_leading_zero_to_double_digit),
0,
1,
)
+ pynutil.insert(" Uhr")
+ optional_zone
)

View File

@ -1,18 +1,21 @@
from fun_text_processing.inverse_text_normalization.de.verbalizers.cardinal import CardinalFst
from fun_text_processing.inverse_text_normalization.de.verbalizers.decimal import DecimalFst
from fun_text_processing.inverse_text_normalization.de.verbalizers.measure import MeasureFst
from fun_text_processing.inverse_text_normalization.de.verbalizers.money import MoneyFst
from fun_text_processing.inverse_text_normalization.de.verbalizers.time import TimeFst
from fun_text_processing.text_normalization.de.verbalizers.cardinal import CardinalFst as TNCardinalVerbalizer
from fun_text_processing.text_normalization.de.verbalizers.decimal import DecimalFst as TNDecimalVerbalizer
from fun_text_processing.text_normalization.de.verbalizers.cardinal import (
CardinalFst as TNCardinalVerbalizer,
)
from fun_text_processing.text_normalization.de.verbalizers.decimal import (
DecimalFst as TNDecimalVerbalizer,
)
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
class VerbalizeFst(GraphFst):
"""
Composes other verbalizer grammars.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
"""

View File

@ -1,14 +1,17 @@
import pynini
from fun_text_processing.inverse_text_normalization.de.verbalizers.verbalize import VerbalizeFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.word import WordFst
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, delete_extra_space, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
GraphFst,
delete_extra_space,
delete_space,
)
from pynini.lib import pynutil
class VerbalizeFinalFst(GraphFst):
"""
Finite state transducer that verbalizes an entire sentence, e.g.
Finite state transducer that verbalizes an entire sentence, e.g.
tokens { name: "jetzt" } tokens { name: "ist" } tokens { time { hours: "12" minutes: "30" } } -> jetzt ist 12:30 Uhr
"""

View File

@ -1,5 +1,7 @@
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import (
ClassifyFst,
)
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize import VerbalizeFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import (
VerbalizeFinalFst,
)

View File

@ -1,5 +1,3 @@
from argparse import ArgumentParser
from typing import List
@ -55,7 +53,7 @@ class Filter:
Args:
processes given instance with process function
Returns: processed instance if instance belongs to expected class type or original instance
"""
if instance.token_type != self.class_type:
@ -73,7 +71,9 @@ def process_cardinal_1(instance: Instance) -> Instance:
normalized = instance.normalized
un_normalized = re.sub(r"[^0-9]", "", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_ordinal_1(instance: Instance) -> bool:
@ -86,7 +86,9 @@ def process_ordinal_1(instance: Instance) -> Instance:
normalized = instance.normalized
un_normalized = re.sub(r"[,\s]", "", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_decimal_1(instance: Instance) -> bool:
@ -99,7 +101,9 @@ def process_decimal_1(instance: Instance) -> Instance:
un_normalized = re.sub(r",", "", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_measure_1(instance: Instance) -> bool:
@ -116,7 +120,9 @@ def process_measure_1(instance: Instance) -> Instance:
normalized = re.sub(r"[^a-z\s]", "", normalized)
normalized = re.sub(r"per ([a-z\s]*)s$", r"per \1", normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_money_1(instance: Instance) -> bool:
@ -133,7 +139,9 @@ def process_money_1(instance: Instance) -> Instance:
un_normalized = re.sub(r"(\d)m\s*$", r"\1 million", un_normalized)
un_normalized = re.sub(r"(\d)bn?\s*$", r"\1 billion", un_normalized)
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_time_1(instance: Instance) -> bool:
@ -148,7 +156,9 @@ def process_time_1(instance: Instance) -> Instance:
un_normalized = re.sub(r"(\d)\s?p\s?m\s?", r"\1 p.m.", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_plain_1(instance: Instance) -> bool:
@ -159,7 +169,9 @@ def filter_plain_1(instance: Instance) -> bool:
def process_plain_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_punct_1(instance: Instance) -> bool:
@ -170,7 +182,9 @@ def filter_punct_1(instance: Instance) -> bool:
def process_punct_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_date_1(instance: Instance) -> bool:
@ -183,7 +197,9 @@ def process_date_1(instance: Instance) -> Instance:
un_normalized = re.sub(r",", "", un_normalized)
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_letters_1(instance: Instance) -> bool:
@ -195,7 +211,9 @@ def process_letters_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_verbatim_1(instance: Instance) -> bool:
@ -206,7 +224,9 @@ def filter_verbatim_1(instance: Instance) -> bool:
def process_verbatim_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_digit_1(instance: Instance) -> bool:
@ -218,7 +238,9 @@ def process_digit_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_telephone_1(instance: Instance) -> bool:
@ -230,7 +252,9 @@ def process_telephone_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_electronic_1(instance: Instance) -> bool:
@ -242,7 +266,9 @@ def process_electronic_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_fraction_1(instance: Instance) -> bool:
@ -254,7 +280,9 @@ def process_fraction_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
def filter_address_1(instance: Instance) -> bool:
@ -266,27 +294,51 @@ def process_address_1(instance: Instance) -> Instance:
un_normalized = instance.un_normalized
normalized = instance.normalized
normalized = re.sub(r"[^a-z ]", "", normalized)
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
return Instance(
token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized
)
filters = []
filters.append(Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1))
filters.append(Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1))
filters.append(Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1))
filters.append(Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1))
filters.append(
Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1)
)
filters.append(
Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1)
)
filters.append(
Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1)
)
filters.append(
Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1)
)
filters.append(Filter(class_type="MONEY", process_func=process_money_1, filter_func=filter_money_1))
filters.append(Filter(class_type="TIME", process_func=process_time_1, filter_func=filter_time_1))
filters.append(Filter(class_type="DATE", process_func=process_date_1, filter_func=filter_date_1))
filters.append(Filter(class_type="PLAIN", process_func=process_plain_1, filter_func=filter_plain_1))
filters.append(Filter(class_type="PUNCT", process_func=process_punct_1, filter_func=filter_punct_1))
filters.append(Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1))
filters.append(Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1))
filters.append(
Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1)
)
filters.append(
Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1)
)
filters.append(Filter(class_type="DIGIT", process_func=process_digit_1, filter_func=filter_digit_1))
filters.append(Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1))
filters.append(Filter(class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1))
filters.append(Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1))
filters.append(Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1))
filters.append(
Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1)
)
filters.append(
Filter(
class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1
)
)
filters.append(
Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1)
)
filters.append(
Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1)
)
filters.append(Filter(class_type=EOS_TYPE, process_func=lambda x: x, filter_func=lambda x: True))
@ -315,8 +367,10 @@ def filter_loaded_data(data: List[Instance], verbose: bool = False) -> List[Inst
def parse_args():
parser = ArgumentParser()
parser.add_argument("--input", help="input file path", type=str, default='./en_with_types/output-00001-of-00100')
parser.add_argument("--verbose", help="print filtered instances", action='store_true')
parser.add_argument(
"--input", help="input file path", type=str, default="./en_with_types/output-00001-of-00100"
)
parser.add_argument("--verbose", help="print filtered instances", action="store_true")
return parser.parse_args()

View File

@ -1,5 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path, num_to_word
from fun_text_processing.text_normalization.en.graph_utils import (
@ -17,7 +15,7 @@ class CardinalFst(GraphFst):
"""
Finite state transducer for classifying cardinals
e.g. minus twenty three -> cardinal { integer: "23" negative: "-" } }
Numbers below thirteen are not converted.
Numbers below thirteen are not converted.
"""
def __init__(self):
@ -29,7 +27,9 @@ class CardinalFst(GraphFst):
graph_hundred = pynini.cross("hundred", "")
graph_hundred_component = pynini.union(graph_digit + delete_space + graph_hundred, pynutil.insert("0"))
graph_hundred_component = pynini.union(
graph_digit + delete_space + graph_hundred, pynutil.insert("0")
)
graph_hundred_component += delete_space
graph_hundred_component += pynini.union(
graph_teen | pynutil.insert("00"),
@ -44,32 +44,46 @@ class CardinalFst(GraphFst):
)
graph_thousands = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("thousand"),
graph_hundred_component_at_least_one_none_zero_digit
+ delete_space
+ pynutil.delete("thousand"),
pynutil.insert("000", weight=0.1),
)
graph_million = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("million"),
graph_hundred_component_at_least_one_none_zero_digit
+ delete_space
+ pynutil.delete("million"),
pynutil.insert("000", weight=0.1),
)
graph_billion = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("billion"),
graph_hundred_component_at_least_one_none_zero_digit
+ delete_space
+ pynutil.delete("billion"),
pynutil.insert("000", weight=0.1),
)
graph_trillion = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("trillion"),
graph_hundred_component_at_least_one_none_zero_digit
+ delete_space
+ pynutil.delete("trillion"),
pynutil.insert("000", weight=0.1),
)
graph_quadrillion = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("quadrillion"),
graph_hundred_component_at_least_one_none_zero_digit
+ delete_space
+ pynutil.delete("quadrillion"),
pynutil.insert("000", weight=0.1),
)
graph_quintillion = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("quintillion"),
graph_hundred_component_at_least_one_none_zero_digit
+ delete_space
+ pynutil.delete("quintillion"),
pynutil.insert("000", weight=0.1),
)
graph_sextillion = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + pynutil.delete("sextillion"),
graph_hundred_component_at_least_one_none_zero_digit
+ delete_space
+ pynutil.delete("sextillion"),
pynutil.insert("000", weight=0.1),
)
@ -93,7 +107,10 @@ class CardinalFst(GraphFst):
)
graph = graph @ pynini.union(
pynutil.delete(pynini.closure("0")) + pynini.difference(DAMO_DIGIT, "0") + pynini.closure(DAMO_DIGIT), "0"
pynutil.delete(pynini.closure("0"))
+ pynini.difference(DAMO_DIGIT, "0")
+ pynini.closure(DAMO_DIGIT),
"0",
)
labels_exception = [num_to_word(x) for x in range(0, 13)]
@ -110,10 +127,12 @@ class CardinalFst(GraphFst):
self.graph = (pynini.project(graph, "input") - graph_exception.arcsort()) @ graph
optional_minus_graph = pynini.closure(
pynutil.insert("negative: ") + pynini.cross("minus", "\"-\"") + DAMO_SPACE, 0, 1
pynutil.insert("negative: ") + pynini.cross("minus", '"-"') + DAMO_SPACE, 0, 1
)
final_graph = optional_minus_graph + pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
final_graph = (
optional_minus_graph + pynutil.insert('integer: "') + self.graph + pynutil.insert('"')
)
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import (
@ -63,7 +62,9 @@ def _get_year_graph():
def _get_thousands_graph():
graph_ties = _get_ties_graph()
graph_hundred_component = (graph_digit + delete_space + pynutil.delete("hundred")) | pynutil.insert("0")
graph_hundred_component = (
graph_digit + delete_space + pynutil.delete("hundred")
) | pynutil.insert("0")
graph = (
graph_digit
+ delete_space
@ -90,7 +91,7 @@ def _get_year_graph():
class DateFst(GraphFst):
"""
Finite state transducer for classifying date,
Finite state transducer for classifying date,
e.g. january fifth twenty twelve -> date { month: "january" day: "5" year: "2012" preserve_order: true }
e.g. the fifth of january twenty twelve -> date { day: "5" month: "january" year: "2012" preserve_order: true }
e.g. twenty twenty -> date { year: "2012" preserve_order: true }
@ -108,18 +109,26 @@ class DateFst(GraphFst):
year_graph = pynutil.add_weight(year_graph, YEAR_WEIGHT)
month_graph = _get_month_graph()
month_graph = pynutil.insert("month: \"") + month_graph + pynutil.insert("\"")
month_graph = pynutil.insert('month: "') + month_graph + pynutil.insert('"')
day_graph = pynutil.insert("day: \"") + pynutil.add_weight(ordinal_graph, -0.7) + pynutil.insert("\"")
day_graph = (
pynutil.insert('day: "') + pynutil.add_weight(ordinal_graph, -0.7) + pynutil.insert('"')
)
graph_year = (
delete_extra_space
+ pynutil.insert("year: \"")
+ pynutil.insert('year: "')
+ pynutil.add_weight(year_graph, -YEAR_WEIGHT)
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
optional_graph_year = pynini.closure(
graph_year,
0,
1,
)
optional_graph_year = pynini.closure(graph_year, 0, 1,)
graph_mdy = month_graph + (
(delete_extra_space + day_graph) | graph_year | (delete_extra_space + day_graph + graph_year)
(delete_extra_space + day_graph)
| graph_year
| (delete_extra_space + day_graph + graph_year)
)
graph_dmy = (
pynutil.delete("the")
@ -131,7 +140,9 @@ class DateFst(GraphFst):
+ month_graph
+ optional_graph_year
)
graph_year = pynutil.insert("year: \"") + (year_graph | _get_range_graph()) + pynutil.insert("\"")
graph_year = (
pynutil.insert('year: "') + (year_graph | _get_range_graph()) + pynutil.insert('"')
)
final_graph = graph_mdy | graph_dmy | graph_year
final_graph += pynutil.insert(" preserve_order: true")

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import (
@ -10,30 +9,42 @@ from fun_text_processing.text_normalization.en.graph_utils import (
from pynini.lib import pynutil
def get_quantity(decimal: 'pynini.FstLike', cardinal_up_to_hundred: 'pynini.FstLike') -> 'pynini.FstLike':
def get_quantity(
decimal: "pynini.FstLike", cardinal_up_to_hundred: "pynini.FstLike"
) -> "pynini.FstLike":
"""
Returns FST that transforms either a cardinal or decimal followed by a quantity into a numeral,
e.g. one million -> integer_part: "1" quantity: "million"
e.g. one point five million -> integer_part: "1" fractional_part: "5" quantity: "million"
Args:
Args:
decimal: decimal FST
cardinal_up_to_hundred: cardinal FST
"""
numbers = cardinal_up_to_hundred @ (
pynutil.delete(pynini.closure("0")) + pynini.difference(DAMO_DIGIT, "0") + pynini.closure(DAMO_DIGIT)
pynutil.delete(pynini.closure("0"))
+ pynini.difference(DAMO_DIGIT, "0")
+ pynini.closure(DAMO_DIGIT)
)
suffix = pynini.union(
"million", "billion", "trillion", "quadrillion", "quintillion", "sextillion"
)
suffix = pynini.union("million", "billion", "trillion", "quadrillion", "quintillion", "sextillion")
res = (
pynutil.insert("integer_part: \"")
pynutil.insert('integer_part: "')
+ numbers
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ delete_extra_space
+ pynutil.insert("quantity: \"")
+ pynutil.insert('quantity: "')
+ suffix
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
res |= (
decimal
+ delete_extra_space
+ pynutil.insert('quantity: "')
+ (suffix | "thousand")
+ pynutil.insert('"')
)
res |= decimal + delete_extra_space + pynutil.insert("quantity: \"") + (suffix | "thousand") + pynutil.insert("\"")
return res
@ -52,7 +63,9 @@ class DecimalFst(GraphFst):
cardinal_graph = cardinal.graph_no_exception
graph_decimal = pynini.string_file(get_abs_path("data/numbers/digit.tsv"))
graph_decimal |= pynini.string_file(get_abs_path("data/numbers/zero.tsv")) | pynini.cross("o", "0")
graph_decimal |= pynini.string_file(get_abs_path("data/numbers/zero.tsv")) | pynini.cross(
"o", "0"
)
graph_decimal = pynini.closure(graph_decimal + delete_space) + graph_decimal
self.graph = graph_decimal
@ -60,13 +73,20 @@ class DecimalFst(GraphFst):
point = pynutil.delete("point")
optional_graph_negative = pynini.closure(
pynutil.insert("negative: ") + pynini.cross("minus", "\"true\"") + delete_extra_space, 0, 1
pynutil.insert("negative: ") + pynini.cross("minus", '"true"') + delete_extra_space,
0,
1,
)
graph_fractional = pynutil.insert("fractional_part: \"") + graph_decimal + pynutil.insert("\"")
graph_integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"")
graph_fractional = (
pynutil.insert('fractional_part: "') + graph_decimal + pynutil.insert('"')
)
graph_integer = pynutil.insert('integer_part: "') + cardinal_graph + pynutil.insert('"')
final_graph_wo_sign = (
pynini.closure(graph_integer + delete_extra_space, 0, 1) + point + delete_extra_space + graph_fractional
pynini.closure(graph_integer + delete_extra_space, 0, 1)
+ point
+ delete_extra_space
+ graph_fractional
)
final_graph = optional_graph_negative + final_graph_wo_sign

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import DAMO_ALPHA, GraphFst, insert_space
@ -25,35 +24,50 @@ class ElectronicFst(GraphFst):
accepted_username = alpha_num | symbols
process_dot = pynini.cross("dot", ".")
username = (alpha_num + pynini.closure(delete_extra_space + accepted_username)) | pynutil.add_weight(
pynini.closure(DAMO_ALPHA, 1), weight=0.0001
)
username = pynutil.insert("username: \"") + username + pynutil.insert("\"")
username = (
alpha_num + pynini.closure(delete_extra_space + accepted_username)
) | pynutil.add_weight(pynini.closure(DAMO_ALPHA, 1), weight=0.0001)
username = pynutil.insert('username: "') + username + pynutil.insert('"')
single_alphanum = pynini.closure(alpha_num + delete_extra_space) + alpha_num
server = single_alphanum | pynini.string_file(get_abs_path("data/electronic/server_name.tsv"))
server = single_alphanum | pynini.string_file(
get_abs_path("data/electronic/server_name.tsv")
)
domain = single_alphanum | pynini.string_file(get_abs_path("data/electronic/domain.tsv"))
domain_graph = (
pynutil.insert("domain: \"")
pynutil.insert('domain: "')
+ server
+ delete_extra_space
+ process_dot
+ delete_extra_space
+ domain
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
graph = (
username
+ delete_extra_space
+ pynutil.delete("at")
+ insert_space
+ delete_extra_space
+ domain_graph
)
graph = username + delete_extra_space + pynutil.delete("at") + insert_space + delete_extra_space + domain_graph
############# url ###
protocol_end = pynini.cross(pynini.union("w w w", "www"), "www")
protocol_start = (pynini.cross("h t t p", "http") | pynini.cross("h t t p s", "https")) + pynini.cross(
" colon slash slash ", "://"
)
protocol_start = (
pynini.cross("h t t p", "http") | pynini.cross("h t t p s", "https")
) + pynini.cross(" colon slash slash ", "://")
# .com,
ending = (
delete_extra_space
+ symbols
+ delete_extra_space
+ (domain | pynini.closure(accepted_username + delete_extra_space,) + accepted_username)
+ (
domain
| pynini.closure(
accepted_username + delete_extra_space,
)
+ accepted_username
)
)
protocol_default = (
@ -64,12 +78,18 @@ class ElectronicFst(GraphFst):
+ pynini.closure(ending, 1)
).optimize()
protocol = (
pynini.closure(protocol_start, 0, 1) + protocol_end + delete_extra_space + process_dot + protocol_default
pynini.closure(protocol_start, 0, 1)
+ protocol_end
+ delete_extra_space
+ process_dot
+ protocol_default
).optimize()
protocol |= pynini.closure(protocol_end + delete_extra_space + process_dot, 0, 1) + protocol_default
protocol |= (
pynini.closure(protocol_end + delete_extra_space + process_dot, 0, 1) + protocol_default
)
protocol = pynutil.insert("protocol: \"") + protocol.optimize() + pynutil.insert("\"")
protocol = pynutil.insert('protocol: "') + protocol.optimize() + pynutil.insert('"')
graph |= protocol
########

View File

@ -1,4 +1,3 @@
from fun_text_processing.text_normalization.en.graph_utils import GraphFst

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import (
@ -32,22 +31,37 @@ class MeasureFst(GraphFst):
graph_unit_plural = get_singulars(graph_unit_singular) # plural -> abbr
optional_graph_negative = pynini.closure(
pynutil.insert("negative: ") + pynini.cross("minus", "\"true\"") + delete_extra_space, 0, 1
pynutil.insert("negative: ") + pynini.cross("minus", '"true"') + delete_extra_space,
0,
1,
)
unit_singular = convert_space(graph_unit_singular)
unit_plural = convert_space(graph_unit_plural)
unit_misc = pynutil.insert("/") + pynutil.delete("per") + delete_space + convert_space(graph_unit_singular)
unit_misc = (
pynutil.insert("/")
+ pynutil.delete("per")
+ delete_space
+ convert_space(graph_unit_singular)
)
unit_singular = (
pynutil.insert("units: \"")
+ (unit_singular | unit_misc | pynutil.add_weight(unit_singular + delete_space + unit_misc, 0.01))
+ pynutil.insert("\"")
pynutil.insert('units: "')
+ (
unit_singular
| unit_misc
| pynutil.add_weight(unit_singular + delete_space + unit_misc, 0.01)
)
+ pynutil.insert('"')
)
unit_plural = (
pynutil.insert("units: \"")
+ (unit_plural | unit_misc | pynutil.add_weight(unit_plural + delete_space + unit_misc, 0.01))
+ pynutil.insert("\"")
pynutil.insert('units: "')
+ (
unit_plural
| unit_misc
| pynutil.add_weight(unit_plural + delete_space + unit_misc, 0.01)
)
+ pynutil.insert('"')
)
subgraph_decimal = (
@ -61,9 +75,9 @@ class MeasureFst(GraphFst):
subgraph_cardinal = (
pynutil.insert("cardinal { ")
+ optional_graph_negative
+ pynutil.insert("integer: \"")
+ pynutil.insert('integer: "')
+ ((DAMO_SIGMA - "one") @ cardinal_graph)
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ pynutil.insert(" }")
+ delete_extra_space
+ unit_plural
@ -71,9 +85,9 @@ class MeasureFst(GraphFst):
subgraph_cardinal |= (
pynutil.insert("cardinal { ")
+ optional_graph_negative
+ pynutil.insert("integer: \"")
+ pynutil.insert('integer: "')
+ pynini.cross("one", "1")
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ pynutil.insert(" }")
+ delete_extra_space
+ unit_singular

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import (
@ -33,8 +32,11 @@ class MoneyFst(GraphFst):
# add support for missing hundred (only for 3 digit numbers)
# "one fifty" -> "one hundred fifty"
with_hundred = pynini.compose(
pynini.closure(DAMO_NOT_SPACE) + pynini.accep(" ") + pynutil.insert("hundred ") + DAMO_SIGMA,
pynini.compose(cardinal_graph, DAMO_DIGIT ** 3),
pynini.closure(DAMO_NOT_SPACE)
+ pynini.accep(" ")
+ pynutil.insert("hundred ")
+ DAMO_SIGMA,
pynini.compose(cardinal_graph, DAMO_DIGIT**3),
)
cardinal_graph |= with_hundred
graph_decimal_final = decimal.final_graph_wo_negative
@ -43,20 +45,27 @@ class MoneyFst(GraphFst):
unit_singular = pynini.invert(unit)
unit_plural = get_singulars(unit_singular)
graph_unit_singular = pynutil.insert("currency: \"") + convert_space(unit_singular) + pynutil.insert("\"")
graph_unit_plural = pynutil.insert("currency: \"") + convert_space(unit_plural) + pynutil.insert("\"")
graph_unit_singular = (
pynutil.insert('currency: "') + convert_space(unit_singular) + pynutil.insert('"')
)
graph_unit_plural = (
pynutil.insert('currency: "') + convert_space(unit_plural) + pynutil.insert('"')
)
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
pynutil.insert("0") + DAMO_DIGIT
)
# twelve dollars (and) fifty cents, zero cents
cents_standalone = (
pynutil.insert("fractional_part: \"")
pynutil.insert('fractional_part: "')
+ pynini.union(
pynutil.add_weight(((DAMO_SIGMA - "one") @ cardinal_graph), -0.7) @ add_leading_zero_to_double_digit
pynutil.add_weight(((DAMO_SIGMA - "one") @ cardinal_graph), -0.7)
@ add_leading_zero_to_double_digit
+ delete_space
+ (pynutil.delete("cents") | pynutil.delete("cent")),
pynini.cross("one", "01") + delete_space + pynutil.delete("cent"),
)
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
optional_cents_standalone = pynini.closure(
@ -70,31 +79,31 @@ class MoneyFst(GraphFst):
# twelve dollars fifty, only after integer
optional_cents_suffix = pynini.closure(
delete_extra_space
+ pynutil.insert("fractional_part: \"")
+ pynutil.insert('fractional_part: "')
+ pynutil.add_weight(cardinal_graph @ add_leading_zero_to_double_digit, -0.7)
+ pynutil.insert("\""),
+ pynutil.insert('"'),
0,
1,
)
graph_integer = (
pynutil.insert("integer_part: \"")
pynutil.insert('integer_part: "')
+ ((DAMO_SIGMA - "one") @ cardinal_graph)
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ delete_extra_space
+ graph_unit_plural
+ (optional_cents_standalone | optional_cents_suffix)
)
graph_integer |= (
pynutil.insert("integer_part: \"")
pynutil.insert('integer_part: "')
+ pynini.cross("one", "1")
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ delete_extra_space
+ graph_unit_singular
+ (optional_cents_standalone | optional_cents_suffix)
)
graph_decimal = graph_decimal_final + delete_extra_space + graph_unit_plural
graph_decimal |= pynutil.insert("currency: \"$\" integer_part: \"0\" ") + cents_standalone
graph_decimal |= pynutil.insert('currency: "$" integer_part: "0" ') + cents_standalone
final_graph = graph_integer | graph_decimal
final_graph = self.add_tokens(final_graph)

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst
@ -25,6 +24,6 @@ class OrdinalFst(GraphFst):
)
self.graph = graph @ cardinal_graph
final_graph = pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
final_graph = pynutil.insert('integer: "') + self.graph + pynutil.insert('"')
final_graph = self.add_tokens(final_graph)
self.fst = final_graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
from pynini.lib import pynutil
@ -13,9 +12,9 @@ class PunctuationFst(GraphFst):
def __init__(self):
super().__init__(name="punctuation", kind="classify")
s = "!#$%&\'()*+,-./:;<=>?@^_`{|}~"
s = "!#$%&'()*+,-./:;<=>?@^_`{|}~"
punct = pynini.union(*s)
graph = pynutil.insert("name: \"") + punct + pynutil.insert("\"")
graph = pynutil.insert('name: "') + punct + pynutil.insert('"')
self.fst = graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import (
@ -24,7 +23,7 @@ def get_serial_number(cardinal):
class TelephoneFst(GraphFst):
"""
Finite state transducer for classifying telephone numbers, e.g.
Finite state transducer for classifying telephone numbers, e.g.
one two three one two three five six seven eight -> { number_part: "123-123-5678" }
This class also support card number and IP format.
@ -61,61 +60,89 @@ class TelephoneFst(GraphFst):
double_digit.invert()
# to handle cases like "one twenty three"
two_digit_cardinal = pynini.compose(cardinal.graph_no_exception, DAMO_DIGIT ** 2)
two_digit_cardinal = pynini.compose(cardinal.graph_no_exception, DAMO_DIGIT**2)
double_digit_to_digit = (
pynini.compose(double_digit, str_to_digit + pynutil.delete(" ") + str_to_digit) | two_digit_cardinal
pynini.compose(double_digit, str_to_digit + pynutil.delete(" ") + str_to_digit)
| two_digit_cardinal
)
single_or_double_digit = (pynutil.add_weight(double_digit_to_digit, -0.0001) | str_to_digit).optimize()
single_or_double_digit = (
pynutil.add_weight(double_digit_to_digit, -0.0001) | str_to_digit
).optimize()
single_or_double_digit |= (
single_or_double_digit
+ pynini.closure(pynutil.add_weight(pynutil.delete(" ") + single_or_double_digit, 0.0001))
+ pynini.closure(
pynutil.add_weight(pynutil.delete(" ") + single_or_double_digit, 0.0001)
)
).optimize()
number_part = pynini.compose(
single_or_double_digit,
DAMO_DIGIT ** 3 + pynutil.insert("-") + DAMO_DIGIT ** 3 + pynutil.insert("-") + DAMO_DIGIT ** 4,
DAMO_DIGIT**3
+ pynutil.insert("-")
+ DAMO_DIGIT**3
+ pynutil.insert("-")
+ DAMO_DIGIT**4,
).optimize()
number_part = pynutil.insert("number_part: \"") + number_part.optimize() + pynutil.insert("\"")
number_part = (
pynutil.insert('number_part: "') + number_part.optimize() + pynutil.insert('"')
)
cardinal_option = pynini.compose(single_or_double_digit, DAMO_DIGIT ** (2, 3))
country_code = (
pynutil.insert("country_code: \"")
pynutil.insert('country_code: "')
+ pynini.closure(pynini.cross("plus ", "+"), 0, 1)
+ ((pynini.closure(str_to_digit + pynutil.delete(" "), 0, 2) + str_to_digit) | cardinal_option)
+ pynutil.insert("\"")
+ (
(pynini.closure(str_to_digit + pynutil.delete(" "), 0, 2) + str_to_digit)
| cardinal_option
)
+ pynutil.insert('"')
)
optional_country_code = pynini.closure(country_code + pynutil.delete(" ") + insert_space, 0, 1).optimize()
optional_country_code = pynini.closure(
country_code + pynutil.delete(" ") + insert_space, 0, 1
).optimize()
graph = optional_country_code + number_part
# credit card number
space_four_digits = insert_space + DAMO_DIGIT ** 4
credit_card_graph = pynini.compose(single_or_double_digit, DAMO_DIGIT ** 4 + space_four_digits ** 3).optimize()
graph |= pynutil.insert("number_part: \"") + credit_card_graph.optimize() + pynutil.insert("\"")
space_four_digits = insert_space + DAMO_DIGIT**4
credit_card_graph = pynini.compose(
single_or_double_digit, DAMO_DIGIT**4 + space_four_digits**3
).optimize()
graph |= (
pynutil.insert('number_part: "') + credit_card_graph.optimize() + pynutil.insert('"')
)
# SSN
ssn_graph = pynini.compose(
single_or_double_digit,
DAMO_DIGIT ** 3 + pynutil.insert("-") + DAMO_DIGIT ** 2 + pynutil.insert("-") + DAMO_DIGIT ** 4,
DAMO_DIGIT**3
+ pynutil.insert("-")
+ DAMO_DIGIT**2
+ pynutil.insert("-")
+ DAMO_DIGIT**4,
).optimize()
graph |= pynutil.insert("number_part: \"") + ssn_graph.optimize() + pynutil.insert("\"")
graph |= pynutil.insert('number_part: "') + ssn_graph.optimize() + pynutil.insert('"')
# ip
digit_or_double = pynini.closure(str_to_digit + pynutil.delete(" "), 0, 1) + double_digit_to_digit
digit_or_double |= double_digit_to_digit + pynini.closure(pynutil.delete(" ") + str_to_digit, 0, 1)
digit_or_double = (
pynini.closure(str_to_digit + pynutil.delete(" "), 0, 1) + double_digit_to_digit
)
digit_or_double |= double_digit_to_digit + pynini.closure(
pynutil.delete(" ") + str_to_digit, 0, 1
)
digit_or_double |= str_to_digit + (pynutil.delete(" ") + str_to_digit) ** (0, 2)
digit_or_double |= cardinal_option
digit_or_double = digit_or_double.optimize()
ip_graph = digit_or_double + (pynini.cross(" dot ", ".") + digit_or_double) ** 3
graph |= pynutil.insert("number_part: \"") + ip_graph.optimize() + pynutil.insert("\"")
graph |= pynutil.insert('number_part: "') + ip_graph.optimize() + pynutil.insert('"')
graph |= (
pynutil.insert("number_part: \"")
pynutil.insert('number_part: "')
+ pynutil.add_weight(get_serial_number(cardinal=cardinal), weight=0.0001)
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
final_graph = self.add_tokens(graph)

View File

@ -1,5 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.taggers.cardinal import CardinalFst
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path, num_to_word
@ -47,21 +45,23 @@ class TimeFst(GraphFst):
graph_minute_verbose = pynini.cross("half", "30") | pynini.cross("quarter", "15")
oclock = pynini.cross(pynini.union("o' clock", "o clock", "o'clock", "oclock"), "")
final_graph_hour = pynutil.insert("hours: \"") + graph_hour + pynutil.insert("\"")
final_graph_hour = pynutil.insert('hours: "') + graph_hour + pynutil.insert('"')
graph_minute = (
oclock + pynutil.insert("00")
| pynutil.delete("o") + delete_space + graph_minute_single
| graph_minute_double
)
final_suffix = pynutil.insert("suffix: \"") + convert_space(suffix_graph) + pynutil.insert("\"")
final_suffix = (
pynutil.insert('suffix: "') + convert_space(suffix_graph) + pynutil.insert('"')
)
final_suffix = delete_space + insert_space + final_suffix
final_suffix_optional = pynini.closure(final_suffix, 0, 1)
final_time_zone_optional = pynini.closure(
delete_space
+ insert_space
+ pynutil.insert("zone: \"")
+ pynutil.insert('zone: "')
+ convert_space(time_zone_graph)
+ pynutil.insert("\""),
+ pynutil.insert('"'),
0,
1,
)
@ -70,13 +70,17 @@ class TimeFst(GraphFst):
# two o eight, two thirty five (am/pm)
# two pm/am
graph_hm = (
final_graph_hour + delete_extra_space + pynutil.insert("minutes: \"") + graph_minute + pynutil.insert("\"")
final_graph_hour
+ delete_extra_space
+ pynutil.insert('minutes: "')
+ graph_minute
+ pynutil.insert('"')
)
# 10 past four, quarter past four, half past four
graph_m_past_h = (
pynutil.insert("minutes: \"")
pynutil.insert('minutes: "')
+ pynini.union(graph_minute_single, graph_minute_double, graph_minute_verbose)
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ delete_space
+ pynutil.delete("past")
+ delete_extra_space
@ -84,42 +88,48 @@ class TimeFst(GraphFst):
)
graph_quarter_time = (
pynutil.insert("minutes: \"")
pynutil.insert('minutes: "')
+ pynini.cross("quarter", "45")
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ delete_space
+ pynutil.delete(pynini.union("to", "till"))
+ delete_extra_space
+ pynutil.insert("hours: \"")
+ pynutil.insert('hours: "')
+ to_hour_graph
+ pynutil.insert("\"")
+ pynutil.insert('"')
)
graph_m_to_h_suffix_time = (
pynutil.insert("minutes: \"")
pynutil.insert('minutes: "')
+ ((graph_minute_single | graph_minute_double).optimize() @ minute_to_graph)
+ pynutil.insert("\"")
+ pynini.closure(delete_space + pynutil.delete(pynini.union("min", "mins", "minute", "minutes")), 0, 1)
+ pynutil.insert('"')
+ pynini.closure(
delete_space + pynutil.delete(pynini.union("min", "mins", "minute", "minutes")),
0,
1,
)
+ delete_space
+ pynutil.delete(pynini.union("to", "till"))
+ delete_extra_space
+ pynutil.insert("hours: \"")
+ pynutil.insert('hours: "')
+ to_hour_graph
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ final_suffix
)
graph_h = (
final_graph_hour
+ delete_extra_space
+ pynutil.insert("minutes: \"")
+ pynutil.insert('minutes: "')
+ (pynutil.insert("00") | graph_minute)
+ pynutil.insert("\"")
+ pynutil.insert('"')
+ final_suffix
+ final_time_zone_optional
)
final_graph = (
(graph_hm | graph_m_past_h | graph_quarter_time) + final_suffix_optional + final_time_zone_optional
(graph_hm | graph_m_past_h | graph_quarter_time)
+ final_suffix_optional
+ final_time_zone_optional
)
final_graph |= graph_h
final_graph |= graph_m_to_h_suffix_time

View File

@ -1,4 +1,3 @@
import os
import pynini
@ -28,7 +27,7 @@ import logging
class ClassifyFst(GraphFst):
"""
Final class that composes all other classification grammars. This class can process an entire sentence, that is lower cased.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
Args:
@ -81,10 +80,16 @@ class ClassifyFst(GraphFst):
| pynutil.add_weight(word_graph, 100)
)
punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=1.1) + pynutil.insert(" }")
punct = (
pynutil.insert("tokens { ")
+ pynutil.add_weight(punct_graph, weight=1.1)
+ pynutil.insert(" }")
)
token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }")
token_plus_punct = (
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
pynini.closure(punct + pynutil.insert(" "))
+ token
+ pynini.closure(pynutil.insert(" ") + punct)
)
graph = token_plus_punct + pynini.closure(delete_extra_space + token_plus_punct)

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.utils import get_abs_path
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, convert_space
@ -16,5 +15,5 @@ class WhiteListFst(GraphFst):
super().__init__(name="whitelist", kind="classify")
whitelist = pynini.string_file(get_abs_path("data/whitelist.tsv")).invert()
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
graph = pynutil.insert('name: "') + convert_space(whitelist) + pynutil.insert('"')
self.fst = graph.optimize()

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_SPACE, GraphFst
from pynini.lib import pynutil
@ -12,5 +11,5 @@ class WordFst(GraphFst):
def __init__(self):
super().__init__(name="word", kind="classify")
word = pynutil.insert("name: \"") + pynini.closure(DAMO_NOT_SPACE, 1) + pynutil.insert("\"")
word = pynutil.insert('name: "') + pynini.closure(DAMO_NOT_SPACE, 1) + pynutil.insert('"')
self.fst = word.optimize()

View File

@ -1,5 +1,3 @@
import os
from typing import Union
@ -15,7 +13,7 @@ def num_to_word(x: Union[str, int]):
Args
x: integer
Returns: spoken representation
Returns: spoken representation
"""
if isinstance(x, int):
x = str(x)
@ -29,7 +27,7 @@ def get_abs_path(rel_path):
Args:
rel_path: relative path to this file
Returns absolute path
"""
return os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
return os.path.dirname(os.path.abspath(__file__)) + "/" + rel_path

View File

@ -1,7 +1,9 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_NOT_QUOTE,
GraphFst,
delete_space,
)
from pynini.lib import pynutil
@ -16,9 +18,9 @@ class CardinalFst(GraphFst):
optional_sign = pynini.closure(
pynutil.delete("negative:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ DAMO_NOT_QUOTE
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ delete_space,
0,
1,
@ -26,9 +28,9 @@ class CardinalFst(GraphFst):
graph = (
pynutil.delete("integer:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
self.numbers = graph
graph = optional_sign + graph

View File

@ -1,5 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_NOT_QUOTE,
@ -22,43 +20,47 @@ class DateFst(GraphFst):
month = (
pynutil.delete("month:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
day = (
pynutil.delete("day:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
year = (
pynutil.delete("year:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
# month (day) year
graph_mdy = (
month + pynini.closure(delete_extra_space + day, 0, 1) + pynini.closure(delete_extra_space + year, 0, 1)
month
+ pynini.closure(delete_extra_space + day, 0, 1)
+ pynini.closure(delete_extra_space + year, 0, 1)
)
# (day) month year
graph_dmy = (
pynini.closure(day + delete_extra_space, 0, 1) + month + pynini.closure(delete_extra_space + year, 0, 1)
pynini.closure(day + delete_extra_space, 0, 1)
+ month
+ pynini.closure(delete_extra_space + year, 0, 1)
)
optional_preserve_order = pynini.closure(
pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space
| pynutil.delete("field_order:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ DAMO_NOT_QUOTE
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ delete_space
)

View File

@ -1,7 +1,9 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_NOT_QUOTE,
GraphFst,
delete_space,
)
from pynini.lib import pynutil
@ -13,30 +15,30 @@ class DecimalFst(GraphFst):
def __init__(self):
super().__init__(name="decimal", kind="verbalize")
optionl_sign = pynini.closure(pynini.cross("negative: \"true\"", "-") + delete_space, 0, 1)
optionl_sign = pynini.closure(pynini.cross('negative: "true"', "-") + delete_space, 0, 1)
integer = (
pynutil.delete("integer_part:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
optional_integer = pynini.closure(integer + delete_space, 0, 1)
fractional = (
pynutil.insert(".")
+ pynutil.delete("fractional_part:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
optional_fractional = pynini.closure(fractional + delete_space, 0, 1)
quantity = (
pynutil.delete("quantity:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
optional_quantity = pynini.closure(pynutil.insert(" ") + quantity + delete_space, 0, 1)
graph = optional_integer + optional_fractional + optional_quantity

View File

@ -1,7 +1,9 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_NOT_QUOTE,
GraphFst,
delete_space,
)
from pynini.lib import pynutil
@ -16,24 +18,24 @@ class ElectronicFst(GraphFst):
user_name = (
pynutil.delete("username:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
domain = (
pynutil.delete("domain:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
protocol = (
pynutil.delete("protocol:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
graph = user_name + delete_space + pynutil.insert("@") + domain

View File

@ -1,10 +1,9 @@
from fun_text_processing.text_normalization.en.graph_utils import GraphFst
class FractionFst(GraphFst):
"""
Finite state transducer for verbalizing fraction,
Finite state transducer for verbalizing fraction,
"""
def __init__(self):

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
from pynini.lib import pynutil
@ -16,13 +15,13 @@ class MeasureFst(GraphFst):
def __init__(self, decimal: GraphFst, cardinal: GraphFst):
super().__init__(name="measure", kind="verbalize")
optional_sign = pynini.closure(pynini.cross("negative: \"true\"", "-"), 0, 1)
optional_sign = pynini.closure(pynini.cross('negative: "true"', "-"), 0, 1)
unit = (
pynutil.delete("units:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_CHAR - " ", 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ delete_space
)
graph_decimal = (

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, GraphFst, delete_space
from pynini.lib import pynutil
@ -18,9 +17,9 @@ class MoneyFst(GraphFst):
unit = (
pynutil.delete("currency:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_CHAR - " ", 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
graph = unit + delete_space + decimal.numbers
delete_tokens = self.delete_tokens(graph)

View File

@ -1,6 +1,10 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, DAMO_SIGMA, GraphFst, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_NOT_QUOTE,
DAMO_SIGMA,
GraphFst,
delete_space,
)
from pynini.lib import pynutil
@ -15,9 +19,9 @@ class OrdinalFst(GraphFst):
graph = (
pynutil.delete("integer:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
convert_eleven = pynini.cross("11", "11th")
convert_twelve = pynini.cross("12", "12th")

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_NOT_QUOTE, GraphFst
from pynini.lib import pynutil
@ -8,17 +7,21 @@ class TelephoneFst(GraphFst):
"""
Finite state transducer for verbalizing telephone, e.g.
telephone { number_part: "123-123-5678" }
-> 123-123-5678
-> 123-123-5678
"""
def __init__(self):
super().__init__(name="telephone", kind="verbalize")
number_part = pynutil.delete("number_part: \"") + pynini.closure(DAMO_NOT_QUOTE, 1) + pynutil.delete("\"")
optional_country_code = pynini.closure(
pynutil.delete("country_code: \"")
number_part = (
pynutil.delete('number_part: "')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
optional_country_code = pynini.closure(
pynutil.delete('country_code: "')
+ pynini.closure(DAMO_NOT_QUOTE, 1)
+ pynutil.delete('"')
+ pynini.accep(" "),
0,
1,

View File

@ -1,4 +1,3 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_CHAR,
@ -20,29 +19,31 @@ class TimeFst(GraphFst):
def __init__(self):
super().__init__(name="time", kind="verbalize")
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (pynutil.insert("0") + DAMO_DIGIT)
add_leading_zero_to_double_digit = (DAMO_DIGIT + DAMO_DIGIT) | (
pynutil.insert("0") + DAMO_DIGIT
)
hour = (
pynutil.delete("hours:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_DIGIT, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
minute = (
pynutil.delete("minutes:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_DIGIT, 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
suffix = (
delete_space
+ insert_space
+ pynutil.delete("suffix:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_CHAR - " ", 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
optional_suffix = pynini.closure(suffix, 0, 1)
zone = (
@ -50,9 +51,9 @@ class TimeFst(GraphFst):
+ insert_space
+ pynutil.delete("zone:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_CHAR - " ", 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
optional_zone = pynini.closure(zone, 0, 1)
graph = (

View File

@ -1,4 +1,3 @@
from fun_text_processing.inverse_text_normalization.en.verbalizers.cardinal import CardinalFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.date import DateFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.decimal import DecimalFst
@ -15,7 +14,7 @@ from fun_text_processing.text_normalization.en.graph_utils import GraphFst
class VerbalizeFst(GraphFst):
"""
Composes other verbalizer grammars.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
More details to deployment at NeMo/tools/text_processing_deployment.
"""

View File

@ -1,14 +1,17 @@
import pynini
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize import VerbalizeFst
from fun_text_processing.inverse_text_normalization.en.verbalizers.word import WordFst
from fun_text_processing.text_normalization.en.graph_utils import GraphFst, delete_extra_space, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
GraphFst,
delete_extra_space,
delete_space,
)
from pynini.lib import pynutil
class VerbalizeFinalFst(GraphFst):
"""
Finite state transducer that verbalizes an entire sentence, e.g.
Finite state transducer that verbalizes an entire sentence, e.g.
tokens { name: "its" } tokens { time { hours: "12" minutes: "30" } } tokens { name: "now" } -> its 12:30 now
"""

View File

@ -1,7 +1,10 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, DAMO_SIGMA, GraphFst, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_CHAR,
DAMO_SIGMA,
GraphFst,
delete_space,
)
from pynini.lib import pynutil
@ -16,9 +19,9 @@ class WhiteListFst(GraphFst):
graph = (
pynutil.delete("name:")
+ delete_space
+ pynutil.delete("\"")
+ pynutil.delete('"')
+ pynini.closure(DAMO_CHAR - " ", 1)
+ pynutil.delete("\"")
+ pynutil.delete('"')
)
graph = graph @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", DAMO_SIGMA)
graph = graph @ pynini.cdrewrite(pynini.cross("\u00A0", " "), "", "", DAMO_SIGMA)
self.fst = graph.optimize()

View File

@ -1,6 +1,10 @@
import pynini
from fun_text_processing.text_normalization.en.graph_utils import DAMO_CHAR, DAMO_SIGMA, GraphFst, delete_space
from fun_text_processing.text_normalization.en.graph_utils import (
DAMO_CHAR,
DAMO_SIGMA,
GraphFst,
delete_space,
)
from pynini.lib import pynutil
@ -13,7 +17,13 @@ class WordFst(GraphFst):
def __init__(self):
super().__init__(name="word", kind="verbalize")
chars = pynini.closure(DAMO_CHAR - " ", 1)
char = pynutil.delete("name:") + delete_space + pynutil.delete("\"") + chars + pynutil.delete("\"")
graph = char @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", DAMO_SIGMA)
char = (
pynutil.delete("name:")
+ delete_space
+ pynutil.delete('"')
+ chars
+ pynutil.delete('"')
)
graph = char @ pynini.cdrewrite(pynini.cross("\u00A0", " "), "", "", DAMO_SIGMA)
self.fst = graph.optimize()

View File

@ -1,4 +1,7 @@
from fun_text_processing.inverse_text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
from fun_text_processing.inverse_text_normalization.es.taggers.tokenize_and_classify import (
ClassifyFst,
)
from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize import VerbalizeFst
from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst
from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize_final import (
VerbalizeFinalFst,
)

Some files were not shown because too many files have changed in this diff Show More