* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx

* export onnx

* dingding

* dingding

* llm

* doc

* onnx

* onnx

* onnx

* onnx

* onnx

* onnx

* v1.0.15

* qwenaudio

* qwenaudio

* issue doc

* update

* update

* bugfix
This commit is contained in:
zhifu gao 2024-03-12 17:27:02 +08:00 committed by GitHub
parent 68f0603b10
commit c3192dffdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 7 additions and 5 deletions

View File

@ -162,7 +162,8 @@ class AutoModel:
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
tokenizer_conf = kwargs.get("tokenizer_conf", {})
tokenizer = tokenizer_class(**tokenizer_conf)
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None

View File

@ -39,8 +39,7 @@ class AudioLLMDataset(torch.utils.data.Dataset):
self.float_pad_value = float_pad_value
self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
self.prompt_af = ""
self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
self.int_pad_value = self.IGNORE_INDEX

View File

@ -401,4 +401,6 @@ class Trainer:
epoch * len(self.dataloader_val) + batch_idx)
for key, var in speed_stats.items():
self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
epoch * len(self.dataloader_val) + batch_idx)
epoch * len(self.dataloader_val) + batch_idx)
self.model.train()

View File

@ -58,7 +58,7 @@ class CT_Transformer():
model = AutoModel(model=model_dir)
model_dir = model.export(quantize=quantize)
config_file = os.path.join(model_dir, 'confi.yaml')
config_file = os.path.join(model_dir, 'config.yaml')
config = read_yaml(config_file)
token_list = os.path.join(model_dir, 'tokens.json')
with open(token_list, 'r', encoding='utf-8') as f: