mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
ead9addaf7
commit
fa6f60fa76
@ -24,11 +24,11 @@ llm_conf:
|
||||
init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
|
||||
freeze: true
|
||||
|
||||
adaptor: linear
|
||||
adaptor: Linear
|
||||
adaptor_conf:
|
||||
downsample_rate: 1
|
||||
llm_dim: 4096
|
||||
encoder_dim: 2048
|
||||
encoder_dim: 512
|
||||
|
||||
# frontend related
|
||||
frontend: WavFrontend
|
||||
@ -38,54 +38,56 @@ frontend_conf:
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
dither: 0.0
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
lfr_m: 7
|
||||
lfr_n: 6
|
||||
cmvn_file: "/root/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
|
||||
|
||||
specaug: SpecAug
|
||||
specaug: SpecAugLFR
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
apply_time_warp: false
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
lfr_rate: 6
|
||||
num_freq_mask: 1
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
- 12
|
||||
num_time_mask: 1
|
||||
|
||||
train_conf:
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 150
|
||||
keep_nbest_models: 10
|
||||
log_interval: 50
|
||||
log_interval: 10
|
||||
|
||||
optim: adam
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
lr: 0.0001
|
||||
weight_decay: 0.000001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 35000
|
||||
warmup_steps: 1500
|
||||
|
||||
dataset: AudioLLMDataset
|
||||
dataset_conf:
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: RankFullLocalShuffleBatchSampler
|
||||
batch_type: example # example or length
|
||||
batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
buffer_size: 500
|
||||
shuffle: True
|
||||
num_workers: 4
|
||||
preprocessor_text: TextPreprocessRemovePunctuation
|
||||
|
||||
tokenizer: HuggingfaceTokenizer
|
||||
tokenizer_conf:
|
||||
unk_symbol: <unk>
|
||||
init_param_path: null
|
||||
init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
|
||||
|
||||
|
||||
@ -157,8 +157,10 @@ class AutoModel:
|
||||
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
kwargs["token_list"] = tokenizer.token_list
|
||||
vocab_size = len(tokenizer.token_list)
|
||||
|
||||
kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
|
||||
kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
|
||||
vocab_size = len(kwargs["token_list"])
|
||||
else:
|
||||
vocab_size = -1
|
||||
|
||||
|
||||
@ -85,7 +85,9 @@ def main(**kwargs):
|
||||
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"])
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
||||
vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
|
||||
vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -24,12 +24,12 @@ class AudioLLMDataset(torch.utils.data.Dataset):
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
self.frontend = frontend
|
||||
@ -43,6 +43,7 @@ class AudioLLMDataset(torch.utils.data.Dataset):
|
||||
self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
|
||||
self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
|
||||
self.prompt_af = ""
|
||||
self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
|
||||
|
||||
def get_source_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
@ -64,7 +65,7 @@ class AudioLLMDataset(torch.utils.data.Dataset):
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
|
||||
speech = speech.sequeeze(0)
|
||||
speech = speech.squeeze(0)
|
||||
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
@ -91,10 +92,10 @@ class AudioLLMDataset(torch.utils.data.Dataset):
|
||||
label_mask = labels_ids.ge(0) # [False,False,True,True]
|
||||
labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
|
||||
|
||||
audio_mask = [0] * prompt_pre_length + [1] * audio_length
|
||||
torch.tensor(audio_mask, dtype=torch.float32)
|
||||
audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
|
||||
audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
|
||||
|
||||
ids = self.tokenizer.encode(target)
|
||||
ids = self.tokenizer.encode(target) # token ids is different from labels_ids
|
||||
text = torch.tensor(ids, dtype=torch.int64)
|
||||
text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
|
||||
|
||||
|
||||
@ -11,41 +11,27 @@ import torchaudio
|
||||
from torch import nn
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
from funasr.tokenizer.cleaner import TextCleaner
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
|
||||
class SpeechPreprocessSpeedPerturb(nn.Module):
|
||||
def __init__(self, speed_perturb: list=None, **kwargs):
|
||||
super().__init__()
|
||||
self.speed_perturb = speed_perturb
|
||||
|
||||
def forward(self, waveform, fs, **kwargs):
|
||||
if self.speed_perturb is None:
|
||||
return waveform
|
||||
speed = random.choice(self.speed_perturb)
|
||||
if speed != 1.0:
|
||||
if not isinstance(waveform, torch.Tensor):
|
||||
waveform = torch.tensor(waveform)
|
||||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
|
||||
waveform = waveform.view(-1)
|
||||
|
||||
return waveform
|
||||
|
||||
|
||||
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
|
||||
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
|
||||
class TextPreprocessSegDict(nn.Module):
|
||||
def __init__(self, seg_dict: str = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
split_with_space: bool = False,
|
||||
def __init__(self,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
text = self.text_cleaner(text)
|
||||
|
||||
return text
|
||||
# 定义英文标点符号
|
||||
en_punct = string.punctuation
|
||||
# 定义中文标点符号(部分常用的)
|
||||
cn_punct = '。?!,、;:“”‘’()《》【】…—~·'
|
||||
# 合并英文和中文标点符号
|
||||
all_punct = en_punct + cn_punct
|
||||
# 创建正则表达式模式,匹配任何在all_punct中的字符
|
||||
punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
|
||||
# 使用正则表达式的sub方法替换掉这些字符
|
||||
return punct_pattern.sub('', text)
|
||||
|
||||
@ -1,96 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
import concurrent.futures
|
||||
import librosa
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
|
||||
def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
|
||||
cpu_cores = os.cpu_count() or 1
|
||||
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
|
||||
if rank == 0:
|
||||
json_dict = {}
|
||||
for data_type, data_file in zip(data_type_list, path):
|
||||
json_dict[data_type] = {}
|
||||
with open(data_file, "r") as f:
|
||||
|
||||
data_file_lists = f.readlines()
|
||||
lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
|
||||
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
|
||||
|
||||
futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
|
||||
json_dict[data_type].update(future.result())
|
||||
# print(json_dict)
|
||||
|
||||
with open(jsonl_file_out, "w") as f:
|
||||
for key in json_dict[data_type_list[0]].keys():
|
||||
jsonl_line = {"key": key}
|
||||
for data_file in data_type_list:
|
||||
jsonl_line.update(json_dict[data_file][key])
|
||||
jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
|
||||
f.write(jsonl_line+"\n")
|
||||
f.flush()
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def parse_context_length(data_list: list, data_type: str):
|
||||
|
||||
res = {}
|
||||
for i, line in enumerate(data_list):
|
||||
key, line = line.strip().split(maxsplit=1)
|
||||
line = line.strip()
|
||||
if os.path.exists(line):
|
||||
waveform, _ = librosa.load(line, sr=16000)
|
||||
sample_num = len(waveform)
|
||||
context_len = int(sample_num//16000*1000/10)
|
||||
else:
|
||||
context_len = len(line.split()) if " " in line else len(line)
|
||||
res[key] = {data_type: line, f"{data_type}_len": context_len}
|
||||
return res
|
||||
|
||||
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(cfg: DictConfig):
|
||||
|
||||
kwargs = OmegaConf.to_container(cfg, resolve=True)
|
||||
|
||||
scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
|
||||
if isinstance(scp_file_list, str):
|
||||
scp_file_list = eval(scp_file_list)
|
||||
data_type_list = kwargs.get("data_type_list", ("source", "target"))
|
||||
jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
|
||||
gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
|
||||
|
||||
|
||||
"""
|
||||
python -m funasr.datasets.audio_datasets.scp2jsonl \
|
||||
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
|
||||
++data_type_list='["source", "target"]' \
|
||||
++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_hydra()
|
||||
|
||||
|
||||
@ -35,8 +35,6 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
|
||||
"""
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(
|
||||
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
|
||||
)
|
||||
numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
|
||||
@ -73,7 +73,7 @@ class LLMASRNAR(nn.Module):
|
||||
hub = encoder_conf.get("hub", None)
|
||||
if hub == "funasr":
|
||||
from funasr import AutoModel
|
||||
init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||||
init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||||
model = AutoModel(model=init_param_path, model_revision="v2.0.4")
|
||||
# frontend = model.kwargs.get("frontend")
|
||||
model.model.decoder = None
|
||||
@ -179,6 +179,7 @@ class LLMASRNAR(nn.Module):
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids[input_ids == -1] = 0
|
||||
input_ids[input_ids == -100] = 0
|
||||
if hasattr(self.llm.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.embed_tokens(input_ids)
|
||||
elif hasattr(self.llm.model.model, "embed_tokens"):
|
||||
@ -190,7 +191,7 @@ class LLMASRNAR(nn.Module):
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
_, l, _ = encoder_out.shape
|
||||
encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
|
||||
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
|
||||
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
|
||||
inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
|
||||
@ -198,11 +199,10 @@ class LLMASRNAR(nn.Module):
|
||||
|
||||
|
||||
stats = {}
|
||||
if self.metric:
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||
stats["acc"] = acc_att
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
||||
stats["acc"] = acc_att
|
||||
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
@ -221,11 +221,12 @@ class LLMASRNAR(nn.Module):
|
||||
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
enc, enc_lens = self.audio_encoder.encode(**batch)
|
||||
enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
|
||||
mask=enc_mask,
|
||||
target_label_length=audio_token_lengths,
|
||||
)
|
||||
with autocast(False):
|
||||
enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
|
||||
mask=enc_mask,
|
||||
target_label_length=audio_token_lengths,
|
||||
)
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import numpy as np
|
||||
from funasr.register import tables
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
@tables.register("predictor_classes", "CifPredictor")
|
||||
class CifPredictor(torch.nn.Module):
|
||||
@ -28,42 +28,44 @@ class CifPredictor(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
||||
target_label_length=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
memory = self.cif_conv1d(queries)
|
||||
output = memory + context
|
||||
output = self.dropout(output)
|
||||
output = output.transpose(1, 2)
|
||||
output = torch.relu(output)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
elif self.tail_threshold > 0.0:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
|
||||
with autocast(False):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
memory = self.cif_conv1d(queries)
|
||||
output = memory + context
|
||||
output = self.dropout(output)
|
||||
output = output.transpose(1, 2)
|
||||
output = torch.relu(output)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
elif self.tail_threshold > 0.0:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
@ -169,41 +171,43 @@ class CifPredictorV2(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
||||
target_label_length=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length.squeeze(-1)
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
elif self.tail_threshold > 0.0:
|
||||
if self.tail_mask:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
||||
|
||||
with autocast(False):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length.squeeze(-1)
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
elif self.tail_threshold > 0.0:
|
||||
if self.tail_mask:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
||||
else:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
@ -371,62 +375,6 @@ class CifPredictorV2(torch.nn.Module):
|
||||
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
||||
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
## predictor
|
||||
"{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
}, # (256,256,3),(3,256,256)
|
||||
"{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.cif_output.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1,256),(1,256,1)
|
||||
"{}.cif_output.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1,),(1,)
|
||||
}
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class mae_loss(torch.nn.Module):
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@ -40,11 +40,11 @@ requirements = {
|
||||
"umap_learn",
|
||||
"jaconv",
|
||||
"hydra-core>=1.3.2",
|
||||
"tensorboardX",
|
||||
],
|
||||
# train: The modules invoked when training only.
|
||||
"train": [
|
||||
"editdistance",
|
||||
"tensorboardX",
|
||||
],
|
||||
# all: The modules should be optionally installled due to some reason.
|
||||
# Please consider moving them to "install" occasionally
|
||||
|
||||
Loading…
Reference in New Issue
Block a user