mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr2
This commit is contained in:
parent
00ea1186f9
commit
c8bae0ec85
8
examples/industrial_data_pretraining/fsmn-vad/infer.sh
Normal file
8
examples/industrial_data_pretraining/fsmn-vad/infer.sh
Normal file
@ -0,0 +1,8 @@
|
||||
|
||||
cmd="funasr/bin/inference.py"
|
||||
|
||||
python $cmd \
|
||||
+model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
|
||||
+input="/Users/zhifu/Downloads/asr_example.wav" \
|
||||
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_vad" \
|
||||
+device="cpu" \
|
||||
@ -101,6 +101,7 @@ class AutoModel:
|
||||
tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
|
||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
kwargs["token_list"] = tokenizer.token_list
|
||||
|
||||
# build frontend
|
||||
frontend = kwargs.get("frontend", None)
|
||||
@ -112,12 +113,10 @@ class AutoModel:
|
||||
|
||||
# build model
|
||||
model_class = registry_tables.model_classes.get(kwargs["model"].lower())
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
kwargs["token_list"] = tokenizer.token_list
|
||||
|
||||
# init_param
|
||||
init_param = kwargs.get("init_param", None)
|
||||
if init_param is not None:
|
||||
|
||||
@ -145,7 +145,8 @@ def main(**kwargs):
|
||||
# dataloader
|
||||
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
||||
batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
|
||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||
if batch_sampler is not None:
|
||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
||||
collate_fn=dataset_tr.collator,
|
||||
batch_sampler=batch_sampler,
|
||||
@ -153,7 +154,6 @@ def main(**kwargs):
|
||||
pin_memory=True)
|
||||
|
||||
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optim=optim,
|
||||
|
||||
@ -24,6 +24,17 @@ class AudioDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
|
||||
self.index_ds = index_ds_class(path)
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower())
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
self.frontend = frontend
|
||||
self.fs = 16000 if frontend is None else frontend.fs
|
||||
self.data_type = "sound"
|
||||
@ -49,8 +60,13 @@ class AudioDataset(torch.utils.data.Dataset):
|
||||
# pdb.set_trace()
|
||||
source = item["source"]
|
||||
data_src = load_audio(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
|
||||
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
target = self.preprocessor_text(target)
|
||||
ids = self.tokenizer.encode(target)
|
||||
ids_lengths = len(ids)
|
||||
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
|
||||
|
||||
212
funasr/models/ct_transformer/model.py
Normal file
212
funasr/models/ct_transformer/model.py
Normal file
@ -0,0 +1,212 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.utils.register import register_class, registry_tables
|
||||
|
||||
@register_class("model_classes", "CTTransformer")
|
||||
class CTTransformer(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
encoder: str = None,
|
||||
encoder_conf: str = None,
|
||||
vocab_size: int = -1,
|
||||
punc_list: list = None,
|
||||
punc_weight: list = None,
|
||||
embed_unit: int = 128,
|
||||
att_unit: int = 256,
|
||||
dropout_rate: float = 0.5,
|
||||
ignore_id: int = -1,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
punc_size = len(punc_list)
|
||||
if punc_weight is None:
|
||||
punc_weight = [1] * punc_size
|
||||
|
||||
|
||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||
encoder_class = registry_tables.encoder_classes.get(encoder.lower())
|
||||
encoder = encoder_class(**encoder_conf)
|
||||
|
||||
self.decoder = nn.Linear(att_unit, punc_size)
|
||||
self.encoder = encoder
|
||||
self.punc_list = punc_list
|
||||
self.punc_weight = punc_weight
|
||||
self.ignore_id = ignore_id
|
||||
self.sos = sos
|
||||
self.eos = eos
|
||||
|
||||
|
||||
|
||||
def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(input)
|
||||
# mask = self._target_mask(input)
|
||||
h, _, _ = self.encoder(x, text_lengths)
|
||||
y = self.decoder(h)
|
||||
return y, None
|
||||
|
||||
def with_vad(self):
|
||||
return False
|
||||
|
||||
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (vocab_size)
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
y = y.unsqueeze(0)
|
||||
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||
return logp, cache
|
||||
|
||||
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.encoder.encoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
||||
|
||||
# batch decoding
|
||||
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
|
||||
def nll(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
vad_indexes: Optional[torch.Tensor] = None,
|
||||
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll)
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
punc: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
max_lengths: int
|
||||
"""
|
||||
batch_size = text.size(0)
|
||||
# For data parallel
|
||||
if max_length is None:
|
||||
text = text[:, :text_lengths.max()]
|
||||
punc = punc[:, :text_lengths.max()]
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
punc = punc[:, :max_length]
|
||||
|
||||
if self.with_vad():
|
||||
# Should be VadRealtimeTransformer
|
||||
assert vad_indexes is not None
|
||||
y, _ = self.punc_forward(text, text_lengths, vad_indexes)
|
||||
else:
|
||||
# Should be TargetDelayTransformer,
|
||||
y, _ = self.punc_forward(text, text_lengths)
|
||||
|
||||
# Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
if self.training == False:
|
||||
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
||||
from sklearn.metrics import f1_score
|
||||
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
|
||||
indices.squeeze(-1).detach().cpu().numpy(),
|
||||
average='micro')
|
||||
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
|
||||
return nll, text_lengths
|
||||
else:
|
||||
self.punc_weight = self.punc_weight.to(punc.device)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
|
||||
ignore_index=self.ignore_id)
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
|
||||
else:
|
||||
nll.masked_fill_(
|
||||
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
||||
0.0,
|
||||
)
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll, text_lengths
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None,
|
||||
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def generate(self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
|
||||
if self.with_vad():
|
||||
assert vad_indexes is not None
|
||||
return self.punc_forward(text, text_lengths, vad_indexes)
|
||||
else:
|
||||
return self.punc_forward(text, text_lengths)
|
||||
@ -1,130 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
|
||||
from funasr.models.sanm.encoder import SANMEncoder as Encoder
|
||||
|
||||
|
||||
class TargetDelayTransformer(torch.nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
punc_size: int,
|
||||
pos_enc: str = None,
|
||||
embed_unit: int = 128,
|
||||
att_unit: int = 256,
|
||||
head: int = 2,
|
||||
unit: int = 1024,
|
||||
layer: int = 4,
|
||||
dropout_rate: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
if pos_enc == "sinusoidal":
|
||||
# pos_enc_class = PositionalEncoding
|
||||
pos_enc_class = SinusoidalPositionEncoder
|
||||
elif pos_enc is None:
|
||||
|
||||
def pos_enc_class(*args, **kwargs):
|
||||
return nn.Sequential() # indentity
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||
|
||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||
self.encoder = Encoder(
|
||||
input_size=embed_unit,
|
||||
output_size=att_unit,
|
||||
attention_heads=head,
|
||||
linear_units=unit,
|
||||
num_blocks=layer,
|
||||
dropout_rate=dropout_rate,
|
||||
input_layer="pe",
|
||||
# pos_enc_class=pos_enc_class,
|
||||
padding_idx=0,
|
||||
)
|
||||
self.decoder = nn.Linear(att_unit, punc_size)
|
||||
|
||||
|
||||
# def _target_mask(self, ys_in_pad):
|
||||
# ys_mask = ys_in_pad != 0
|
||||
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
|
||||
# return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(input)
|
||||
# mask = self._target_mask(input)
|
||||
h, _, _ = self.encoder(x, text_lengths)
|
||||
y = self.decoder(h)
|
||||
return y, None
|
||||
|
||||
def with_vad(self):
|
||||
return False
|
||||
|
||||
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (vocab_size)
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
y = y.unsqueeze(0)
|
||||
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||
return logp, cache
|
||||
|
||||
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.encoder.encoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
||||
|
||||
# batch decoding
|
||||
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
@ -6,6 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from funasr.utils.register import register_class, registry_tables
|
||||
|
||||
class LinearTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
@ -156,7 +158,7 @@ num_syn: output dimension
|
||||
fsmn_layers: no. of sequential fsmn layers
|
||||
'''
|
||||
|
||||
|
||||
@register_class("encoder_classes", "FSMN")
|
||||
class FSMN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -227,7 +229,7 @@ lstride: left stride
|
||||
rstride: right stride
|
||||
'''
|
||||
|
||||
|
||||
@register_class("encoder_classes", "DFSMN")
|
||||
class DFSMN(nn.Module):
|
||||
|
||||
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
|
||||
@ -1,33 +1,244 @@
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
from typing import Optional
|
||||
from funasr.models.encoder.fsmn_encoder import FSMN
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.model_class_factory import *
|
||||
import time
|
||||
from funasr.utils.register import register_class, registry_tables
|
||||
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
class VadStateMachine(Enum):
|
||||
kVadInStateStartPointNotDetected = 1
|
||||
kVadInStateInSpeechSegment = 2
|
||||
kVadInStateEndPointDetected = 3
|
||||
|
||||
|
||||
class FrameState(Enum):
|
||||
kFrameStateInvalid = -1
|
||||
kFrameStateSpeech = 1
|
||||
kFrameStateSil = 0
|
||||
|
||||
|
||||
# final voice/unvoice state per frame
|
||||
class AudioChangeState(Enum):
|
||||
kChangeStateSpeech2Speech = 0
|
||||
kChangeStateSpeech2Sil = 1
|
||||
kChangeStateSil2Sil = 2
|
||||
kChangeStateSil2Speech = 3
|
||||
kChangeStateNoBegin = 4
|
||||
kChangeStateInvalid = 5
|
||||
|
||||
|
||||
class VadDetectMode(Enum):
|
||||
kVadSingleUtteranceDetectMode = 0
|
||||
kVadMutipleUtteranceDetectMode = 1
|
||||
|
||||
|
||||
class VADXOptions:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
|
||||
snr_mode: int = 0,
|
||||
max_end_silence_time: int = 800,
|
||||
max_start_silence_time: int = 3000,
|
||||
do_start_point_detection: bool = True,
|
||||
do_end_point_detection: bool = True,
|
||||
window_size_ms: int = 200,
|
||||
sil_to_speech_time_thres: int = 150,
|
||||
speech_to_sil_time_thres: int = 150,
|
||||
speech_2_noise_ratio: float = 1.0,
|
||||
do_extend: int = 1,
|
||||
lookback_time_start_point: int = 200,
|
||||
lookahead_time_end_point: int = 100,
|
||||
max_single_segment_time: int = 60000,
|
||||
nn_eval_block_size: int = 8,
|
||||
dcd_block_size: int = 4,
|
||||
snr_thres: int = -100.0,
|
||||
noise_frame_num_used_for_snr: int = 100,
|
||||
decibel_thres: int = -100.0,
|
||||
speech_noise_thres: float = 0.6,
|
||||
fe_prior_thres: float = 1e-4,
|
||||
silence_pdf_num: int = 1,
|
||||
sil_pdf_ids: List[int] = [0],
|
||||
speech_noise_thresh_low: float = -0.1,
|
||||
speech_noise_thresh_high: float = 0.3,
|
||||
output_frame_probs: bool = False,
|
||||
frame_in_ms: int = 10,
|
||||
frame_length_ms: int = 25,
|
||||
**kwargs,
|
||||
):
|
||||
self.sample_rate = sample_rate
|
||||
self.detect_mode = detect_mode
|
||||
self.snr_mode = snr_mode
|
||||
self.max_end_silence_time = max_end_silence_time
|
||||
self.max_start_silence_time = max_start_silence_time
|
||||
self.do_start_point_detection = do_start_point_detection
|
||||
self.do_end_point_detection = do_end_point_detection
|
||||
self.window_size_ms = window_size_ms
|
||||
self.sil_to_speech_time_thres = sil_to_speech_time_thres
|
||||
self.speech_to_sil_time_thres = speech_to_sil_time_thres
|
||||
self.speech_2_noise_ratio = speech_2_noise_ratio
|
||||
self.do_extend = do_extend
|
||||
self.lookback_time_start_point = lookback_time_start_point
|
||||
self.lookahead_time_end_point = lookahead_time_end_point
|
||||
self.max_single_segment_time = max_single_segment_time
|
||||
self.nn_eval_block_size = nn_eval_block_size
|
||||
self.dcd_block_size = dcd_block_size
|
||||
self.snr_thres = snr_thres
|
||||
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
|
||||
self.decibel_thres = decibel_thres
|
||||
self.speech_noise_thres = speech_noise_thres
|
||||
self.fe_prior_thres = fe_prior_thres
|
||||
self.silence_pdf_num = silence_pdf_num
|
||||
self.sil_pdf_ids = sil_pdf_ids
|
||||
self.speech_noise_thresh_low = speech_noise_thresh_low
|
||||
self.speech_noise_thresh_high = speech_noise_thresh_high
|
||||
self.output_frame_probs = output_frame_probs
|
||||
self.frame_in_ms = frame_in_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
|
||||
|
||||
class E2EVadSpeechBufWithDoa(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
self.buffer = []
|
||||
self.contain_seg_start_point = False
|
||||
self.contain_seg_end_point = False
|
||||
self.doa = 0
|
||||
|
||||
def Reset(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
self.buffer = []
|
||||
self.contain_seg_start_point = False
|
||||
self.contain_seg_end_point = False
|
||||
self.doa = 0
|
||||
|
||||
|
||||
class E2EVadFrameProb(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.noise_prob = 0.0
|
||||
self.speech_prob = 0.0
|
||||
self.score = 0.0
|
||||
self.frame_id = 0
|
||||
self.frm_state = 0
|
||||
|
||||
|
||||
class WindowDetector(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
|
||||
speech_to_sil_time: int, frame_size_ms: int):
|
||||
self.window_size_ms = window_size_ms
|
||||
self.sil_to_speech_time = sil_to_speech_time
|
||||
self.speech_to_sil_time = speech_to_sil_time
|
||||
self.frame_size_ms = frame_size_ms
|
||||
|
||||
self.win_size_frame = int(window_size_ms / frame_size_ms)
|
||||
self.win_sum = 0
|
||||
self.win_state = [0] * self.win_size_frame # 初始化窗
|
||||
|
||||
self.cur_win_pos = 0
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
self.cur_frame_state = FrameState.kFrameStateSil
|
||||
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
|
||||
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
|
||||
|
||||
self.voice_last_frame_count = 0
|
||||
self.noise_last_frame_count = 0
|
||||
self.hydre_frame_count = 0
|
||||
|
||||
def Reset(self) -> None:
|
||||
self.cur_win_pos = 0
|
||||
self.win_sum = 0
|
||||
self.win_state = [0] * self.win_size_frame
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
self.cur_frame_state = FrameState.kFrameStateSil
|
||||
self.voice_last_frame_count = 0
|
||||
self.noise_last_frame_count = 0
|
||||
self.hydre_frame_count = 0
|
||||
|
||||
def GetWinSize(self) -> int:
|
||||
return int(self.win_size_frame)
|
||||
|
||||
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
|
||||
cur_frame_state = FrameState.kFrameStateSil
|
||||
if frameState == FrameState.kFrameStateSpeech:
|
||||
cur_frame_state = 1
|
||||
elif frameState == FrameState.kFrameStateSil:
|
||||
cur_frame_state = 0
|
||||
else:
|
||||
return AudioChangeState.kChangeStateInvalid
|
||||
self.win_sum -= self.win_state[self.cur_win_pos]
|
||||
self.win_sum += cur_frame_state
|
||||
self.win_state[self.cur_win_pos] = cur_frame_state
|
||||
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
|
||||
self.pre_frame_state = FrameState.kFrameStateSpeech
|
||||
return AudioChangeState.kChangeStateSil2Speech
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
return AudioChangeState.kChangeStateSpeech2Sil
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSil:
|
||||
return AudioChangeState.kChangeStateSil2Sil
|
||||
if self.pre_frame_state == FrameState.kFrameStateSpeech:
|
||||
return AudioChangeState.kChangeStateSpeech2Speech
|
||||
return AudioChangeState.kChangeStateInvalid
|
||||
|
||||
def FrameSizeMs(self) -> int:
|
||||
return int(self.frame_size_ms)
|
||||
|
||||
|
||||
@register_class("model_classes", "FsmnVAD")
|
||||
class FsmnVAD(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self, encoder: str = None,
|
||||
def __init__(self,
|
||||
encoder: str = None,
|
||||
encoder_conf: Optional[Dict] = None,
|
||||
vad_post_args: Dict[str, Any] = None,
|
||||
frontend=None):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.vad_opts = VADXOptions(**vad_post_args)
|
||||
self.vad_opts = VADXOptions(**kwargs)
|
||||
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
|
||||
self.vad_opts.sil_to_speech_time_thres,
|
||||
self.vad_opts.speech_to_sil_time_thres,
|
||||
self.vad_opts.frame_in_ms)
|
||||
|
||||
encoder_class = encoder_classes.get_class(encoder)
|
||||
encoder_class = registry_tables.encoder_classes.get(encoder.lower())
|
||||
encoder = encoder_class(**encoder_conf)
|
||||
self.encoder = encoder
|
||||
# init variables
|
||||
@ -57,7 +268,6 @@ class FsmnVAD(nn.Module):
|
||||
self.data_buf = None
|
||||
self.data_buf_all = None
|
||||
self.waveform = None
|
||||
self.frontend = frontend
|
||||
self.last_drop_frames = 0
|
||||
|
||||
def AllResetDetection(self):
|
||||
@ -239,7 +449,7 @@ class FsmnVAD(nn.Module):
|
||||
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
|
||||
return vad_latency
|
||||
|
||||
def GetFrameState(self, t: int) -> FrameState:
|
||||
def GetFrameState(self, t: int):
|
||||
frame_state = FrameState.kFrameStateInvalid
|
||||
cur_decibel = self.decibel[t]
|
||||
cur_snr = cur_decibel - self.noise_average_decibel
|
||||
@ -285,7 +495,7 @@ class FsmnVAD(nn.Module):
|
||||
|
||||
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||
is_final: bool = False
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
):
|
||||
if not in_cache:
|
||||
self.AllResetDetection()
|
||||
self.waveform = waveform # compute decibel for each frame
|
||||
@ -313,6 +523,87 @@ class FsmnVAD(nn.Module):
|
||||
self.AllResetDetection()
|
||||
return segments, in_cache
|
||||
|
||||
def generate(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
|
||||
meta_data = {}
|
||||
audio_sample_list = [data_in]
|
||||
if isinstance(data_in, torch.Tensor): # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data[
|
||||
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# b. Forward Encoder streaming
|
||||
t_offset = 0
|
||||
feats = speech
|
||||
feats_len = speech_lengths.max().item()
|
||||
waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
|
||||
in_cache = kwargs.get("in_cache", {})
|
||||
batch_size = kwargs.get("batch_size", 1)
|
||||
step = min(feats_len, 6000)
|
||||
segments = [[]] * batch_size
|
||||
|
||||
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
|
||||
if t_offset + step >= feats_len - 1:
|
||||
step = feats_len - t_offset
|
||||
is_final = True
|
||||
else:
|
||||
is_final = False
|
||||
batch = {
|
||||
"feats": feats[:, t_offset:t_offset + step, :],
|
||||
"waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
|
||||
"is_final": is_final,
|
||||
"in_cache": in_cache
|
||||
}
|
||||
|
||||
|
||||
segments_part, in_cache = self.forward(**batch)
|
||||
if segments_part:
|
||||
for batch_num in range(0, batch_size):
|
||||
segments[batch_num] += segments_part[batch_num]
|
||||
|
||||
ibest_writer = None
|
||||
if ibest_writer is None and kwargs.get("output_dir") is not None:
|
||||
writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = writer[f"{1}best_recog"]
|
||||
|
||||
results = []
|
||||
for i in range(batch_size):
|
||||
|
||||
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
|
||||
results[i] = json.dumps(results[i])
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["text"][key[i]] = segments[i]
|
||||
|
||||
result_i = {"key": key[i], "value": segments[i]}
|
||||
results.append(result_i)
|
||||
|
||||
return results, meta_data
|
||||
|
||||
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||
is_final: bool = False, max_end_sil: int = 800
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
@ -483,207 +774,3 @@ class FsmnVAD(nn.Module):
|
||||
|
||||
|
||||
|
||||
class VadStateMachine(Enum):
|
||||
kVadInStateStartPointNotDetected = 1
|
||||
kVadInStateInSpeechSegment = 2
|
||||
kVadInStateEndPointDetected = 3
|
||||
|
||||
|
||||
class FrameState(Enum):
|
||||
kFrameStateInvalid = -1
|
||||
kFrameStateSpeech = 1
|
||||
kFrameStateSil = 0
|
||||
|
||||
|
||||
# final voice/unvoice state per frame
|
||||
class AudioChangeState(Enum):
|
||||
kChangeStateSpeech2Speech = 0
|
||||
kChangeStateSpeech2Sil = 1
|
||||
kChangeStateSil2Sil = 2
|
||||
kChangeStateSil2Speech = 3
|
||||
kChangeStateNoBegin = 4
|
||||
kChangeStateInvalid = 5
|
||||
|
||||
|
||||
class VadDetectMode(Enum):
|
||||
kVadSingleUtteranceDetectMode = 0
|
||||
kVadMutipleUtteranceDetectMode = 1
|
||||
|
||||
|
||||
class VADXOptions:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
|
||||
snr_mode: int = 0,
|
||||
max_end_silence_time: int = 800,
|
||||
max_start_silence_time: int = 3000,
|
||||
do_start_point_detection: bool = True,
|
||||
do_end_point_detection: bool = True,
|
||||
window_size_ms: int = 200,
|
||||
sil_to_speech_time_thres: int = 150,
|
||||
speech_to_sil_time_thres: int = 150,
|
||||
speech_2_noise_ratio: float = 1.0,
|
||||
do_extend: int = 1,
|
||||
lookback_time_start_point: int = 200,
|
||||
lookahead_time_end_point: int = 100,
|
||||
max_single_segment_time: int = 60000,
|
||||
nn_eval_block_size: int = 8,
|
||||
dcd_block_size: int = 4,
|
||||
snr_thres: int = -100.0,
|
||||
noise_frame_num_used_for_snr: int = 100,
|
||||
decibel_thres: int = -100.0,
|
||||
speech_noise_thres: float = 0.6,
|
||||
fe_prior_thres: float = 1e-4,
|
||||
silence_pdf_num: int = 1,
|
||||
sil_pdf_ids: List[int] = [0],
|
||||
speech_noise_thresh_low: float = -0.1,
|
||||
speech_noise_thresh_high: float = 0.3,
|
||||
output_frame_probs: bool = False,
|
||||
frame_in_ms: int = 10,
|
||||
frame_length_ms: int = 25,
|
||||
):
|
||||
self.sample_rate = sample_rate
|
||||
self.detect_mode = detect_mode
|
||||
self.snr_mode = snr_mode
|
||||
self.max_end_silence_time = max_end_silence_time
|
||||
self.max_start_silence_time = max_start_silence_time
|
||||
self.do_start_point_detection = do_start_point_detection
|
||||
self.do_end_point_detection = do_end_point_detection
|
||||
self.window_size_ms = window_size_ms
|
||||
self.sil_to_speech_time_thres = sil_to_speech_time_thres
|
||||
self.speech_to_sil_time_thres = speech_to_sil_time_thres
|
||||
self.speech_2_noise_ratio = speech_2_noise_ratio
|
||||
self.do_extend = do_extend
|
||||
self.lookback_time_start_point = lookback_time_start_point
|
||||
self.lookahead_time_end_point = lookahead_time_end_point
|
||||
self.max_single_segment_time = max_single_segment_time
|
||||
self.nn_eval_block_size = nn_eval_block_size
|
||||
self.dcd_block_size = dcd_block_size
|
||||
self.snr_thres = snr_thres
|
||||
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
|
||||
self.decibel_thres = decibel_thres
|
||||
self.speech_noise_thres = speech_noise_thres
|
||||
self.fe_prior_thres = fe_prior_thres
|
||||
self.silence_pdf_num = silence_pdf_num
|
||||
self.sil_pdf_ids = sil_pdf_ids
|
||||
self.speech_noise_thresh_low = speech_noise_thresh_low
|
||||
self.speech_noise_thresh_high = speech_noise_thresh_high
|
||||
self.output_frame_probs = output_frame_probs
|
||||
self.frame_in_ms = frame_in_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
|
||||
|
||||
class E2EVadSpeechBufWithDoa(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
self.buffer = []
|
||||
self.contain_seg_start_point = False
|
||||
self.contain_seg_end_point = False
|
||||
self.doa = 0
|
||||
|
||||
def Reset(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
self.buffer = []
|
||||
self.contain_seg_start_point = False
|
||||
self.contain_seg_end_point = False
|
||||
self.doa = 0
|
||||
|
||||
|
||||
class E2EVadFrameProb(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.noise_prob = 0.0
|
||||
self.speech_prob = 0.0
|
||||
self.score = 0.0
|
||||
self.frame_id = 0
|
||||
self.frm_state = 0
|
||||
|
||||
|
||||
class WindowDetector(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
|
||||
speech_to_sil_time: int, frame_size_ms: int):
|
||||
self.window_size_ms = window_size_ms
|
||||
self.sil_to_speech_time = sil_to_speech_time
|
||||
self.speech_to_sil_time = speech_to_sil_time
|
||||
self.frame_size_ms = frame_size_ms
|
||||
|
||||
self.win_size_frame = int(window_size_ms / frame_size_ms)
|
||||
self.win_sum = 0
|
||||
self.win_state = [0] * self.win_size_frame # 初始化窗
|
||||
|
||||
self.cur_win_pos = 0
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
self.cur_frame_state = FrameState.kFrameStateSil
|
||||
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
|
||||
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
|
||||
|
||||
self.voice_last_frame_count = 0
|
||||
self.noise_last_frame_count = 0
|
||||
self.hydre_frame_count = 0
|
||||
|
||||
def Reset(self) -> None:
|
||||
self.cur_win_pos = 0
|
||||
self.win_sum = 0
|
||||
self.win_state = [0] * self.win_size_frame
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
self.cur_frame_state = FrameState.kFrameStateSil
|
||||
self.voice_last_frame_count = 0
|
||||
self.noise_last_frame_count = 0
|
||||
self.hydre_frame_count = 0
|
||||
|
||||
def GetWinSize(self) -> int:
|
||||
return int(self.win_size_frame)
|
||||
|
||||
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
|
||||
cur_frame_state = FrameState.kFrameStateSil
|
||||
if frameState == FrameState.kFrameStateSpeech:
|
||||
cur_frame_state = 1
|
||||
elif frameState == FrameState.kFrameStateSil:
|
||||
cur_frame_state = 0
|
||||
else:
|
||||
return AudioChangeState.kChangeStateInvalid
|
||||
self.win_sum -= self.win_state[self.cur_win_pos]
|
||||
self.win_sum += cur_frame_state
|
||||
self.win_state[self.cur_win_pos] = cur_frame_state
|
||||
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
|
||||
self.pre_frame_state = FrameState.kFrameStateSpeech
|
||||
return AudioChangeState.kChangeStateSil2Speech
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
return AudioChangeState.kChangeStateSpeech2Sil
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSil:
|
||||
return AudioChangeState.kChangeStateSil2Sil
|
||||
if self.pre_frame_state == FrameState.kFrameStateSpeech:
|
||||
return AudioChangeState.kChangeStateSpeech2Speech
|
||||
return AudioChangeState.kChangeStateInvalid
|
||||
|
||||
def FrameSizeMs(self) -> int:
|
||||
return int(self.frame_size_ms)
|
||||
|
||||
|
||||
|
||||
@ -42,8 +42,9 @@ class BaseTokenizer(ABC):
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with open('data.json', 'r', encoding='utf-8') as f:
|
||||
self.token_list = json.loads(f.read())
|
||||
with open(token_list, 'r', encoding='utf-8') as f:
|
||||
self.token_list = json.load(f)
|
||||
|
||||
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
|
||||
@ -120,6 +120,7 @@ def load_pretrained_model(
|
||||
if ignore_init_mismatch:
|
||||
src_state = filter_state_dict(dst_state, src_state)
|
||||
|
||||
# logging.info("Loaded src_state keys: {}".format(src_state.keys()))
|
||||
logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
|
||||
logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
|
||||
dst_state.update(src_state)
|
||||
obj.load_state_dict(dst_state)
|
||||
|
||||
5
setup.py
5
setup.py
@ -10,14 +10,11 @@ from setuptools import setup
|
||||
|
||||
requirements = {
|
||||
"install": [
|
||||
# "setuptools>=38.5.1",
|
||||
"humanfriendly",
|
||||
"scipy>=1.4.1",
|
||||
"librosa",
|
||||
"jamo", # For kss
|
||||
"PyYAML>=5.1.2",
|
||||
# "soundfile>=0.12.1",
|
||||
# "h5py>=3.1.0",
|
||||
"kaldiio>=2.17.0",
|
||||
"torch_complex",
|
||||
# "nltk>=3.4.5",
|
||||
@ -32,7 +29,6 @@ requirements = {
|
||||
# ENH
|
||||
"pytorch_wpe",
|
||||
"editdistance>=0.5.2",
|
||||
"tensorboard",
|
||||
# "g2p",
|
||||
# "nara_wpe",
|
||||
# PAI
|
||||
@ -44,6 +40,7 @@ requirements = {
|
||||
"hdbscan",
|
||||
"umap",
|
||||
"jaconv",
|
||||
"hydra-core",
|
||||
],
|
||||
# train: The modules invoked when training only.
|
||||
"train": [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user