This commit is contained in:
游雁 2023-12-07 00:27:00 +08:00
parent 99340740f5
commit fc246ab820
2 changed files with 32 additions and 31 deletions

View File

@ -21,42 +21,43 @@ class AbsTokenizer(ABC):
class BaseTokenizer(ABC):
def __init__(self, token_list: Union[Path, str, Iterable[str]],
def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
unk_symbol: str = "<unk>",
**kwargs,
):
if isinstance(token_list, (Path, str)):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
if token_list is not None:
if isinstance(token_list, (Path, str)):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with token_list.open("r", encoding="utf-8") as f:
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
with token_list.open("r", encoding="utf-8") as f:
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
for i, t in enumerate(self.token_list):
if i == 3:
break
self.token_list_repr += f"{t}, "
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
self.token2id: Dict[str, int] = {}
for i, t in enumerate(self.token_list):
if i == 3:
break
self.token_list_repr += f"{t}, "
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
self.token2id: Dict[str, int] = {}
for i, t in enumerate(self.token_list):
if t in self.token2id:
raise RuntimeError(f'Symbol "{t}" is duplicated')
self.token2id[t] = i
self.unk_symbol = unk_symbol
if self.unk_symbol not in self.token2id:
raise RuntimeError(
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
)
self.unk_id = self.token2id[self.unk_symbol]
if t in self.token2id:
raise RuntimeError(f'Symbol "{t}" is duplicated')
self.token2id[t] = i
self.unk_symbol = unk_symbol
if self.unk_symbol not in self.token2id:
raise RuntimeError(
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
)
self.unk_id = self.token2id[self.unk_symbol]
def encode(self, text):
tokens = self.text2tokens(text)

View File

@ -29,7 +29,7 @@ def build_tokenizer(
delimiter: str = None,
g2p_type: str = None,
**kwargs,
) -> AbsTokenizer:
):
"""A helper function to instantiate Tokenizer"""
if token_type == "bpe":
if bpemodel is None: