diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py index b62cc2949..d4dd34ed3 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py @@ -22,10 +22,8 @@ print(res) import soundfile import os -speech, sample_rate = soundfile.read(os.path.expanduser('~')+ - "/.cache/modelscope/hub/damo/"+ - "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/"+ - "example/asr_example.wav") +wav_file = os.path.join(model.model_path, "example/asr_example.wav") +speech, sample_rate = soundfile.read(wav_file) chunk_stride = chunk_size[1] * 960 # 600ms、480ms diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py index e239747ab..7d9c1b964 100644 --- a/funasr/bin/inference.py +++ b/funasr/bin/inference.py @@ -83,7 +83,7 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): return key_list, data_list -@hydra.main(config_name=None) +@hydra.main(config_name=None, version_base=None) def main_hydra(cfg: DictConfig): def to_plain_list(cfg_item): if isinstance(cfg_item, ListConfig): @@ -150,6 +150,7 @@ class AutoModel: self.punc_kwargs = punc_kwargs self.spk_model = spk_model self.spk_kwargs = spk_kwargs + self.model_path = kwargs["model_path"] def build_model(self, **kwargs): diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 1f896b7a3..af3e8afab 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -23,7 +23,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from funasr.download.download_from_hub import download_model from funasr.register import tables -@hydra.main(config_name=None) +@hydra.main(config_name=None, version_base=None) def main_hydra(kwargs: DictConfig): if kwargs.get("debug", False): import pdb; pdb.set_trace() diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py index 73578f25c..946572fce 100644 --- a/funasr/download/download_from_hub.py +++ b/funasr/download/download_from_hub.py @@ -18,6 +18,7 @@ def download_from_ms(**kwargs): model_revision = kwargs.get("model_revision") if not os.path.exists(model_or_path): model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True)) + kwargs["model_path"] = model_or_path config = os.path.join(model_or_path, "config.yaml") if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):