This commit is contained in:
游雁 2024-07-04 13:04:45 +08:00
parent 63800cb852
commit 256defef10
3 changed files with 27 additions and 10 deletions

View File

@ -225,7 +225,7 @@ class AutoModel:
init_param = kwargs.get("init_param", None)
if init_param is not None:
if isinstance(init_param, str):
init_param = [init_param]
init_param = init_param.split(",")
for i, init_param_i in enumerate(init_param):
if os.path.exists(init_param_i):
logging.info(f"Loading pretrained params from ckpt-{i}: {init_param_i}")

View File

@ -59,10 +59,19 @@ def download_from_ms(**kwargs):
elif os.path.exists(os.path.join(model_or_path, "config.yaml")):
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pt")
if "init_param" not in kwargs or not os.path.exists(kwargs["init_param"]):
kwargs["init_param"] = init_param
assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
init_param = kwargs.get("init_param", "")
if not os.path.exists(init_param):
init_param_new = init_param
if isinstance(init_param, str):
init_param = init_param.split(",")
for init_param_i in init_param:
if not os.path.exists(init_param_i):
print(f"init_param: {init_param_i}, does not exist")
init_param_i = os.path.join(model_or_path, "model.pt")
init_param_new = f"{init_param_new},{init_param_i}"
kwargs["init_param"] = init_param_new
# assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
if os.path.exists(os.path.join(model_or_path, "tokens.json")):

View File

@ -2564,8 +2564,16 @@ class LLMASR5(nn.Module):
fbank_beg += [fbank_beg_i + len(input_ids)]
fake_token_len += [fake_token_len_i]
source_mask = [-100] * len(source_ids)
target_out = f"{target_out}<|im_end|>"
target_ids = tokenizer.encode(target_out)
splits = pattern.split(target_out)
for k, sub_str in enumerate(splits):
if len(sub_str) < 1:
continue
if not sub_str.startswith("<|startofspeech|>"):
sub_str = f"{sub_str}<|im_end|>"
sub_token = tokenizer.encode(sub_str)
target_ids = sub_token
# target_out = f"{target_out}<|im_end|>"
# target_ids = tokenizer.encode(target_out)
input_source_ids = input_ids + source_ids
input_ids += source_ids + target_ids
labels += source_mask + target_ids
@ -2740,9 +2748,9 @@ class LLMASR5(nn.Module):
for i in range(token_num):
hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32)
speech_tokens = self.audio_decode(
hidden_states_out, hidden_states_out_len
) # 1xl: 2,10,1023
speech_tokens = self.audio_decode(hidden_states_out, hidden_states_out_len)[
:, :, 0
] # 1xlx1: 2,10,1023
sequences = generated_ids["sequences"]
# generated_ids = [