This commit is contained in:
木守 2024-09-14 20:22:37 +08:00
parent 1b9300a4c3
commit 8df4b1001e

View File

@ -28,6 +28,7 @@ parser.add_argument(
help="model from modelscope",
)
parser.add_argument("--vad_model_revision", type=str, default="master", help="")
parser.add_argument("--model_path", type=str, default=None, help="model path (vad/sensevoice/qwen/gummy)")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
@ -51,10 +52,15 @@ args = parser.parse_args()
websocket_users = set()
if args.model_path is None:
vad_model_path = args.vad_model
else:
vad_model_path = os.path.join(args.model_path, "vad_model")
print("model loading")
# vad
model_vad = AutoModel(
model=args.vad_model,
model=vad_model_path,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
@ -67,22 +73,25 @@ model_vad = AutoModel(
# chunk_size=60,
)
api = HubApi()
key = "ed70b703-9ec7-44b8-b5ce-5f4527719810"
api.login(key)
if "key" in os.environ:
key = os.environ["key"]
api.login(key)
if args.model_path is None:
api = HubApi()
key = "ed70b703-9ec7-44b8-b5ce-5f4527719810"
api.login(key)
if "key" in os.environ:
key = os.environ["key"]
api.login(key)
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
if args.model_path is None:
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
else:
llm_dir = os.path.join(args.model_path, "llm_model")
audio_encoder_dir = os.path.join(args.model_path, "audio_model")
device = "cuda:0"
all_file_paths = [
# "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240912_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240912_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240902",
@ -101,7 +110,10 @@ DO_ASR_FRAME_INTERVAL = 12
ckpt_dir = all_file_paths[0]
ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master")
if args.model_path is None:
ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master")
else:
ckpt_dir = os.path.join(args.model_path, "gummy_model")
model_llm = AutoModel(
model=ckpt_dir,