FunASR/funasr/datasets/audio_datasets/espnet_samplers.py
zhifu gao 32e7836645
update with main (#1786)
* add cmakelist

* add paraformer-torch

* add debug for funasr-onnx-offline

* fix redefinition of jieba StdExtension.hpp

* add loading torch models

* update funasr-onnx-offline

* add SwitchArg for wss-server

* add SwitchArg for funasr-onnx-offline

* update cmakelist

* update funasr-onnx-offline-rtf

* add define condition

* add gpu define for offlne-stream

* update com define

* update offline-stream

* update cmakelist

* update func CompileHotwordEmbedding

* add timestamp for paraformer-torch

* add C10_USE_GLOG for paraformer-torch

* update paraformer-torch

* fix func FunASRWfstDecoderInit

* update model.h

* fix func FunASRWfstDecoderInit

* fix tpass_stream

* update paraformer-torch

* add bladedisc for funasr-onnx-offline

* update comdefine

* update funasr-wss-server

* add log for torch

* fix GetValue BLADEDISC

* fix log

* update cmakelist

* update warmup to 10

* update funasrruntime

* add batch_size for wss-server

* add batch for bins

* add batch for offline-stream

* add batch for paraformer

* add batch for offline-stream

* fix func SetBatchSize

* add SetBatchSize for model

* add SetBatchSize for model

* fix func Forward

* fix padding

* update funasrruntime

* add dec reset for batch

* set batch default value

* add argv for CutSplit

* sort frame_queue

* sorted msgs

* fix FunOfflineInfer

* add dynamic batch for fetch

* fix FetchDynamic

* update run_server.sh

* update run_server.sh

* cpp http post server support (#1739)

* add cpp http server

* add some comment

* remove some comments

* del debug infos

* restore run_server.sh

* adapt to new model struct

* 修复了onnxruntime在macos下编译失败的错误 (#1748)

* Add files via upload

增加macos的编译支持

* Add files via upload

增加macos支持

* Add files via upload

target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
添加 if(APPLE) 限制

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>

* Delete docs/images/wechat.png

* Add files via upload

* fixed the issues about seaco-onnx timestamp

* fix bug (#1764)

当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url

* fix empty asr result (#1765)

解码结果为空的语音片段,text 用空字符串

* docs

* docs

* docs

* docs

* docs

* keep empty speech result (#1772)

* docs

* docs

* update wechat QRcode

* Add python funasr api support for websocket srv (#1777)

* add python funasr_api supoort

* change little to README.md

* add core tools stream

* modified a little

* fix bug for timeout

* support for buffer decode

* add ffmpeg decode for buffer

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* Dev gzf exp (#1785)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* dynamic batch

* start_data_split_i=0

* total_time/accum_grad

* total_time/accum_grad

* total_time/accum_grad

* update avg slice

* update avg slice

* sensevoice sanm

* sensevoice sanm

* sensevoice sanm

---------

Co-authored-by: 北念 <lzr265946@alibaba-inc.com>

* auto frontend

---------

Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com>
Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com>
Co-authored-by: Ephemeroptera <605686962@qq.com>
Co-authored-by: 彭震东 <zhendong.peng@qq.com>
Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
Co-authored-by: 北念 <lzr265946@alibaba-inc.com>
2024-06-06 09:54:35 +08:00

166 lines
6.0 KiB
Python

import torch
import numpy as np
import logging
import math
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from torch.utils.data import BatchSampler, Sampler
import torch.distributed as dist
import random
from funasr.register import tables
@tables.register("batch_sampler_classes", "EspnetStyleBatchSampler")
def EspnetStyleBatchSampler_fn(dataset, **kwargs):
dataloader_args = {}
batch_sampler = EspnetStyleBatchSampler(dataset, **kwargs)
dataloader_args["batch_sampler"] = batch_sampler
dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
return dataloader_args
import torch
from torch.utils.data import Dataset, DistributedSampler
import math
import random
class EspnetStyleBatchSampler(DistributedSampler):
def __init__(
self,
dataset,
batch_size,
batch_type="token",
rank=None,
num_replicas=None,
rank_split=False,
shuffle=True,
drop_last=False,
is_training: bool = True,
sort_size: int = 1024,
start_step: int = 0,
**kwargs,
):
try:
rank = dist.get_rank()
num_replicas = dist.get_world_size()
except:
rank = 0
num_replicas = 1
# if rank_split:
# logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
# rank = 0
# num_replicas = 1
self.rank = rank
self.num_replicas = num_replicas
self.dataset = dataset
self.batch_size = batch_size
self.batch_type = batch_type
self.is_training = is_training
self.shuffle = shuffle and is_training
self.drop_last = drop_last
self.total_size = len(self.dataset)
self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
self.epoch = 0
self.sort_size = sort_size * num_replicas
self.max_token_length = kwargs.get("max_token_length", 2048)
self.min_token_length = kwargs.get("min_token_length", 0)
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
self.start_step = start_step
self.batch_num = 1
if self.start_step > 0:
logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
# super().__init__(dataset, num_replicas=num_replicas, rank=rank,
# shuffle=shuffle, drop_last=drop_last)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
random.seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Sort indices by sample length
sorted_indices = sorted(indices, key=lambda idx: self.dataset.get_source_len(idx))
# Organize batches based on 'length' or 'example'
buffer_batches = []
batch = []
max_len_in_batch = 0 # Tracks the max sample length within the current batch
for idx in sorted_indices:
# original_sample_length = self.dataset.get_source_len(idx)
# if (
# original_sample_length < self.min_token_length
# or original_sample_length > self.max_token_length
# ): # Skip samples that exceed the max length
# continue
# sample_length = 1 if self.batch_type == "example" else original_sample_length
# Set sample_length based on the batch type
if self.batch_type == "example":
sample_length = 1
elif self.batch_type == "token":
sample_length = self.dataset.get_source_len(idx) + int(
self.dataset.get_target_len(idx) * 1.2
)
else:
sample_length = self.dataset.get_source_len(idx)
# Calculate potential batch size with the new sample
potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1)
# Add index to batch if it doesn't exceed batch size limit
if potential_batch_length <= self.batch_size:
batch.append(idx)
max_len_in_batch = max(max_len_in_batch, sample_length)
else:
# Save the current batch and start a new one
buffer_batches.append(batch)
batch = [idx]
max_len_in_batch = sample_length
# Add the last batch if it shouldn't be dropped
if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
buffer_batches.append(batch)
# Shuffle the list of batches
if self.shuffle:
random.seed(self.epoch)
random.shuffle(buffer_batches)
# Ensure each rank gets the same number of batches
batches_per_rank = int(math.ceil(len(buffer_batches) / self.num_replicas))
total_batches_needed = batches_per_rank * self.num_replicas
extra_batches = total_batches_needed - len(buffer_batches)
# Add extra batches by random selection, if needed
buffer_batches += random.choices(buffer_batches, k=extra_batches)
# Allocate the batches to the current rank
start_idx = self.rank * batches_per_rank
end_idx = start_idx + batches_per_rank
rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
self.batch_num = len(rank_batches)
logging.info(
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
)
# Return an iterator over the batches for the current rank
return iter(rank_batches)
def __len__(self):
# Calculate the number of batches per epoch for the current rank
return self.batch_num
def set_epoch(self, epoch):
# Set the epoch for shuffling
self.epoch = epoch