diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm.py b/runtime/python/websocket/funasr_wss_server_streaming_llm.py index 50a6f35a1..e3da494e3 100644 --- a/runtime/python/websocket/funasr_wss_server_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_server_streaming_llm.py @@ -28,6 +28,7 @@ parser.add_argument( help="model from modelscope", ) parser.add_argument("--vad_model_revision", type=str, default="master", help="") +parser.add_argument("--model_path", type=str, default=None, help="model path (vad/sensevoice/qwen/gummy)") parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu") parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu") parser.add_argument("--ncpu", type=int, default=4, help="cpu cores") @@ -51,10 +52,15 @@ args = parser.parse_args() websocket_users = set() +if args.model_path is None: + vad_model_path = args.vad_model +else: + vad_model_path = os.path.join(args.model_path, "vad_model") + print("model loading") # vad model_vad = AutoModel( - model=args.vad_model, + model=vad_model_path, model_revision=args.vad_model_revision, ngpu=args.ngpu, ncpu=args.ncpu, @@ -67,22 +73,25 @@ model_vad = AutoModel( # chunk_size=60, ) -api = HubApi() -key = "ed70b703-9ec7-44b8-b5ce-5f4527719810" -api.login(key) -if "key" in os.environ: - key = os.environ["key"] -api.login(key) +if args.model_path is None: + api = HubApi() + key = "ed70b703-9ec7-44b8-b5ce-5f4527719810" + api.login(key) + if "key" in os.environ: + key = os.environ["key"] + api.login(key) # os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope" -llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master") -audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master") -# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct" -# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712" +if args.model_path is None: + llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master") + audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master") +else: + llm_dir = os.path.join(args.model_path, "llm_model") + audio_encoder_dir = os.path.join(args.model_path, "audio_model") + device = "cuda:0" all_file_paths = [ - # "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240912_streaming", "FunAudioLLM/qwen2_7b_mmt_v15_20240912_streaming", "FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming", "FunAudioLLM/qwen2_7b_mmt_v15_20240902", @@ -101,7 +110,10 @@ DO_ASR_FRAME_INTERVAL = 12 ckpt_dir = all_file_paths[0] -ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master") +if args.model_path is None: + ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master") +else: + ckpt_dir = os.path.join(args.model_path, "gummy_model") model_llm = AutoModel( model=ckpt_dir,