mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr2
This commit is contained in:
parent
99340740f5
commit
fc246ab820
@ -21,42 +21,43 @@ class AbsTokenizer(ABC):
|
||||
|
||||
|
||||
class BaseTokenizer(ABC):
|
||||
def __init__(self, token_list: Union[Path, str, Iterable[str]],
|
||||
def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
|
||||
unk_symbol: str = "<unk>",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if isinstance(token_list, (Path, str)):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
if token_list is not None:
|
||||
if isinstance(token_list, (Path, str)):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with token_list.open("r", encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
self.token_list.append(line)
|
||||
|
||||
with token_list.open("r", encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
self.token_list.append(line)
|
||||
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
self.token_list_repr = ""
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
self.token_list_repr = ""
|
||||
for i, t in enumerate(self.token_list):
|
||||
if i == 3:
|
||||
break
|
||||
self.token_list_repr += f"{t}, "
|
||||
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
||||
|
||||
self.token2id: Dict[str, int] = {}
|
||||
for i, t in enumerate(self.token_list):
|
||||
if i == 3:
|
||||
break
|
||||
self.token_list_repr += f"{t}, "
|
||||
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
||||
|
||||
self.token2id: Dict[str, int] = {}
|
||||
for i, t in enumerate(self.token_list):
|
||||
if t in self.token2id:
|
||||
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
||||
self.token2id[t] = i
|
||||
|
||||
self.unk_symbol = unk_symbol
|
||||
if self.unk_symbol not in self.token2id:
|
||||
raise RuntimeError(
|
||||
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
if t in self.token2id:
|
||||
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
||||
self.token2id[t] = i
|
||||
|
||||
self.unk_symbol = unk_symbol
|
||||
if self.unk_symbol not in self.token2id:
|
||||
raise RuntimeError(
|
||||
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.text2tokens(text)
|
||||
|
||||
@ -29,7 +29,7 @@ def build_tokenizer(
|
||||
delimiter: str = None,
|
||||
g2p_type: str = None,
|
||||
**kwargs,
|
||||
) -> AbsTokenizer:
|
||||
):
|
||||
"""A helper function to instantiate Tokenizer"""
|
||||
if token_type == "bpe":
|
||||
if bpemodel is None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user