import os import json import torch import logging import librosa import random import torch.distributed as dist from funasr.register import tables @tables.register("index_ds_classes", "OpenAIIndexDSJsonl") class OpenAIIndexDSJsonl(torch.utils.data.Dataset): # torch.utils.data.Dataset def __init__(self, path: str, **kwargs): super().__init__() self.max_source_length = kwargs.get("max_source_length", 2048) self.min_source_length = kwargs.get("min_source_length", 0) self.max_target_length = kwargs.get("max_target_length", 2048) self.min_target_length = kwargs.get("min_target_length", 0) self.max_token_length = kwargs.get("max_token_length", 2200) is_training = kwargs.get("is_training", True) if not (path.endswith(".jsonl") or path.endswith(".json")): # jsonl list file data_split_num = kwargs.get("data_split_num", 1) data_split_i = kwargs.get("data_split_i", 0) if not is_training: data_split_num = 1 data_split_i = 0 with open(path, encoding="utf-8") as fin: file_list_all = fin.readlines() num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16 file_list = file_list_all[ data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice ] logging.info( f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}" ) else: file_list = [path] contents = [] for file_json in file_list: with open(file_json.strip(), encoding="utf-8") as fin: for line in fin: data = json.loads(line.strip())["messages"] system, user, assistant = [], [], [] for i, item in enumerate(data): role = item["role"] content = item["content"] if role == "system": system.append(content) elif role == "user": user.append(content) elif role == "assistant": assistant.append(content) system = system * len(user) contents_i = {"system": system, "user": user, "assistant": assistant} contents.append(contents_i) self.contents = contents logging.info("total_num of samplers: {}, {}".format(len(self.contents), path)) def __len__(self): return len(self.contents) def __getitem__(self, index): data = self.contents[index] return data def get_source_len(self, data_dict): return len(data_dict["system"]) + len(data_dict["user"]) def get_target_len(self, data_dict): return len(data_dict["assistant"]) if __name__ == "__main__": index_ds = OpenAIIndexDSJsonl( path="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl" ) print(index_ds.contents) pass