This commit is contained in:
游雁 2024-07-01 17:38:55 +08:00
parent 392a93d919
commit 7190d50b27

View File

@ -656,24 +656,27 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
continue
# targets
target_out = f"{target_out}<|im_end|>"
# target_out = f"{target_out}<|im_end|>"
splits = self.pattern.split(target_out)
codec_i = []
sub_token = []
for k, sub_str in enumerate(splits):
if len(sub_str) < 1:
continue
if not sub_str.startswith("<|startofspeech|>"):
sub_str = f"{sub_str}<|im_end|>"
sub_token = self.tokenizer.encode(sub_str)
else:
sub_str = sub_str.replace("<|startofspeech|>", "").replace(
"<|endofspeech|>", ""
)
if not sub_str.startswith("!"):
sub_token_codec = []
for x in sub_str.split("|"):
if x.startswith("c"):
sub_token_codec = int(x[1:])
codec_i = torch.tensor(sub_token_codec, dtype=torch.int64)
codec_i_len = len(sub_token_codec)
sub_token_codec.append(int(x[1:]))
codec_i = torch.tensor(sub_token_codec, dtype=torch.int64)
codec_i_len = len(sub_token_codec)
target_ids = sub_token
if len(codec_i) > 0:
@ -713,9 +716,9 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
"labels_ids": labels,
}
if len(codec) > 0:
codec_i_len = torch.tensor(codec_i_len, dtype=torch.int64)
codec_len = torch.tensor(codec_len, dtype=torch.int32)
output["codec"] = codec
output["codec_i_len"] = codec_i_len
output["codec_len"] = codec_len
break
return output