mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
63800cb852
commit
256defef10
@ -225,7 +225,7 @@ class AutoModel:
|
||||
init_param = kwargs.get("init_param", None)
|
||||
if init_param is not None:
|
||||
if isinstance(init_param, str):
|
||||
init_param = [init_param]
|
||||
init_param = init_param.split(",")
|
||||
for i, init_param_i in enumerate(init_param):
|
||||
if os.path.exists(init_param_i):
|
||||
logging.info(f"Loading pretrained params from ckpt-{i}: {init_param_i}")
|
||||
|
||||
@ -59,10 +59,19 @@ def download_from_ms(**kwargs):
|
||||
elif os.path.exists(os.path.join(model_or_path, "config.yaml")):
|
||||
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
|
||||
kwargs = OmegaConf.merge(config, kwargs)
|
||||
init_param = os.path.join(model_or_path, "model.pt")
|
||||
if "init_param" not in kwargs or not os.path.exists(kwargs["init_param"]):
|
||||
kwargs["init_param"] = init_param
|
||||
assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
|
||||
|
||||
init_param = kwargs.get("init_param", "")
|
||||
if not os.path.exists(init_param):
|
||||
init_param_new = init_param
|
||||
if isinstance(init_param, str):
|
||||
init_param = init_param.split(",")
|
||||
for init_param_i in init_param:
|
||||
if not os.path.exists(init_param_i):
|
||||
print(f"init_param: {init_param_i}, does not exist")
|
||||
init_param_i = os.path.join(model_or_path, "model.pt")
|
||||
init_param_new = f"{init_param_new},{init_param_i}"
|
||||
kwargs["init_param"] = init_param_new
|
||||
# assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
||||
|
||||
@ -2564,8 +2564,16 @@ class LLMASR5(nn.Module):
|
||||
fbank_beg += [fbank_beg_i + len(input_ids)]
|
||||
fake_token_len += [fake_token_len_i]
|
||||
source_mask = [-100] * len(source_ids)
|
||||
target_out = f"{target_out}<|im_end|>"
|
||||
target_ids = tokenizer.encode(target_out)
|
||||
splits = pattern.split(target_out)
|
||||
for k, sub_str in enumerate(splits):
|
||||
if len(sub_str) < 1:
|
||||
continue
|
||||
if not sub_str.startswith("<|startofspeech|>"):
|
||||
sub_str = f"{sub_str}<|im_end|>"
|
||||
sub_token = tokenizer.encode(sub_str)
|
||||
target_ids = sub_token
|
||||
# target_out = f"{target_out}<|im_end|>"
|
||||
# target_ids = tokenizer.encode(target_out)
|
||||
input_source_ids = input_ids + source_ids
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
@ -2740,9 +2748,9 @@ class LLMASR5(nn.Module):
|
||||
for i in range(token_num):
|
||||
hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32)
|
||||
|
||||
speech_tokens = self.audio_decode(
|
||||
hidden_states_out, hidden_states_out_len
|
||||
) # 1xl: 2,10,1023
|
||||
speech_tokens = self.audio_decode(hidden_states_out, hidden_states_out_len)[
|
||||
:, :, 0
|
||||
] # 1xlx1: 2,10,1023
|
||||
|
||||
sequences = generated_ids["sequences"]
|
||||
# generated_ids = [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user