mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
onnx libtorch python runtime optim
This commit is contained in:
parent
48e6117a3a
commit
60d38fa9ca
@ -23,9 +23,11 @@ class TokenIDConverter():
|
||||
):
|
||||
check_argument_types()
|
||||
|
||||
# self.token_list = self.load_token(token_path)
|
||||
self.token_list = token_list
|
||||
self.unk_symbol = token_list[-1]
|
||||
self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
@ -38,13 +40,8 @@ class TokenIDConverter():
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
token2id = {v: i for i, v in enumerate(self.token_list)}
|
||||
if self.unk_symbol not in token2id:
|
||||
raise TokenIDConverterError(
|
||||
f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
unk_id = token2id[self.unk_symbol]
|
||||
return [token2id.get(i, unk_id) for i in tokens]
|
||||
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
|
||||
|
||||
class CharTokenizer():
|
||||
|
||||
@ -24,21 +24,11 @@ class TokenIDConverter():
|
||||
):
|
||||
check_argument_types()
|
||||
|
||||
# self.token_list = self.load_token(token_path)
|
||||
self.token_list = token_list
|
||||
self.unk_symbol = token_list[-1]
|
||||
self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
# @staticmethod
|
||||
# def load_token(file_path: Union[Path, str]) -> List:
|
||||
# if not Path(file_path).exists():
|
||||
# raise TokenIDConverterError(f'The {file_path} does not exist.')
|
||||
#
|
||||
# with open(str(file_path), 'rb') as f:
|
||||
# token_list = pickle.load(f)
|
||||
#
|
||||
# if len(token_list) != len(set(token_list)):
|
||||
# raise TokenIDConverterError('The Token exists duplicated symbol.')
|
||||
# return token_list
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
@ -51,13 +41,8 @@ class TokenIDConverter():
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
token2id = {v: i for i, v in enumerate(self.token_list)}
|
||||
if self.unk_symbol not in token2id:
|
||||
raise TokenIDConverterError(
|
||||
f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
unk_id = token2id[self.unk_symbol]
|
||||
return [token2id.get(i, unk_id) for i in tokens]
|
||||
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
|
||||
|
||||
class CharTokenizer():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user