This commit is contained in:
speech_asr 2023-03-21 16:28:22 +08:00
parent 837c5001d4
commit 8314c5f17e
3 changed files with 47 additions and 6 deletions

View File

@ -1,10 +1,16 @@
import logging import logging
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import yaml import sentencepiece as spm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.datasets.large_datasets.dataset import Dataset from funasr.datasets.large_datasets.dataset import Dataset
from funasr.iterators.abs_iter_factory import AbsIterFactory from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.text.abs_tokenizer import AbsTokenizer
def read_symbol_table(symbol_table_file): def read_symbol_table(symbol_table_file):
@ -21,6 +27,7 @@ def read_symbol_table(symbol_table_file):
symbol_table[char] = i symbol_table[char] = i
return symbol_table return symbol_table
def load_seg_dict(seg_dict_file): def load_seg_dict(seg_dict_file):
seg_dict = {} seg_dict = {}
assert isinstance(seg_dict_file, str) assert isinstance(seg_dict_file, str)
@ -33,8 +40,33 @@ def load_seg_dict(seg_dict_file):
seg_dict[key] = " ".join(value) seg_dict[key] = " ".join(value)
return seg_dict return seg_dict
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
assert check_argument_types()
self.model = str(model)
self.sp = None
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.model}")'
def _build_sentence_piece_processor(self):
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.model)
def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
class ArkDataLoader(AbsIterFactory): class ArkDataLoader(AbsIterFactory):
def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None, mode="train"): def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
bpemodel_file=None, mode="train"):
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
if seg_dict_file is not None: if seg_dict_file is not None:
seg_dict = load_seg_dict(seg_dict_file) seg_dict = load_seg_dict(seg_dict_file)
@ -48,7 +80,11 @@ class ArkDataLoader(AbsIterFactory):
self.frontend_conf = frontend_conf self.frontend_conf = frontend_conf
logging.info("dataloader config: {}".format(self.dataset_conf)) logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding") batch_mode = self.dataset_conf.get("batch_mode", "padding")
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, if bpemodel_file is not None:
bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
else:
bpe_tokenizer = None
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode) self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True): def build_iter(self, epoch, shuffle=True):

View File

@ -158,6 +158,7 @@ def Dataset(data_list_file,
dict, dict,
seg_dict, seg_dict,
punc_dict, punc_dict,
bpe_tokenizer,
conf, conf,
frontend_conf, frontend_conf,
mode="train", mode="train",
@ -173,7 +174,7 @@ def Dataset(data_list_file,
dataset = FilterIterDataPipe(dataset, fn=filter_fn) dataset = FilterIterDataPipe(dataset, fn=filter_fn)
if "text" in data_names: if "text" in data_names:
vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict} vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
tokenize_fn = partial(tokenize, **vocab) tokenize_fn = partial(tokenize, **vocab)
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn) dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)

View File

@ -28,13 +28,17 @@ def seg_tokenize(txt, seg_dict):
def tokenize(data, def tokenize(data,
vocab=None, vocab=None,
seg_dict=None, seg_dict=None,
punc_dict=None): punc_dict=None,
bpe_tokenizer=None):
assert "text" in data assert "text" in data
assert isinstance(vocab, dict) assert isinstance(vocab, dict)
text = data["text"] text = data["text"]
token = [] token = []
vad = -2 vad = -2
if bpe_tokenizer is not None:
text = bpe_tokenizer.text2tokens(text)
if seg_dict is not None: if seg_dict is not None:
assert isinstance(seg_dict, dict) assert isinstance(seg_dict, dict)
txt = forward_segment("".join(text).lower(), seg_dict) txt = forward_segment("".join(text).lower(), seg_dict)