This commit is contained in:
游雁 2024-09-03 11:23:34 +08:00
parent 8fb3ce8796
commit 29717f4361
4 changed files with 33 additions and 13 deletions

View File

@ -973,9 +973,13 @@ class LLMASR4(nn.Module):
lora_init_param_path = lora_conf.get("init_param_path", None)
if lora_init_param_path is not None:
model = PeftModel.from_pretrained(model, lora_init_param_path)
for name, param in model.named_parameters():
if "lora_A" in name or "lora_B" in name:
param.requires_grad = True
else:
peft_config = LoraConfig(**lora_conf)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
if llm_conf.get("activation_checkpoint", False):

View File

@ -225,6 +225,14 @@ class Trainer:
ckpt_name = f"ds-model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
if self.use_lora:
lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
os.makedirs(lora_outdir, exist_ok=True)
if hasattr(model, "module"):
model.module.llm.save_pretrained(lora_outdir)
else:
model.llm.save_pretrained(lora_outdir)
# torch.save(state, filename)
with torch.no_grad():
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)

View File

@ -20,7 +20,7 @@ parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
@ -42,7 +42,7 @@ parser.add_argument(
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--ssl", type=int, default=0, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
@ -255,8 +255,8 @@ async def message(id):
ibest_writer = None
try:
timestamp = int(time.time())
file_name = f'tts_client_{timestamp}.pcm'
file = open(file_name, 'wb')
file_name = f"tts_client_{timestamp}.pcm"
file = open(file_name, "wb")
while True:
meg = await websocket.recv()
if isinstance(meg, bytes):
@ -316,6 +316,8 @@ async def message(id):
print("Exception:", e)
# traceback.print_exc()
# await websocket.close()
async def ws_client(id, chunk_begin, chunk_size):
if args.audio_in is None:
chunk_begin = 0

View File

@ -49,6 +49,7 @@ class NlsTtsSynthesizer:
def on_data(self, data, *args):
self.count += len(data)
print(f"cout: {self.count}")
self.tts_fifo.append(data)
# with open('tts_server.pcm', 'ab') as file:
# file.write(data)
@ -78,6 +79,7 @@ class NlsTtsSynthesizer:
self.started = True
def send_text(self, text):
print(f"text: {text}")
self.sdk.sendStreamInputTts(text)
async def stop(self):
@ -194,7 +196,7 @@ if "appkey" in os.environ:
from modelscope.hub.snapshot_download import snapshot_download
os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
os.environ["MODELSCOPE_CACHE"] = "/mnt/workspace"
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
@ -221,7 +223,7 @@ model_llm = AutoModel(
bf16=False,
llm_dtype="bf16",
max_length=1024,
# llm_kwargs=llm_kwargs,
llm_kwargs=llm_kwargs,
llm_conf={"init_param_path": llm_dir},
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
@ -321,8 +323,10 @@ async def model_inference(
f"generated new text {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}"
)
if len(new_text) > 0:
new_text = new_text.replace("<|im_end|>", "")
res += new_text
synthesizer.send_text(new_text)
res += new_text.replace("<|im_end|>", "")
contents_i[-1]["content"] = res
websocket.llm_state["contents_i"] = contents_i
# history[-1][1] = res
@ -333,7 +337,7 @@ async def model_inference(
"mode": mode,
"text": new_text,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_final": False,
}
)
# print(f"online: {message}")
@ -342,6 +346,7 @@ async def model_inference(
while len(fifo_queue) > 0:
await websocket.send(fifo_queue.popleft())
# synthesizer.send_text(res)
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
synthesizer.stop()
await tts_to_client_task
@ -351,7 +356,7 @@ async def model_inference(
"mode": mode,
"text": res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_final": True,
}
)
# print(f"offline: {message}")
@ -452,7 +457,8 @@ async def ws_serve(websocket, path):
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
if speech_end_i != -1 or not websocket.is_speaking:
# if speech_end_i != -1 or not websocket.is_speaking:
if not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)