diff --git a/examples/industrial_data_pretraining/spk_verification/demo.py b/examples/industrial_data_pretraining/spk_verification/demo.py new file mode 100644 index 000000000..0b5588f01 --- /dev/null +++ b/examples/industrial_data_pretraining/spk_verification/demo.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel(model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common") + +res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") +print(res) \ No newline at end of file diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py index c4ff69bda..2d94e70a8 100644 --- a/funasr/bin/inference.py +++ b/funasr/bin/inference.py @@ -159,6 +159,9 @@ class AutoModel: tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) kwargs["tokenizer"] = tokenizer kwargs["token_list"] = tokenizer.token_list + vocab_size = len(tokenizer.token_list) + else: + vocab_size = -1 # build frontend frontend = kwargs.get("frontend", None) @@ -170,8 +173,7 @@ class AutoModel: # build model model_class = tables.model_classes.get(kwargs["model"].lower()) - model = model_class(**kwargs, **kwargs["model_conf"], - vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1) + model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) model.eval() model.to(device) diff --git a/funasr/models/campplus/__init__.py b/funasr/models/campplus/__init__.py index ff44fed85..e69de29bb 100644 --- a/funasr/models/campplus/__init__.py +++ b/funasr/models/campplus/__init__.py @@ -1 +0,0 @@ -from .campplus import CAMPPlus diff --git a/funasr/models/campplus/layers.py b/funasr/models/campplus/components.py similarity index 86% rename from funasr/models/campplus/layers.py rename to funasr/models/campplus/components.py index 0475612a9..43d366eba 100644 --- a/funasr/models/campplus/layers.py +++ b/funasr/models/campplus/components.py @@ -7,6 +7,82 @@ import torch.utils.checkpoint as cp from torch import nn +class BasicResBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicResBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=(stride, 1), + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=(stride, 1), + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class FCM(nn.Module): + def __init__(self, + block=BasicResBlock, + num_blocks=[2, 2], + m_channels=32, + feat_dim=80): + super(FCM, self).__init__() + self.in_planes = m_channels + self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + + self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(m_channels) + self.out_channels = m_channels * (feat_dim // 8) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.unsqueeze(1) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = F.relu(self.bn2(self.conv2(out))) + + shape = out.shape + out = out.reshape(shape[0], shape[1] * shape[2], shape[3]) + return out + + def get_nonlinear(config_str, channels): nonlinear = nn.Sequential() for name in config_str.split('-'): @@ -216,39 +292,3 @@ class DenseLayer(nn.Module): return x -class BasicResBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicResBlock, self).__init__() - self.conv1 = nn.Conv2d(in_planes, - planes, - kernel_size=3, - stride=(stride, 1), - padding=1, - bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, - planes, - kernel_size=3, - stride=1, - padding=1, - bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, - self.expansion * planes, - kernel_size=1, - stride=(stride, 1), - bias=False), - nn.BatchNorm2d(self.expansion * planes)) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out diff --git a/funasr/models/campplus/campplus.py b/funasr/models/campplus/model.py similarity index 64% rename from funasr/models/campplus/campplus.py rename to funasr/models/campplus/model.py index 88113ece0..84938cc35 100644 --- a/funasr/models/campplus/campplus.py +++ b/funasr/models/campplus/model.py @@ -1,54 +1,24 @@ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import os +import time +import torch +import logging +import numpy as np +import torch.nn as nn from collections import OrderedDict +from typing import Union, Dict, List, Tuple, Optional -import torch.nn.functional as F -from torch import nn - - -from funasr.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \ - BasicResBlock, get_nonlinear - - -class FCM(nn.Module): - def __init__(self, - block=BasicResBlock, - num_blocks=[2, 2], - m_channels=32, - feat_dim=80): - super(FCM, self).__init__() - self.in_planes = m_channels - self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(m_channels) - - self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) - self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2) - - self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(m_channels) - self.out_channels = m_channels * (feat_dim // 8) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - x = x.unsqueeze(1) - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = F.relu(self.bn2(self.conv2(out))) - - shape = out.shape - out = out.reshape(shape[0], shape[1] * shape[2], shape[3]) - return out +from funasr.utils.load_utils import load_audio_text_image_video +from funasr.utils.datadir_writer import DatadirWriter +from funasr.register import tables +from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \ + BasicResBlock, get_nonlinear, FCM +from funasr.models.campplus.utils import extract_feature +@tables.register("model_classes", "CAMPPlus") class CAMPPlus(nn.Module): def __init__(self, feat_dim=80, @@ -58,8 +28,9 @@ class CAMPPlus(nn.Module): init_channels=128, config_str='batchnorm-relu', memory_efficient=True, - output_level='segment'): - super(CAMPPlus, self).__init__() + output_level='segment', + **kwargs,): + super().__init__() self.head = FCM(feat_dim=feat_dim) channels = self.head.out_channels @@ -123,3 +94,28 @@ class CAMPPlus(nn.Module): if self.output_level == 'frame': x = x.transpose(1, 2) return x + + def generate(self, + data_in, + data_lengths=None, + key: list=None, + tokenizer=None, + frontend=None, + **kwargs, + ): + # extract fbank feats + meta_data = {} + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound") + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_feature(audio_sample_list) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0 + # import pdb; pdb.set_trace() + results = [] + embeddings = self.forward(speech) + for embedding in embeddings: + results.append({"spk_embedding":embedding}) + return results, meta_data \ No newline at end of file diff --git a/funasr/models/campplus/template.yaml b/funasr/models/campplus/template.yaml new file mode 100644 index 000000000..38dcfde3e --- /dev/null +++ b/funasr/models/campplus/template.yaml @@ -0,0 +1,23 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: CAMPPlus +model_conf: + feat_dim: 80 + embedding_size: 192 + growth_rate: 32 + bn_size: 4 + init_channels: 128 + config_str: 'batchnorm-relu' + memory_efficient: True + output_level: 'segment' + +# frontend related +frontend: WavFrontend +frontend_conf: + fs: 16000 diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py new file mode 100644 index 000000000..c86a9f055 --- /dev/null +++ b/funasr/models/campplus/utils.py @@ -0,0 +1,533 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import io +from typing import Union + +import librosa as sf +import numpy as np +import torch +import torch.nn.functional as F +import torchaudio.compliance.kaldi as Kaldi +from torch import nn + +import contextlib +import os +import tempfile +from abc import ABCMeta, abstractmethod +from pathlib import Path +from typing import Generator, Union + +import requests + + +def check_audio_list(audio: list): + audio_dur = 0 + for i in range(len(audio)): + seg = audio[i] + assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.' + assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.' + assert int(seg[1] * 16000) - int( + seg[0] * 16000 + ) == seg[2].shape[ + 0], 'modelscope error: audio data in list is inconsistent with time length.' + if i > 0: + assert seg[0] >= audio[ + i - 1][1], 'modelscope error: Wrong time stamps.' + audio_dur += seg[1] - seg[0] + return audio_dur + # assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.' + + +def sv_preprocess(inputs: Union[np.ndarray, list]): + output = [] + for i in range(len(inputs)): + if isinstance(inputs[i], str): + file_bytes = File.read(inputs[i]) + data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32') + if len(data.shape) == 2: + data = data[:, 0] + data = torch.from_numpy(data).unsqueeze(0) + data = data.squeeze(0) + elif isinstance(inputs[i], np.ndarray): + assert len( + inputs[i].shape + ) == 1, 'modelscope error: Input array should be [N, T]' + data = inputs[i] + if data.dtype in ['int16', 'int32', 'int64']: + data = (data / (1 << 15)).astype('float32') + else: + data = data.astype('float32') + data = torch.from_numpy(data) + else: + raise ValueError( + 'modelscope error: The input type is restricted to audio address and nump array.' + ) + output.append(data) + return output + + +def sv_chunk(vad_segments: list, fs = 16000) -> list: + config = { + 'seg_dur': 1.5, + 'seg_shift': 0.75, + } + def seg_chunk(seg_data): + seg_st = seg_data[0] + data = seg_data[2] + chunk_len = int(config['seg_dur'] * fs) + chunk_shift = int(config['seg_shift'] * fs) + last_chunk_ed = 0 + seg_res = [] + for chunk_st in range(0, data.shape[0], chunk_shift): + chunk_ed = min(chunk_st + chunk_len, data.shape[0]) + if chunk_ed <= last_chunk_ed: + break + last_chunk_ed = chunk_ed + chunk_st = max(0, chunk_ed - chunk_len) + chunk_data = data[chunk_st:chunk_ed] + if chunk_data.shape[0] < chunk_len: + chunk_data = np.pad(chunk_data, + (0, chunk_len - chunk_data.shape[0]), + 'constant') + seg_res.append([ + chunk_st / fs + seg_st, chunk_ed / fs + seg_st, + chunk_data + ]) + return seg_res + + segs = [] + for i, s in enumerate(vad_segments): + segs.extend(seg_chunk(s)) + + return segs + + +def extract_feature(audio): + features = [] + feature_lengths = [] + for au in audio: + feature = Kaldi.fbank( + au.unsqueeze(0), num_mel_bins=80) + feature = feature - feature.mean(dim=0, keepdim=True) + features.append(feature.unsqueeze(0)) + feature_lengths.append(au.shape[0]) + features = torch.cat(features) + return features, feature_lengths + + +def postprocess(segments: list, vad_segments: list, + labels: np.ndarray, embeddings: np.ndarray) -> list: + assert len(segments) == len(labels) + labels = correct_labels(labels) + distribute_res = [] + for i in range(len(segments)): + distribute_res.append([segments[i][0], segments[i][1], labels[i]]) + # merge the same speakers chronologically + distribute_res = merge_seque(distribute_res) + + # accquire speaker center + spk_embs = [] + for i in range(labels.max() + 1): + spk_emb = embeddings[labels == i].mean(0) + spk_embs.append(spk_emb) + spk_embs = np.stack(spk_embs) + + def is_overlapped(t1, t2): + if t1 > t2 + 1e-4: + return True + return False + + # distribute the overlap region + for i in range(1, len(distribute_res)): + if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]): + p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2 + distribute_res[i][0] = p + distribute_res[i - 1][1] = p + + # smooth the result + distribute_res = smooth(distribute_res) + + return distribute_res + + +def correct_labels(labels): + labels_id = 0 + id2id = {} + new_labels = [] + for i in labels: + if i not in id2id: + id2id[i] = labels_id + labels_id += 1 + new_labels.append(id2id[i]) + return np.array(new_labels) + +def merge_seque(distribute_res): + res = [distribute_res[0]] + for i in range(1, len(distribute_res)): + if distribute_res[i][2] != res[-1][2] or distribute_res[i][ + 0] > res[-1][1]: + res.append(distribute_res[i]) + else: + res[-1][1] = distribute_res[i][1] + return res + +def smooth(res, mindur=1): + # short segments are assigned to nearest speakers. + for i in range(len(res)): + res[i][0] = round(res[i][0], 2) + res[i][1] = round(res[i][1], 2) + if res[i][1] - res[i][0] < mindur: + if i == 0: + res[i][2] = res[i + 1][2] + elif i == len(res) - 1: + res[i][2] = res[i - 1][2] + elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]: + res[i][2] = res[i - 1][2] + else: + res[i][2] = res[i + 1][2] + # merge the speakers + res = merge_seque(res) + + return res + + +def distribute_spk(sentence_list, sd_time_list): + sd_sentence_list = [] + for d in sentence_list: + sentence_start = d['ts_list'][0][0] + sentence_end = d['ts_list'][-1][1] + sentence_spk = 0 + max_overlap = 0 + for sd_time in sd_time_list: + spk_st, spk_ed, spk = sd_time + spk_st = spk_st*1000 + spk_ed = spk_ed*1000 + overlap = max( + min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0) + if overlap > max_overlap: + max_overlap = overlap + sentence_spk = spk + d['spk'] = sentence_spk + sd_sentence_list.append(d) + return sd_sentence_list + + + + +class Storage(metaclass=ABCMeta): + """Abstract class of storage. + + All backends need to implement two apis: ``read()`` and ``read_text()``. + ``read()`` reads the file as a byte stream and ``read_text()`` reads + the file as texts. + """ + + @abstractmethod + def read(self, filepath: str): + pass + + @abstractmethod + def read_text(self, filepath: str): + pass + + @abstractmethod + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + pass + + @abstractmethod + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + pass + + +class LocalStorage(Storage): + """Local hard disk storage""" + + def read(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + with open(filepath, 'rb') as f: + content = f.read() + return content + + def read_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, 'r', encoding=encoding) as f: + value_buf = f.read() + return value_buf + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``write`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + + with open(filepath, 'wb') as f: + f.write(obj) + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``write_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + @contextlib.contextmanager + def as_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + yield filepath + + +class HTTPStorage(Storage): + """HTTP and HTTPS storage.""" + + def read(self, url): + # TODO @wenmeng.zwm add progress bar if file is too large + r = requests.get(url) + r.raise_for_status() + return r.content + + def read_text(self, url): + r = requests.get(url) + r.raise_for_status() + return r.text + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = HTTPStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, url: Union[str, Path]) -> None: + raise NotImplementedError('write is not supported by HTTP Storage') + + def write_text(self, + obj: str, + url: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'write_text is not supported by HTTP Storage') + + +class OSSStorage(Storage): + """OSS storage.""" + + def __init__(self, oss_config_file=None): + # read from config file or env var + raise NotImplementedError( + 'OSSStorage.__init__ to be implemented in the future') + + def read(self, filepath): + raise NotImplementedError( + 'OSSStorage.read to be implemented in the future') + + def read_text(self, filepath, encoding='utf-8'): + raise NotImplementedError( + 'OSSStorage.read_text to be implemented in the future') + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = OSSStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + raise NotImplementedError( + 'OSSStorage.write to be implemented in the future') + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'OSSStorage.write_text to be implemented in the future') + + +G_STORAGES = {} + + +class File(object): + _prefix_to_storage: dict = { + 'oss': OSSStorage, + 'http': HTTPStorage, + 'https': HTTPStorage, + 'local': LocalStorage, + } + + @staticmethod + def _get_storage(uri): + assert isinstance(uri, + str), f'uri should be str type, but got {type(uri)}' + + if '://' not in uri: + # local path + storage_type = 'local' + else: + prefix, _ = uri.split('://') + storage_type = prefix + + assert storage_type in File._prefix_to_storage, \ + f'Unsupported uri {uri}, valid prefixs: '\ + f'{list(File._prefix_to_storage.keys())}' + + if storage_type not in G_STORAGES: + G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]() + + return G_STORAGES[storage_type] + + @staticmethod + def read(uri: str) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + storage = File._get_storage(uri) + return storage.read(uri) + + @staticmethod + def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + storage = File._get_storage(uri) + return storage.read_text(uri) + + @staticmethod + def write(obj: bytes, uri: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``write`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + storage = File._get_storage(uri) + return storage.write(obj, uri) + + @staticmethod + def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``write_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + storage = File._get_storage(uri) + return storage.write_text(obj, uri) + + @contextlib.contextmanager + def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + storage = File._get_storage(uri) + with storage.as_local_path(uri) as local_path: + yield local_path diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index 9ee4dfc50..78a72ec66 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -447,7 +447,6 @@ class Paraformer(nn.Module): frontend=None, **kwargs, ): - # init beamsearch is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None @@ -475,7 +474,6 @@ class Paraformer(nn.Module): meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) - # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple):