diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py index cafc43bf9..86e78bc4a 100644 --- a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py +++ b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py @@ -23,9 +23,11 @@ class TokenIDConverter(): ): check_argument_types() - # self.token_list = self.load_token(token_path) self.token_list = token_list self.unk_symbol = token_list[-1] + self.token2id = {v: i for i, v in enumerate(self.token_list)} + self.unk_id = self.token2id[self.unk_symbol] + def get_num_vocabulary_size(self) -> int: return len(self.token_list) @@ -38,13 +40,8 @@ class TokenIDConverter(): return [self.token_list[i] for i in integers] def tokens2ids(self, tokens: Iterable[str]) -> List[int]: - token2id = {v: i for i, v in enumerate(self.token_list)} - if self.unk_symbol not in token2id: - raise TokenIDConverterError( - f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" - ) - unk_id = token2id[self.unk_symbol] - return [token2id.get(i, unk_id) for i in tokens] + + return [self.token2id.get(i, self.unk_id) for i in tokens] class CharTokenizer(): diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py index 0df954ed7..78c3f0d98 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py @@ -24,21 +24,11 @@ class TokenIDConverter(): ): check_argument_types() - # self.token_list = self.load_token(token_path) self.token_list = token_list self.unk_symbol = token_list[-1] + self.token2id = {v: i for i, v in enumerate(self.token_list)} + self.unk_id = self.token2id[self.unk_symbol] - # @staticmethod - # def load_token(file_path: Union[Path, str]) -> List: - # if not Path(file_path).exists(): - # raise TokenIDConverterError(f'The {file_path} does not exist.') - # - # with open(str(file_path), 'rb') as f: - # token_list = pickle.load(f) - # - # if len(token_list) != len(set(token_list)): - # raise TokenIDConverterError('The Token exists duplicated symbol.') - # return token_list def get_num_vocabulary_size(self) -> int: return len(self.token_list) @@ -51,13 +41,8 @@ class TokenIDConverter(): return [self.token_list[i] for i in integers] def tokens2ids(self, tokens: Iterable[str]) -> List[int]: - token2id = {v: i for i, v in enumerate(self.token_list)} - if self.unk_symbol not in token2id: - raise TokenIDConverterError( - f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" - ) - unk_id = token2id[self.unk_symbol] - return [token2id.get(i, unk_id) for i in tokens] + + return [self.token2id.get(i, self.unk_id) for i in tokens] class CharTokenizer():