onnx libtorch python runtime optim

This commit is contained in:
游雁 2023-04-12 10:25:29 +08:00
parent 48e6117a3a
commit 60d38fa9ca
2 changed files with 9 additions and 27 deletions

View File

@ -23,9 +23,11 @@ class TokenIDConverter():
):
check_argument_types()
# self.token_list = self.load_token(token_path)
self.token_list = token_list
self.unk_symbol = token_list[-1]
self.token2id = {v: i for i, v in enumerate(self.token_list)}
self.unk_id = self.token2id[self.unk_symbol]
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
@ -38,13 +40,8 @@ class TokenIDConverter():
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
token2id = {v: i for i, v in enumerate(self.token_list)}
if self.unk_symbol not in token2id:
raise TokenIDConverterError(
f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
)
unk_id = token2id[self.unk_symbol]
return [token2id.get(i, unk_id) for i in tokens]
return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():

View File

@ -24,21 +24,11 @@ class TokenIDConverter():
):
check_argument_types()
# self.token_list = self.load_token(token_path)
self.token_list = token_list
self.unk_symbol = token_list[-1]
self.token2id = {v: i for i, v in enumerate(self.token_list)}
self.unk_id = self.token2id[self.unk_symbol]
# @staticmethod
# def load_token(file_path: Union[Path, str]) -> List:
# if not Path(file_path).exists():
# raise TokenIDConverterError(f'The {file_path} does not exist.')
#
# with open(str(file_path), 'rb') as f:
# token_list = pickle.load(f)
#
# if len(token_list) != len(set(token_list)):
# raise TokenIDConverterError('The Token exists duplicated symbol.')
# return token_list
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
@ -51,13 +41,8 @@ class TokenIDConverter():
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
token2id = {v: i for i, v in enumerate(self.token_list)}
if self.unk_symbol not in token2id:
raise TokenIDConverterError(
f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
)
unk_id = token2id[self.unk_symbol]
return [token2id.get(i, unk_id) for i in tokens]
return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():