* funasr1.0.5

* funasr1.0.5 audio samples input

* batch_type token

* batch_type token
This commit is contained in:
zhifu gao 2024-02-01 17:29:28 +08:00 committed by GitHub
parent 0e294ee52f
commit 2ddfc27d5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 7 deletions

View File

@ -25,6 +25,7 @@ print(res)
# example2
import torchaudio
import os
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
input_tensor, sample_rate = torchaudio.load(wav_file)
input_tensor = input_tensor.mean(0)
@ -33,7 +34,7 @@ res = model.generate(input=[input_tensor], batch_size_s=300, is_final=True)
# example3
import soundfile
import os
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
res = model.generate(input=[speech], batch_size_s=300, is_final=True)

View File

@ -154,7 +154,7 @@ def main(**kwargs):
if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,

View File

@ -26,6 +26,8 @@ class BatchSampler(torch.utils.data.BatchSampler):
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle and is_training
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
def __len__(self):
return (self.total_samples-1) // self.batch_size + 1
@ -53,8 +55,10 @@ class BatchSampler(torch.utils.data.BatchSampler):
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.get_source_len(idx_map) + \
self.dataset.get_target_len(idx_map)
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
sample_len_cur = source_len + target_len
datalen_with_index.append([idx, sample_len_cur])
@ -66,7 +70,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_type == 'length':
if self.batch_type != 'example':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)

View File

@ -10,6 +10,8 @@ from torch import nn
from funasr.models.whisper.utils.decoding import detect_language as detect_language_function, decode as decode_function
from funasr.register import tables
@dataclass
class ModelDimensions:
@ -128,6 +130,8 @@ class ResidualAttentionBlock(nn.Module):
return x
@tables.register("encoder_classes", "WhisperEncoder")
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
@ -158,7 +162,7 @@ class AudioEncoder(nn.Module):
x = self.ln_post(x)
return x
@tables.register("decoder_classes", "WhisperDecoder")
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
@ -193,7 +197,7 @@ class TextDecoder(nn.Module):
return logits
@tables.register("model_classes", "Whisper")
class Whisper(nn.Module):
def __init__(self, dims: dict):
super().__init__()