mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
lora
This commit is contained in:
parent
8fb3ce8796
commit
29717f4361
@ -973,9 +973,13 @@ class LLMASR4(nn.Module):
|
||||
lora_init_param_path = lora_conf.get("init_param_path", None)
|
||||
if lora_init_param_path is not None:
|
||||
model = PeftModel.from_pretrained(model, lora_init_param_path)
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_A" in name or "lora_B" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
peft_config = LoraConfig(**lora_conf)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
if llm_conf.get("activation_checkpoint", False):
|
||||
|
||||
@ -225,6 +225,14 @@ class Trainer:
|
||||
ckpt_name = f"ds-model.pt.ep{epoch}.{step}"
|
||||
filename = os.path.join(self.output_dir, ckpt_name)
|
||||
|
||||
if self.use_lora:
|
||||
lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
|
||||
os.makedirs(lora_outdir, exist_ok=True)
|
||||
if hasattr(model, "module"):
|
||||
model.module.llm.save_pretrained(lora_outdir)
|
||||
else:
|
||||
model.llm.save_pretrained(lora_outdir)
|
||||
|
||||
# torch.save(state, filename)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)
|
||||
|
||||
@ -20,7 +20,7 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
|
||||
parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
|
||||
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
|
||||
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
|
||||
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
|
||||
@ -42,7 +42,7 @@ parser.add_argument(
|
||||
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
|
||||
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
|
||||
parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
|
||||
parser.add_argument("--ssl", type=int, default=0, help="1 for ssl connect, 0 for no ssl")
|
||||
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
|
||||
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
|
||||
|
||||
@ -255,8 +255,8 @@ async def message(id):
|
||||
ibest_writer = None
|
||||
try:
|
||||
timestamp = int(time.time())
|
||||
file_name = f'tts_client_{timestamp}.pcm'
|
||||
file = open(file_name, 'wb')
|
||||
file_name = f"tts_client_{timestamp}.pcm"
|
||||
file = open(file_name, "wb")
|
||||
while True:
|
||||
meg = await websocket.recv()
|
||||
if isinstance(meg, bytes):
|
||||
@ -316,6 +316,8 @@ async def message(id):
|
||||
print("Exception:", e)
|
||||
# traceback.print_exc()
|
||||
# await websocket.close()
|
||||
|
||||
|
||||
async def ws_client(id, chunk_begin, chunk_size):
|
||||
if args.audio_in is None:
|
||||
chunk_begin = 0
|
||||
|
||||
@ -49,6 +49,7 @@ class NlsTtsSynthesizer:
|
||||
|
||||
def on_data(self, data, *args):
|
||||
self.count += len(data)
|
||||
print(f"cout: {self.count}")
|
||||
self.tts_fifo.append(data)
|
||||
# with open('tts_server.pcm', 'ab') as file:
|
||||
# file.write(data)
|
||||
@ -78,6 +79,7 @@ class NlsTtsSynthesizer:
|
||||
self.started = True
|
||||
|
||||
def send_text(self, text):
|
||||
print(f"text: {text}")
|
||||
self.sdk.sendStreamInputTts(text)
|
||||
|
||||
async def stop(self):
|
||||
@ -194,7 +196,7 @@ if "appkey" in os.environ:
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
|
||||
os.environ["MODELSCOPE_CACHE"] = "/mnt/workspace"
|
||||
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
|
||||
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
|
||||
|
||||
@ -221,7 +223,7 @@ model_llm = AutoModel(
|
||||
bf16=False,
|
||||
llm_dtype="bf16",
|
||||
max_length=1024,
|
||||
# llm_kwargs=llm_kwargs,
|
||||
llm_kwargs=llm_kwargs,
|
||||
llm_conf={"init_param_path": llm_dir},
|
||||
tokenizer_conf={"init_param_path": llm_dir},
|
||||
audio_encoder=audio_encoder_dir,
|
||||
@ -321,8 +323,10 @@ async def model_inference(
|
||||
f"generated new text: {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}"
|
||||
)
|
||||
if len(new_text) > 0:
|
||||
new_text = new_text.replace("<|im_end|>", "")
|
||||
res += new_text
|
||||
synthesizer.send_text(new_text)
|
||||
res += new_text.replace("<|im_end|>", "")
|
||||
|
||||
contents_i[-1]["content"] = res
|
||||
websocket.llm_state["contents_i"] = contents_i
|
||||
# history[-1][1] = res
|
||||
@ -333,7 +337,7 @@ async def model_inference(
|
||||
"mode": mode,
|
||||
"text": new_text,
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
# print(f"online: {message}")
|
||||
@ -342,6 +346,7 @@ async def model_inference(
|
||||
while len(fifo_queue) > 0:
|
||||
await websocket.send(fifo_queue.popleft())
|
||||
|
||||
# synthesizer.send_text(res)
|
||||
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
|
||||
synthesizer.stop()
|
||||
await tts_to_client_task
|
||||
@ -351,7 +356,7 @@ async def model_inference(
|
||||
"mode": mode,
|
||||
"text": res,
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
# print(f"offline: {message}")
|
||||
@ -452,7 +457,8 @@ async def ws_serve(websocket, path):
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
# asr punc offline
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
# if speech_end_i != -1 or not websocket.is_speaking:
|
||||
if not websocket.is_speaking:
|
||||
# print("vad end point")
|
||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||
audio_in = b"".join(frames_asr)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user