mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update cam++ for embed extract
This commit is contained in:
parent
d342c642fa
commit
668b830cb2
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common")
|
||||
|
||||
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
print(res)
|
||||
@ -159,6 +159,9 @@ class AutoModel:
|
||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
kwargs["token_list"] = tokenizer.token_list
|
||||
vocab_size = len(tokenizer.token_list)
|
||||
else:
|
||||
vocab_size = -1
|
||||
|
||||
# build frontend
|
||||
frontend = kwargs.get("frontend", None)
|
||||
@ -170,8 +173,7 @@ class AutoModel:
|
||||
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"].lower())
|
||||
model = model_class(**kwargs, **kwargs["model_conf"],
|
||||
vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
|
||||
@ -1 +0,0 @@
|
||||
from .campplus import CAMPPlus
|
||||
@ -7,6 +7,82 @@ import torch.utils.checkpoint as cp
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BasicResBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicResBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=(stride, 1),
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=(stride, 1),
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class FCM(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicResBlock,
|
||||
num_blocks=[2, 2],
|
||||
m_channels=32,
|
||||
feat_dim=80):
|
||||
super(FCM, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
||||
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(m_channels)
|
||||
self.out_channels = m_channels * (feat_dim // 8)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
|
||||
shape = out.shape
|
||||
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
|
||||
return out
|
||||
|
||||
|
||||
def get_nonlinear(config_str, channels):
|
||||
nonlinear = nn.Sequential()
|
||||
for name in config_str.split('-'):
|
||||
@ -216,39 +292,3 @@ class DenseLayer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class BasicResBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicResBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=(stride, 1),
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=(stride, 1),
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
@ -1,54 +1,24 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from typing import Union, Dict, List, Tuple, Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
from funasr.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
|
||||
BasicResBlock, get_nonlinear
|
||||
|
||||
|
||||
class FCM(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicResBlock,
|
||||
num_blocks=[2, 2],
|
||||
m_channels=32,
|
||||
feat_dim=80):
|
||||
super(FCM, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
||||
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(m_channels)
|
||||
self.out_channels = m_channels * (feat_dim // 8)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
|
||||
shape = out.shape
|
||||
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
|
||||
return out
|
||||
from funasr.utils.load_utils import load_audio_text_image_video
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.register import tables
|
||||
from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
|
||||
BasicResBlock, get_nonlinear, FCM
|
||||
from funasr.models.campplus.utils import extract_feature
|
||||
|
||||
|
||||
@tables.register("model_classes", "CAMPPlus")
|
||||
class CAMPPlus(nn.Module):
|
||||
def __init__(self,
|
||||
feat_dim=80,
|
||||
@ -58,8 +28,9 @@ class CAMPPlus(nn.Module):
|
||||
init_channels=128,
|
||||
config_str='batchnorm-relu',
|
||||
memory_efficient=True,
|
||||
output_level='segment'):
|
||||
super(CAMPPlus, self).__init__()
|
||||
output_level='segment',
|
||||
**kwargs,):
|
||||
super().__init__()
|
||||
|
||||
self.head = FCM(feat_dim=feat_dim)
|
||||
channels = self.head.out_channels
|
||||
@ -123,3 +94,28 @@ class CAMPPlus(nn.Module):
|
||||
if self.output_level == 'frame':
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
def generate(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list=None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
# extract fbank feats
|
||||
meta_data = {}
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_feature(audio_sample_list)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
|
||||
# import pdb; pdb.set_trace()
|
||||
results = []
|
||||
embeddings = self.forward(speech)
|
||||
for embedding in embeddings:
|
||||
results.append({"spk_embedding":embedding})
|
||||
return results, meta_data
|
||||
23
funasr/models/campplus/template.yaml
Normal file
23
funasr/models/campplus/template.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
# This is an example that demonstrates how to configure a model file.
|
||||
# You can modify the configuration according to your own requirements.
|
||||
|
||||
# to print the register_table:
|
||||
# from funasr.register import tables
|
||||
# tables.print()
|
||||
|
||||
# network architecture
|
||||
model: CAMPPlus
|
||||
model_conf:
|
||||
feat_dim: 80
|
||||
embedding_size: 192
|
||||
growth_rate: 32
|
||||
bn_size: 4
|
||||
init_channels: 128
|
||||
config_str: 'batchnorm-relu'
|
||||
memory_efficient: True
|
||||
output_level: 'segment'
|
||||
|
||||
# frontend related
|
||||
frontend: WavFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
533
funasr/models/campplus/utils.py
Normal file
533
funasr/models/campplus/utils.py
Normal file
@ -0,0 +1,533 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import io
|
||||
from typing import Union
|
||||
|
||||
import librosa as sf
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
from torch import nn
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Generator, Union
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def check_audio_list(audio: list):
|
||||
audio_dur = 0
|
||||
for i in range(len(audio)):
|
||||
seg = audio[i]
|
||||
assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
|
||||
assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
|
||||
assert int(seg[1] * 16000) - int(
|
||||
seg[0] * 16000
|
||||
) == seg[2].shape[
|
||||
0], 'modelscope error: audio data in list is inconsistent with time length.'
|
||||
if i > 0:
|
||||
assert seg[0] >= audio[
|
||||
i - 1][1], 'modelscope error: Wrong time stamps.'
|
||||
audio_dur += seg[1] - seg[0]
|
||||
return audio_dur
|
||||
# assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
|
||||
|
||||
|
||||
def sv_preprocess(inputs: Union[np.ndarray, list]):
|
||||
output = []
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str):
|
||||
file_bytes = File.read(inputs[i])
|
||||
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
data = data.squeeze(0)
|
||||
elif isinstance(inputs[i], np.ndarray):
|
||||
assert len(
|
||||
inputs[i].shape
|
||||
) == 1, 'modelscope error: Input array should be [N, T]'
|
||||
data = inputs[i]
|
||||
if data.dtype in ['int16', 'int32', 'int64']:
|
||||
data = (data / (1 << 15)).astype('float32')
|
||||
else:
|
||||
data = data.astype('float32')
|
||||
data = torch.from_numpy(data)
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: The input type is restricted to audio address and nump array.'
|
||||
)
|
||||
output.append(data)
|
||||
return output
|
||||
|
||||
|
||||
def sv_chunk(vad_segments: list, fs = 16000) -> list:
|
||||
config = {
|
||||
'seg_dur': 1.5,
|
||||
'seg_shift': 0.75,
|
||||
}
|
||||
def seg_chunk(seg_data):
|
||||
seg_st = seg_data[0]
|
||||
data = seg_data[2]
|
||||
chunk_len = int(config['seg_dur'] * fs)
|
||||
chunk_shift = int(config['seg_shift'] * fs)
|
||||
last_chunk_ed = 0
|
||||
seg_res = []
|
||||
for chunk_st in range(0, data.shape[0], chunk_shift):
|
||||
chunk_ed = min(chunk_st + chunk_len, data.shape[0])
|
||||
if chunk_ed <= last_chunk_ed:
|
||||
break
|
||||
last_chunk_ed = chunk_ed
|
||||
chunk_st = max(0, chunk_ed - chunk_len)
|
||||
chunk_data = data[chunk_st:chunk_ed]
|
||||
if chunk_data.shape[0] < chunk_len:
|
||||
chunk_data = np.pad(chunk_data,
|
||||
(0, chunk_len - chunk_data.shape[0]),
|
||||
'constant')
|
||||
seg_res.append([
|
||||
chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
|
||||
chunk_data
|
||||
])
|
||||
return seg_res
|
||||
|
||||
segs = []
|
||||
for i, s in enumerate(vad_segments):
|
||||
segs.extend(seg_chunk(s))
|
||||
|
||||
return segs
|
||||
|
||||
|
||||
def extract_feature(audio):
|
||||
features = []
|
||||
feature_lengths = []
|
||||
for au in audio:
|
||||
feature = Kaldi.fbank(
|
||||
au.unsqueeze(0), num_mel_bins=80)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
features.append(feature.unsqueeze(0))
|
||||
feature_lengths.append(au.shape[0])
|
||||
features = torch.cat(features)
|
||||
return features, feature_lengths
|
||||
|
||||
|
||||
def postprocess(segments: list, vad_segments: list,
|
||||
labels: np.ndarray, embeddings: np.ndarray) -> list:
|
||||
assert len(segments) == len(labels)
|
||||
labels = correct_labels(labels)
|
||||
distribute_res = []
|
||||
for i in range(len(segments)):
|
||||
distribute_res.append([segments[i][0], segments[i][1], labels[i]])
|
||||
# merge the same speakers chronologically
|
||||
distribute_res = merge_seque(distribute_res)
|
||||
|
||||
# accquire speaker center
|
||||
spk_embs = []
|
||||
for i in range(labels.max() + 1):
|
||||
spk_emb = embeddings[labels == i].mean(0)
|
||||
spk_embs.append(spk_emb)
|
||||
spk_embs = np.stack(spk_embs)
|
||||
|
||||
def is_overlapped(t1, t2):
|
||||
if t1 > t2 + 1e-4:
|
||||
return True
|
||||
return False
|
||||
|
||||
# distribute the overlap region
|
||||
for i in range(1, len(distribute_res)):
|
||||
if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
|
||||
p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
|
||||
distribute_res[i][0] = p
|
||||
distribute_res[i - 1][1] = p
|
||||
|
||||
# smooth the result
|
||||
distribute_res = smooth(distribute_res)
|
||||
|
||||
return distribute_res
|
||||
|
||||
|
||||
def correct_labels(labels):
|
||||
labels_id = 0
|
||||
id2id = {}
|
||||
new_labels = []
|
||||
for i in labels:
|
||||
if i not in id2id:
|
||||
id2id[i] = labels_id
|
||||
labels_id += 1
|
||||
new_labels.append(id2id[i])
|
||||
return np.array(new_labels)
|
||||
|
||||
def merge_seque(distribute_res):
|
||||
res = [distribute_res[0]]
|
||||
for i in range(1, len(distribute_res)):
|
||||
if distribute_res[i][2] != res[-1][2] or distribute_res[i][
|
||||
0] > res[-1][1]:
|
||||
res.append(distribute_res[i])
|
||||
else:
|
||||
res[-1][1] = distribute_res[i][1]
|
||||
return res
|
||||
|
||||
def smooth(res, mindur=1):
|
||||
# short segments are assigned to nearest speakers.
|
||||
for i in range(len(res)):
|
||||
res[i][0] = round(res[i][0], 2)
|
||||
res[i][1] = round(res[i][1], 2)
|
||||
if res[i][1] - res[i][0] < mindur:
|
||||
if i == 0:
|
||||
res[i][2] = res[i + 1][2]
|
||||
elif i == len(res) - 1:
|
||||
res[i][2] = res[i - 1][2]
|
||||
elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
|
||||
res[i][2] = res[i - 1][2]
|
||||
else:
|
||||
res[i][2] = res[i + 1][2]
|
||||
# merge the speakers
|
||||
res = merge_seque(res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def distribute_spk(sentence_list, sd_time_list):
|
||||
sd_sentence_list = []
|
||||
for d in sentence_list:
|
||||
sentence_start = d['ts_list'][0][0]
|
||||
sentence_end = d['ts_list'][-1][1]
|
||||
sentence_spk = 0
|
||||
max_overlap = 0
|
||||
for sd_time in sd_time_list:
|
||||
spk_st, spk_ed, spk = sd_time
|
||||
spk_st = spk_st*1000
|
||||
spk_ed = spk_ed*1000
|
||||
overlap = max(
|
||||
min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
|
||||
if overlap > max_overlap:
|
||||
max_overlap = overlap
|
||||
sentence_spk = spk
|
||||
d['spk'] = sentence_spk
|
||||
sd_sentence_list.append(d)
|
||||
return sd_sentence_list
|
||||
|
||||
|
||||
|
||||
|
||||
class Storage(metaclass=ABCMeta):
|
||||
"""Abstract class of storage.
|
||||
|
||||
All backends need to implement two apis: ``read()`` and ``read_text()``.
|
||||
``read()`` reads the file as a byte stream and ``read_text()`` reads
|
||||
the file as texts.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read(self, filepath: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_text(self, filepath: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
pass
|
||||
|
||||
|
||||
class LocalStorage(Storage):
|
||||
"""Local hard disk storage"""
|
||||
|
||||
def read(self, filepath: Union[str, Path]) -> bytes:
|
||||
"""Read data from a given ``filepath`` with 'rb' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
|
||||
Returns:
|
||||
bytes: Expected bytes object.
|
||||
"""
|
||||
with open(filepath, 'rb') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
def read_text(self,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> str:
|
||||
"""Read data from a given ``filepath`` with 'r' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
|
||||
Returns:
|
||||
str: Expected text reading from ``filepath``.
|
||||
"""
|
||||
with open(filepath, 'r', encoding=encoding) as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
"""Write data to a given ``filepath`` with 'wb' mode.
|
||||
|
||||
Note:
|
||||
``write`` will create a directory if the directory of ``filepath``
|
||||
does not exist.
|
||||
|
||||
Args:
|
||||
obj (bytes): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
"""
|
||||
dirname = os.path.dirname(filepath)
|
||||
if dirname and not os.path.exists(dirname):
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(obj)
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
"""Write data to a given ``filepath`` with 'w' mode.
|
||||
|
||||
Note:
|
||||
``write_text`` will create a directory if the directory of
|
||||
``filepath`` does not exist.
|
||||
|
||||
Args:
|
||||
obj (str): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
"""
|
||||
dirname = os.path.dirname(filepath)
|
||||
if dirname and not os.path.exists(dirname):
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
with open(filepath, 'w', encoding=encoding) as f:
|
||||
f.write(obj)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self,
|
||||
filepath: Union[str,
|
||||
Path]) -> Generator[Union[str, Path], None, None]:
|
||||
"""Only for unified API and do nothing."""
|
||||
yield filepath
|
||||
|
||||
|
||||
class HTTPStorage(Storage):
|
||||
"""HTTP and HTTPS storage."""
|
||||
|
||||
def read(self, url):
|
||||
# TODO @wenmeng.zwm add progress bar if file is too large
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
return r.content
|
||||
|
||||
def read_text(self, url):
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
return r.text
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Download a file from ``filepath``.
|
||||
|
||||
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
||||
can be called with ``with`` statement, and when exists from the
|
||||
``with`` statement, the temporary path will be released.
|
||||
|
||||
Args:
|
||||
filepath (str): Download a file from ``filepath``.
|
||||
|
||||
Examples:
|
||||
>>> storage = HTTPStorage()
|
||||
>>> # After existing from the ``with`` clause,
|
||||
>>> # the path will be removed
|
||||
>>> with storage.get_local_path('http://path/to/file') as path:
|
||||
... # do something here
|
||||
"""
|
||||
try:
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f.write(self.read(filepath))
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.remove(f.name)
|
||||
|
||||
def write(self, obj: bytes, url: Union[str, Path]) -> None:
|
||||
raise NotImplementedError('write is not supported by HTTP Storage')
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
url: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
raise NotImplementedError(
|
||||
'write_text is not supported by HTTP Storage')
|
||||
|
||||
|
||||
class OSSStorage(Storage):
|
||||
"""OSS storage."""
|
||||
|
||||
def __init__(self, oss_config_file=None):
|
||||
# read from config file or env var
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.__init__ to be implemented in the future')
|
||||
|
||||
def read(self, filepath):
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.read to be implemented in the future')
|
||||
|
||||
def read_text(self, filepath, encoding='utf-8'):
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.read_text to be implemented in the future')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Download a file from ``filepath``.
|
||||
|
||||
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
||||
can be called with ``with`` statement, and when exists from the
|
||||
``with`` statement, the temporary path will be released.
|
||||
|
||||
Args:
|
||||
filepath (str): Download a file from ``filepath``.
|
||||
|
||||
Examples:
|
||||
>>> storage = OSSStorage()
|
||||
>>> # After existing from the ``with`` clause,
|
||||
>>> # the path will be removed
|
||||
>>> with storage.get_local_path('http://path/to/file') as path:
|
||||
... # do something here
|
||||
"""
|
||||
try:
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f.write(self.read(filepath))
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.remove(f.name)
|
||||
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.write to be implemented in the future')
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.write_text to be implemented in the future')
|
||||
|
||||
|
||||
G_STORAGES = {}
|
||||
|
||||
|
||||
class File(object):
|
||||
_prefix_to_storage: dict = {
|
||||
'oss': OSSStorage,
|
||||
'http': HTTPStorage,
|
||||
'https': HTTPStorage,
|
||||
'local': LocalStorage,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_storage(uri):
|
||||
assert isinstance(uri,
|
||||
str), f'uri should be str type, but got {type(uri)}'
|
||||
|
||||
if '://' not in uri:
|
||||
# local path
|
||||
storage_type = 'local'
|
||||
else:
|
||||
prefix, _ = uri.split('://')
|
||||
storage_type = prefix
|
||||
|
||||
assert storage_type in File._prefix_to_storage, \
|
||||
f'Unsupported uri {uri}, valid prefixs: '\
|
||||
f'{list(File._prefix_to_storage.keys())}'
|
||||
|
||||
if storage_type not in G_STORAGES:
|
||||
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
|
||||
|
||||
return G_STORAGES[storage_type]
|
||||
|
||||
@staticmethod
|
||||
def read(uri: str) -> bytes:
|
||||
"""Read data from a given ``filepath`` with 'rb' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
|
||||
Returns:
|
||||
bytes: Expected bytes object.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.read(uri)
|
||||
|
||||
@staticmethod
|
||||
def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
|
||||
"""Read data from a given ``filepath`` with 'r' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
|
||||
Returns:
|
||||
str: Expected text reading from ``filepath``.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.read_text(uri)
|
||||
|
||||
@staticmethod
|
||||
def write(obj: bytes, uri: Union[str, Path]) -> None:
|
||||
"""Write data to a given ``filepath`` with 'wb' mode.
|
||||
|
||||
Note:
|
||||
``write`` will create a directory if the directory of ``filepath``
|
||||
does not exist.
|
||||
|
||||
Args:
|
||||
obj (bytes): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.write(obj, uri)
|
||||
|
||||
@staticmethod
|
||||
def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
|
||||
"""Write data to a given ``filepath`` with 'w' mode.
|
||||
|
||||
Note:
|
||||
``write_text`` will create a directory if the directory of
|
||||
``filepath`` does not exist.
|
||||
|
||||
Args:
|
||||
obj (str): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.write_text(obj, uri)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Only for unified API and do nothing."""
|
||||
storage = File._get_storage(uri)
|
||||
with storage.as_local_path(uri) as local_path:
|
||||
yield local_path
|
||||
@ -447,7 +447,6 @@ class Paraformer(nn.Module):
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# init beamsearch
|
||||
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
||||
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||
@ -475,7 +474,6 @@ class Paraformer(nn.Module):
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user