diff --git a/examples/aishell/llm_asr_nar/conf/template.yaml b/examples/aishell/llm_asr_nar/conf/template.yaml index 0b26969be..d52963575 100644 --- a/examples/aishell/llm_asr_nar/conf/template.yaml +++ b/examples/aishell/llm_asr_nar/conf/template.yaml @@ -24,11 +24,11 @@ llm_conf: init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" freeze: true -adaptor: linear +adaptor: Linear adaptor_conf: downsample_rate: 1 llm_dim: 4096 - encoder_dim: 2048 + encoder_dim: 512 # frontend related frontend: WavFrontend @@ -38,54 +38,56 @@ frontend_conf: n_mels: 80 frame_length: 25 frame_shift: 10 - dither: 0.0 - lfr_m: 1 - lfr_n: 1 + lfr_m: 7 + lfr_n: 6 + cmvn_file: "/root/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn" -specaug: SpecAug +specaug: SpecAugLFR specaug_conf: - apply_time_warp: true + apply_time_warp: false time_warp_window: 5 time_warp_mode: bicubic apply_freq_mask: true freq_mask_width_range: - 0 - 30 - num_freq_mask: 2 + lfr_rate: 6 + num_freq_mask: 1 apply_time_mask: true time_mask_width_range: - 0 - - 40 - num_time_mask: 2 + - 12 + num_time_mask: 1 train_conf: accum_grad: 1 grad_clip: 5 max_epoch: 150 keep_nbest_models: 10 - log_interval: 50 + log_interval: 10 -optim: adam +optim: adamw optim_conf: - lr: 0.001 + lr: 0.0001 weight_decay: 0.000001 scheduler: warmuplr scheduler_conf: - warmup_steps: 35000 + warmup_steps: 1500 dataset: AudioLLMDataset dataset_conf: index_ds: IndexDSJsonl batch_sampler: RankFullLocalShuffleBatchSampler batch_type: example # example or length - batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, buffer_size: 500 shuffle: True num_workers: 4 + preprocessor_text: TextPreprocessRemovePunctuation tokenizer: HuggingfaceTokenizer tokenizer_conf: unk_symbol: - init_param_path: null + init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index e5faa2aaa..3b70ad6d4 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -157,8 +157,10 @@ class AutoModel: tokenizer_class = tables.tokenizer_classes.get(tokenizer) tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) kwargs["tokenizer"] = tokenizer - kwargs["token_list"] = tokenizer.token_list - vocab_size = len(tokenizer.token_list) + + kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None + kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] + vocab_size = len(kwargs["token_list"]) else: vocab_size = -1 diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 26b0f4a3c..44d84e7a3 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -85,7 +85,9 @@ def main(**kwargs): # build model model_class = tables.model_classes.get(kwargs["model"]) - model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) + vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None + vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size + model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py index 20eb8aa7c..ab0e48a45 100644 --- a/funasr/datasets/llm_datasets/datasets.py +++ b/funasr/datasets/llm_datasets/datasets.py @@ -24,12 +24,12 @@ class AudioLLMDataset(torch.utils.data.Dataset): preprocessor_speech = kwargs.get("preprocessor_speech", None) if preprocessor_speech: preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) - preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) + preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {})) self.preprocessor_speech = preprocessor_speech preprocessor_text = kwargs.get("preprocessor_text", None) if preprocessor_text: preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) - preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) + preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {})) self.preprocessor_text = preprocessor_text self.frontend = frontend @@ -43,6 +43,7 @@ class AudioLLMDataset(torch.utils.data.Dataset): self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format( self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: " self.prompt_af = "" + self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100) def get_source_len(self, index): item = self.index_ds[index] @@ -64,7 +65,7 @@ class AudioLLMDataset(torch.utils.data.Dataset): if self.preprocessor_speech: data_src = self.preprocessor_speech(data_src, fs=self.fs) speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] - speech = speech.sequeeze(0) + speech = speech.squeeze(0) target = item["target"] if self.preprocessor_text: @@ -91,10 +92,10 @@ class AudioLLMDataset(torch.utils.data.Dataset): label_mask = labels_ids.ge(0) # [False,False,True,True] labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos] - audio_mask = [0] * prompt_pre_length + [1] * audio_length - torch.tensor(audio_mask, dtype=torch.float32) + audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0] + audio_mask = torch.tensor(audio_mask, dtype=torch.float32) - ids = self.tokenizer.encode(target) + ids = self.tokenizer.encode(target) # token ids is different from labels_ids text = torch.tensor(ids, dtype=torch.int64) text_lengths = torch.tensor([len(ids)], dtype=torch.int32) diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py index ab751401b..9f2067258 100644 --- a/funasr/datasets/llm_datasets/preprocessor.py +++ b/funasr/datasets/llm_datasets/preprocessor.py @@ -11,41 +11,27 @@ import torchaudio from torch import nn import random import re +import string from funasr.tokenizer.cleaner import TextCleaner from funasr.register import tables -@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb") -class SpeechPreprocessSpeedPerturb(nn.Module): - def __init__(self, speed_perturb: list=None, **kwargs): - super().__init__() - self.speed_perturb = speed_perturb - - def forward(self, waveform, fs, **kwargs): - if self.speed_perturb is None: - return waveform - speed = random.choice(self.speed_perturb) - if speed != 1.0: - if not isinstance(waveform, torch.Tensor): - waveform = torch.tensor(waveform) - waveform, _ = torchaudio.sox_effects.apply_effects_tensor( - waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]]) - waveform = waveform.view(-1) - - return waveform - -@tables.register("preprocessor_classes", "TextPreprocessSegDict") +@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation") class TextPreprocessSegDict(nn.Module): - def __init__(self, seg_dict: str = None, - text_cleaner: Collection[str] = None, - split_with_space: bool = False, + def __init__(self, **kwargs): super().__init__() - self.text_cleaner = TextCleaner(text_cleaner) def forward(self, text, **kwargs): - text = self.text_cleaner(text) - - return text + # 定义英文标点符号 + en_punct = string.punctuation + # 定义中文标点符号(部分常用的) + cn_punct = '。?!,、;:“”‘’()《》【】…—~·' + # 合并英文和中文标点符号 + all_punct = en_punct + cn_punct + # 创建正则表达式模式,匹配任何在all_punct中的字符 + punct_pattern = re.compile('[{}]'.format(re.escape(all_punct))) + # 使用正则表达式的sub方法替换掉这些字符 + return punct_pattern.sub('', text) diff --git a/funasr/datasets/llm_datasets/scp2jsonl.py b/funasr/datasets/llm_datasets/scp2jsonl.py deleted file mode 100644 index e09a84a61..000000000 --- a/funasr/datasets/llm_datasets/scp2jsonl.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import json -import torch -import logging -import hydra -from omegaconf import DictConfig, OmegaConf -import concurrent.futures -import librosa -import torch.distributed as dist - - - -def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs): - try: - rank = dist.get_rank() - world_size = dist.get_world_size() - except: - rank = 0 - world_size = 1 - - cpu_cores = os.cpu_count() or 1 - print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}") - if rank == 0: - json_dict = {} - for data_type, data_file in zip(data_type_list, path): - json_dict[data_type] = {} - with open(data_file, "r") as f: - - data_file_lists = f.readlines() - lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1 - task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 - with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor: - - futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)] - - for future in concurrent.futures.as_completed(futures): - - json_dict[data_type].update(future.result()) - # print(json_dict) - - with open(jsonl_file_out, "w") as f: - for key in json_dict[data_type_list[0]].keys(): - jsonl_line = {"key": key} - for data_file in data_type_list: - jsonl_line.update(json_dict[data_file][key]) - jsonl_line = json.dumps(jsonl_line, ensure_ascii=False) - f.write(jsonl_line+"\n") - f.flush() - - else: - pass - - if world_size > 1: - dist.barrier() - - -def parse_context_length(data_list: list, data_type: str): - - res = {} - for i, line in enumerate(data_list): - key, line = line.strip().split(maxsplit=1) - line = line.strip() - if os.path.exists(line): - waveform, _ = librosa.load(line, sr=16000) - sample_num = len(waveform) - context_len = int(sample_num//16000*1000/10) - else: - context_len = len(line.split()) if " " in line else len(line) - res[key] = {data_type: line, f"{data_type}_len": context_len} - return res - - -@hydra.main(config_name=None, version_base=None) -def main_hydra(cfg: DictConfig): - - kwargs = OmegaConf.to_container(cfg, resolve=True) - - scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt")) - if isinstance(scp_file_list, str): - scp_file_list = eval(scp_file_list) - data_type_list = kwargs.get("data_type_list", ("source", "target")) - jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl") - gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out) - - -""" -python -m funasr.datasets.audio_datasets.scp2jsonl \ -++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ -++data_type_list='["source", "target"]' \ -++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl -""" - -if __name__ == "__main__": - main_hydra() - - \ No newline at end of file diff --git a/funasr/metrics/compute_acc.py b/funasr/metrics/compute_acc.py index 73545c0ee..ec8067f3c 100644 --- a/funasr/metrics/compute_acc.py +++ b/funasr/metrics/compute_acc.py @@ -35,8 +35,6 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label): """ mask = pad_targets != ignore_label - numerator = torch.sum( - pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) - ) + numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)) denominator = torch.sum(mask) return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type \ No newline at end of file diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index a90326224..06323c637 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -73,7 +73,7 @@ class LLMASRNAR(nn.Module): hub = encoder_conf.get("hub", None) if hub == "funasr": from funasr import AutoModel - init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") + init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") model = AutoModel(model=init_param_path, model_revision="v2.0.4") # frontend = model.kwargs.get("frontend") model.model.decoder = None @@ -179,6 +179,7 @@ class LLMASRNAR(nn.Module): if input_ids is not None: input_ids[input_ids == -1] = 0 + input_ids[input_ids == -100] = 0 if hasattr(self.llm.model, "embed_tokens"): inputs_embeds = self.llm.model.embed_tokens(input_ids) elif hasattr(self.llm.model.model, "embed_tokens"): @@ -190,7 +191,7 @@ class LLMASRNAR(nn.Module): batch_size, token_num, dims = inputs_embeds.shape _, l, _ = encoder_out.shape encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0) - inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None]) + inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None]) inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids) @@ -198,11 +199,10 @@ class LLMASRNAR(nn.Module): stats = {} - if self.metric: - with torch.no_grad(): - preds = torch.argmax(model_outputs.logits, -1) - acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) - stats["acc"] = acc_att + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) + stats["acc"] = acc_att stats["loss"] = torch.clone(loss.detach()) @@ -221,11 +221,12 @@ class LLMASRNAR(nn.Module): batch = {"speech": speech, "speech_lengths": speech_lengths} enc, enc_lens = self.audio_encoder.encode(**batch) - enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :] - pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc, - mask=enc_mask, - target_label_length=audio_token_lengths, - ) + with autocast(False): + enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :] + pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc, + mask=enc_mask, + target_label_length=audio_token_lengths, + ) return pre_acoustic_embeds, pre_token_length diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index 60ddc24e0..4d9f5d8f3 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -10,7 +10,7 @@ import numpy as np from funasr.register import tables from funasr.train_utils.device_funcs import to_device from funasr.models.transformer.utils.nets_utils import make_pad_mask - +from torch.cuda.amp import autocast @tables.register("predictor_classes", "CifPredictor") class CifPredictor(torch.nn.Module): @@ -28,42 +28,44 @@ class CifPredictor(torch.nn.Module): def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None): - h = hidden - context = h.transpose(1, 2) - queries = self.pad(context) - memory = self.cif_conv1d(queries) - output = memory + context - output = self.dropout(output) - output = output.transpose(1, 2) - output = torch.relu(output) - output = self.cif_output(output) - alphas = torch.sigmoid(output) - alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) - if mask is not None: - mask = mask.transpose(-1, -2).float() - alphas = alphas * mask - if mask_chunk_predictor is not None: - alphas = alphas * mask_chunk_predictor - alphas = alphas.squeeze(-1) - mask = mask.squeeze(-1) - if target_label_length is not None: - target_length = target_label_length - elif target_label is not None: - target_length = (target_label != ignore_id).float().sum(-1) - else: - target_length = None - token_num = alphas.sum(-1) - if target_length is not None: - alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1)) - elif self.tail_threshold > 0.0: - hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask) - - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) - - if target_length is None and self.tail_threshold > 0.0: - token_num_int = torch.max(token_num).type(torch.int32).item() - acoustic_embeds = acoustic_embeds[:, :token_num_int, :] + + with autocast(False): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + memory = self.cif_conv1d(queries) + output = memory + context + output = self.dropout(output) + output = output.transpose(1, 2) + output = torch.relu(output) + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + if mask is not None: + mask = mask.transpose(-1, -2).float() + alphas = alphas * mask + if mask_chunk_predictor is not None: + alphas = alphas * mask_chunk_predictor + alphas = alphas.squeeze(-1) + mask = mask.squeeze(-1) + if target_label_length is not None: + target_length = target_label_length + elif target_label is not None: + target_length = (target_label != ignore_id).float().sum(-1) + else: + target_length = None + token_num = alphas.sum(-1) + if target_length is not None: + alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1)) + elif self.tail_threshold > 0.0: + hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask) + + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + if target_length is None and self.tail_threshold > 0.0: + token_num_int = torch.max(token_num).type(torch.int32).item() + acoustic_embeds = acoustic_embeds[:, :token_num_int, :] + return acoustic_embeds, token_num, alphas, cif_peak def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): @@ -169,41 +171,43 @@ class CifPredictorV2(torch.nn.Module): def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None): - h = hidden - context = h.transpose(1, 2) - queries = self.pad(context) - output = torch.relu(self.cif_conv1d(queries)) - output = output.transpose(1, 2) - - output = self.cif_output(output) - alphas = torch.sigmoid(output) - alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) - if mask is not None: - mask = mask.transpose(-1, -2).float() - alphas = alphas * mask - if mask_chunk_predictor is not None: - alphas = alphas * mask_chunk_predictor - alphas = alphas.squeeze(-1) - mask = mask.squeeze(-1) - if target_label_length is not None: - target_length = target_label_length.squeeze(-1) - elif target_label is not None: - target_length = (target_label != ignore_id).float().sum(-1) - else: - target_length = None - token_num = alphas.sum(-1) - if target_length is not None: - alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1)) - elif self.tail_threshold > 0.0: - if self.tail_mask: - hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask) + + with autocast(False): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + output = torch.relu(self.cif_conv1d(queries)) + output = output.transpose(1, 2) + + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + if mask is not None: + mask = mask.transpose(-1, -2).float() + alphas = alphas * mask + if mask_chunk_predictor is not None: + alphas = alphas * mask_chunk_predictor + alphas = alphas.squeeze(-1) + mask = mask.squeeze(-1) + if target_label_length is not None: + target_length = target_label_length.squeeze(-1) + elif target_label is not None: + target_length = (target_label != ignore_id).float().sum(-1) else: - hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None) - - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) - if target_length is None and self.tail_threshold > 0.0: - token_num_int = torch.max(token_num).type(torch.int32).item() - acoustic_embeds = acoustic_embeds[:, :token_num_int, :] + target_length = None + token_num = alphas.sum(-1) + if target_length is not None: + alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1)) + elif self.tail_threshold > 0.0: + if self.tail_mask: + hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask) + else: + hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None) + + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + if target_length is None and self.tail_threshold > 0.0: + token_num_int = torch.max(token_num).type(torch.int32).item() + acoustic_embeds = acoustic_embeds[:, :token_num_int, :] return acoustic_embeds, token_num, alphas, cif_peak @@ -371,62 +375,6 @@ class CifPredictorV2(torch.nn.Module): predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype) return predictor_alignments.detach(), predictor_alignments_length.detach() - def gen_tf2torch_map_dict(self): - - tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch - tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf - map_dict_local = { - ## predictor - "{}.cif_conv1d.weight".format(tensor_name_prefix_torch): - {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": (2, 1, 0), - }, # (256,256,3),(3,256,256) - "{}.cif_conv1d.bias".format(tensor_name_prefix_torch): - {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (256,),(256,) - "{}.cif_output.weight".format(tensor_name_prefix_torch): - {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf), - "squeeze": 0, - "transpose": (1, 0), - }, # (1,256),(1,256,1) - "{}.cif_output.bias".format(tensor_name_prefix_torch): - {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf), - "squeeze": None, - "transpose": None, - }, # (1,),(1,) - } - return map_dict_local - - def convert_tf2torch(self, - var_dict_tf, - var_dict_torch, - ): - map_dict = self.gen_tf2torch_map_dict() - var_dict_torch_update = dict() - for name in sorted(var_dict_torch.keys(), reverse=False): - names = name.split('.') - if names[0] == self.tf2torch_tensor_name_prefix_torch: - name_tf = map_dict[name]["name"] - data_tf = var_dict_tf[name_tf] - if map_dict[name]["squeeze"] is not None: - data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) - if map_dict[name]["transpose"] is not None: - data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) - data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") - assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, - var_dict_torch[ - name].size(), - data_tf.size()) - var_dict_torch_update[name] = data_tf - logging.info( - "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, - var_dict_tf[name_tf].shape)) - - return var_dict_torch_update - class mae_loss(torch.nn.Module): diff --git a/setup.py b/setup.py index f703bb494..4e76c80e1 100644 --- a/setup.py +++ b/setup.py @@ -40,11 +40,11 @@ requirements = { "umap_learn", "jaconv", "hydra-core>=1.3.2", + "tensorboardX", ], # train: The modules invoked when training only. "train": [ "editdistance", - "tensorboardX", ], # all: The modules should be optionally installled due to some reason. # Please consider moving them to "install" occasionally