FunASR/funasr/bin/compute_audio_cmvn.py
zhifu gao f47d43c020
Dev gzf deepspeed (#1750)
* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* dynamic batch

* start_data_split_i=0

* total_time/accum_grad

* total_time/accum_grad

* total_time/accum_grad

* update avg slice

* update avg slice

* sensevoice sanm

* sensevoice sanm

* add

* add

* add

* add

* deepspeed

* update with main (#1731)

* c++ runtime adapt to 1.0 (#1724)

* adapt vad runtime to 1.0

* add json

* change yml name

* add func LoadVocabFromJson

* add token file for InitAsr

* add token path for OfflineStream

* add funcOpenYaml

* add token file for InitPunc

* add token file for stream

* update punc-model

* update funasr-wss-server

* update runtime_sdk_download_tool.py

* update docker list

* Delete docs/images/wechat.png

* Add files via upload

* Emo2Vec限定选择的情感类别 (#1730)

* 限定选择的情感类别

* 使用none来禁用情感标签输出

* 修改输出接口

* 使用unuse来禁用token

---------

Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>

* bugfix

* v1.0.27

* update docs

* hf hub

* Fix incorrect assignment of 'end' attribute to 'start' in sentences list comprehension (#1680)

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com>

* docs

* docs

* deepspeed

* deepspeed

* deepspeed

* deepspeed

* update

* ds

* ds

* ds

* ds

* ds

* ds

* ds

* add

* add

* bugfix

* add

* wenetspeech

* wenetspeech

* wenetspeech

* wenetspeech

* wenetspeech

* wenetspeech

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com>
2024-05-23 09:57:14 +08:00

139 lines
4.7 KiB
Python

import os
import json
import numpy as np
import torch
import hydra
import logging
from omegaconf import DictConfig, OmegaConf
from funasr.register import tables
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
main(**kwargs)
def main(**kwargs):
print(kwargs)
# set random seed
# tables.print()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
tokenizer = kwargs.get("tokenizer", None)
# build frontend if frontend is none None
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_train = dataset_class(
kwargs.get("train_data_set_list"),
frontend=frontend,
tokenizer=None,
is_training=False,
**kwargs.get("dataset_conf"),
)
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
dataset_conf = kwargs.get("dataset_conf")
dataset_conf["batch_type"] = "example"
dataset_conf["batch_size"] = 1
dataset_conf["num_workers"] = os.cpu_count() or 32
batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
dataloader_train = torch.utils.data.DataLoader(
dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train
)
total_frames = 0
for batch_idx, batch in enumerate(dataloader_train):
iter_stop = int(kwargs.get("scale", -1.0) * len(dataloader_train))
log_step = iter_stop // 100
if batch_idx % log_step == 0:
logging.info(f"prcessed: {batch_idx}/{iter_stop}")
if batch_idx >= iter_stop and iter_stop > 0.0:
logging.info(f"prcessed: {iter_stop}/{iter_stop}")
break
fbank = batch["speech"].numpy()[0, :, :]
if total_frames == 0:
mean_stats = np.sum(fbank, axis=0)
var_stats = np.sum(np.square(fbank), axis=0)
else:
mean_stats += np.sum(fbank, axis=0)
var_stats += np.sum(np.square(fbank), axis=0)
total_frames += fbank.shape[0]
cmvn_info = {
"mean_stats": list(mean_stats.tolist()),
"var_stats": list(var_stats.tolist()),
"total_frames": total_frames,
}
cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
# import pdb;pdb.set_trace()
with open(cmvn_file, "w") as fout:
fout.write(json.dumps(cmvn_info))
mean = -1.0 * mean_stats / total_frames
var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
dims = mean.shape[0]
am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
with open(am_mvn, "w") as fout:
fout.write(
"<Nnet>"
+ "\n"
+ "<Splice> "
+ str(dims)
+ " "
+ str(dims)
+ "\n"
+ "[ 0 ]"
+ "\n"
+ "<AddShift> "
+ str(dims)
+ " "
+ str(dims)
+ "\n"
)
mean_str = str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
fout.write("<LearnRateCoef> 0 " + var_str + "\n")
fout.write("</Nnet>" + "\n")
"""
python funasr/bin/compute_audio_cmvn.py \
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
++dataset_conf.num_workers=0
"""
if __name__ == "__main__":
main_hydra()