Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
This commit is contained in:
aky15 2023-06-27 09:59:50 +08:00 committed by GitHub
parent 3a19144712
commit cdf117b974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 132 additions and 18 deletions

View File

@ -6,7 +6,7 @@ encoder_conf:
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
left_chunk_size: 0
left_chunk_size: 1
embed_vgg_like: false
subsampling_factor: 4
linear_units: 2048
@ -51,7 +51,7 @@ use_amp: true
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 200
max_epoch: 120
val_scheduler_criterion:
- valid
- loss

View File

@ -1,6 +1,6 @@
import collections.abc
from pathlib import Path
from typing import Union
from typing import List, Tuple, Union
import random
import numpy as np
@ -13,6 +13,74 @@ import torchaudio
from funasr.fileio.read_text import read_2column_text
def soundfile_read(
wavs: Union[str, List[str]],
dtype=None,
always_2d: bool = False,
concat_axis: int = 1,
start: int = 0,
end: int = None,
return_subtype: bool = False,
) -> Tuple[np.array, int]:
if isinstance(wavs, str):
wavs = [wavs]
arrays = []
subtypes = []
prev_rate = None
prev_wav = None
for wav in wavs:
with soundfile.SoundFile(wav) as f:
f.seek(start)
if end is not None:
frames = end - start
else:
frames = -1
if dtype == "float16":
array = f.read(
frames,
dtype="float32",
always_2d=always_2d,
).astype(dtype)
else:
array = f.read(frames, dtype=dtype, always_2d=always_2d)
rate = f.samplerate
subtype = f.subtype
subtypes.append(subtype)
if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
# array: (Time, Channel)
array = array[:, None]
if prev_wav is not None:
if prev_rate != rate:
raise RuntimeError(
f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
f"{prev_rate} != {rate}"
)
dim1 = arrays[0].shape[1 - concat_axis]
dim2 = array.shape[1 - concat_axis]
if dim1 != dim2:
raise RuntimeError(
"Shapes must match with "
f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
)
prev_rate = rate
prev_wav = wav
arrays.append(array)
if len(arrays) == 1:
array = arrays[0]
else:
array = np.concatenate(arrays, axis=concat_axis)
if return_subtype:
return array, rate, subtypes
else:
return array, rate
class SoundScpReader(collections.abc.Mapping):
"""Reader class for 'wav.scp'.

View File

@ -1081,7 +1081,10 @@ class ConformerChunkEncoder(AbsEncoder):
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
if self.training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
else:
chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
@ -1113,12 +1116,15 @@ class ConformerChunkEncoder(AbsEncoder):
elif self.dynamic_chunk_training:
max_len = x.size(1)
chunk_size = torch.randint(1, max_len, (1,)).item()
if self.training:
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
chunk_size = (chunk_size % self.short_chunk_size) + 1
chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
@ -1147,6 +1153,45 @@ class ConformerChunkEncoder(AbsEncoder):
return x, olens, None
def full_utt_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Encoder input features. (B, T_in, F)
x_len: Encoder input features lengths. (B,)
Returns:
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len).to(x.device)
x, mask = self.embed(x, mask, None)
pos_enc = self.pos_enc(x)
x_utt = self.encoders(
x,
pos_enc,
mask,
chunk_mask=None,
)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
return x_utt
def simu_chunk_forward(
self,
x: torch.Tensor,

View File

@ -427,6 +427,7 @@ class StreamingConvInput(torch.nn.Module):
conv_size: Union[int, Tuple],
subsampling_factor: int = 4,
vgg_like: bool = True,
conv_kernel_size: int = 3,
output_size: Optional[int] = None,
) -> None:
"""Construct a ConvInput object."""
@ -436,14 +437,14 @@ class StreamingConvInput(torch.nn.Module):
conv_size1, conv_size2 = conv_size
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
)
@ -462,14 +463,14 @@ class StreamingConvInput(torch.nn.Module):
kernel_1 = int(subsampling_factor / 2)
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((kernel_1, 2)),
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2, 2)),
)
@ -487,14 +488,14 @@ class StreamingConvInput(torch.nn.Module):
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
torch.nn.ReLU(),
)
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
self.subsampling_factor = subsampling_factor
self.kernel_2 = 3
self.kernel_2 = conv_kernel_size
self.stride_2 = 1
self.create_new_mask = self.create_new_conv2d_mask