This commit is contained in:
游雁 2024-04-29 19:14:13 +08:00
parent dd927baf28
commit 03040b04e2
2 changed files with 63 additions and 31 deletions

View File

@ -29,7 +29,6 @@ def gen_jsonl_from_wav_text_list(
with open(data_file, "r") as f: with open(data_file, "r") as f:
data_file_lists = f.readlines() data_file_lists = f.readlines()
print("")
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1 lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
# import pdb;pdb.set_trace() # import pdb;pdb.set_trace()

View File

@ -7,31 +7,68 @@ from omegaconf import DictConfig, OmegaConf
import concurrent.futures import concurrent.futures
import librosa import librosa
import torch.distributed as dist import torch.distributed as dist
import threading
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file): def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu):
jsonl_file_out_f = open(jsonl_file_out, "w")
wav_f = open(wav_scp_file, "w")
text_f = open(text_file, "w")
with open(jsonl_file, encoding="utf-8") as fin: with open(jsonl_file, encoding="utf-8") as fin:
for line in fin: lines = fin.readlines()
data = json.loads(line.strip())
prompt = data.get("prompt", "<ASR>") num_total = len(lines)
source = data[data_type_list[0]] if ncpu > 1:
target = data[data_type_list[1]] # 使用ThreadPoolExecutor限制并发线程数
source_len = data.get("source_len", 1) with ThreadPoolExecutor(max_workers=ncpu) as executor:
target_len = data.get("target_len", 0) # 提交任务到线程池
if "aishell" in source: futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))}
target = target.replace(" ", "")
key = data["key"]
wav_f.write(f"{key}\t{source}\n")
wav_f.flush()
text_f.write(f"{key}\t{target}\n")
text_f.flush()
wav_f.close() # 等待所有任务完成,这会阻塞直到所有提交的任务完成
text_f.close() for future in concurrent.futures.as_completed(futures):
# 这里可以添加额外的逻辑来处理完成的任务,但在这个例子中我们只是等待
pass
else:
for i in range(num_total):
update_data(lines, i)
print("All audio durations have been processed.")
for line in lines:
jsonl_file_out_f.write(line)
jsonl_file_out_f.flush()
jsonl_file_out_f.close()
def update_data(lines, i):
line = lines[i]
data = json.loads(line.strip())
wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
waveform, _ = librosa.load(wav_path, sr=16000)
sample_num = len(waveform)
source_len = int(sample_num / 16000 * 1000 / 10)
source_len_old = data["source_len"]
if source_len_old != source_len:
print(f"wav: {wav_path}, old: {source_len_old}, new: {source_len}")
data["source_len"] = source_len
jsonl_line = json.dumps(data, ensure_ascii=False)
lines[i] = jsonl_line
def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
os.makedirs(jsonl_file_out_dir, exist_ok=True)
with open(jsonl_file_list_in, "r") as f:
data_file_lists = f.readlines()
for i, jsonl in enumerate(data_file_lists):
filename_with_extension = os.path.basename(jsonl.strip())
jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension)
print(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}")
gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu)
@hydra.main(config_name=None, version_base=None) @hydra.main(config_name=None, version_base=None)
@ -40,17 +77,13 @@ def main_hydra(cfg: DictConfig):
kwargs = OmegaConf.to_container(cfg, resolve=True) kwargs = OmegaConf.to_container(cfg, resolve=True)
print(kwargs) print(kwargs)
scp_file_list = kwargs.get( jsonl_file_list_in = kwargs.get(
"scp_file_list", "jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list"
("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
) )
if isinstance(scp_file_list, str): jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp")
scp_file_list = eval(scp_file_list) ncpu = kwargs.get("ncpu", 1)
data_type_list = kwargs.get("data_type_list", ("source", "target")) update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)
jsonl_file = kwargs.get( # gen_scp_from_jsonl(jsonl_file_list_in, jsonl_file_out_dir)
"jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
)
gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
""" """