diff --git a/funasr/models/sense_voice/quantizer/__init__.py b/funasr/models/sense_voice/quantizer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models/sense_voice/quantizer/costume_quantizer.py b/funasr/models/sense_voice/quantizer/costume_quantizer.py deleted file mode 100644 index a8aade32f..000000000 --- a/funasr/models/sense_voice/quantizer/costume_quantizer.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch -from funasr.models.sense_voice.quantizer.quantization.vq import ResidualVectorQuantizer -import typing as tp - - -class CostumeQuantizer(torch.nn.Module): - def __init__( - self, - input_size: int = 512, - codebook_size: int = 1024, - num_quantizers: int = 8, - ema_decay: float = 0.95, - kmeans_init: tp.Union[bool, str] = False, - sampling_rate: int = 16_000, - quantize_dropout: bool = False, - rand_num_quant: tp.Optional[tp.List] = None, - encoder_hop_length: int = 320, - use_ddp: bool = True, - q0_ds_ratio: int = 1, - codec_dim: int = None, - codec_range: float = None, - threshold_ema_dead_code=2, - **kwargs, - ): - super().__init__() - if codec_dim is None: - codec_dim = input_size - - self.input_proj, self.output_proj = None, None - if codec_dim != input_size: - self.input_proj = torch.nn.Linear(input_size, codec_dim) - self.output_proj = torch.nn.Linear(codec_dim, input_size) - - self.input_act, self.codec_range = None, None - if codec_range is not None: - self.input_act = torch.nn.Tanh() - self.codec_range = codec_range - - self.rq = ResidualVectorQuantizer( - dimension=codec_dim, - n_q=num_quantizers, - bins=codebook_size, - decay=ema_decay, - kmeans_init=kmeans_init, - quantize_dropout=quantize_dropout, - rand_num_quant=rand_num_quant, - encoder_hop_length=encoder_hop_length, - use_ddp=use_ddp, - q0_ds_ratio=q0_ds_ratio, - threshold_ema_dead_code=threshold_ema_dead_code, - **kwargs, - ) - self.code_dim = codec_dim - self.sampling_rate = sampling_rate - self.bandwidth: tp.Optional[float] = None - self.encoder_hop_length = encoder_hop_length - self.codebook_size = codebook_size - - def forward( - self, - x, - bandwidth: int = None, - ): - # x: input tensor in the shape of (B, T, C) - # rq requires inputs in (B, C, T) - - if self.input_proj is not None: - x = self.input_proj(x) - if self.input_act is not None: - x = self.input_act(x) * self.codec_range - - qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth) - x, indices, commit_loss, sub_quants = qv.quantized, qv.codes, qv.penalty, qv.sub_quants - - x = x.permute(0, 2, 1) - if self.output_proj is not None: - x = self.output_proj(x) - - return x, indices, commit_loss, sub_quants - - def inference( - self, - x, - bandwidth: int = None, - ): - # x: input tensor in the shape of (B, T, C) - # rq requires inputs in (B, C, T) - if self.input_proj is not None: - x = self.input_proj(x) - if self.input_act is not None: - x = self.input_act(x) * self.codec_range - - qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth) - x, indices, sub_quants = qv.quantized, qv.codes, qv.sub_quants - - x = x.permute(0, 2, 1) - if self.output_proj is not None: - x = self.output_proj(x) - - return x, indices, sub_quants - - def encode( - self, - x, - bandwidth: int = None, - ): - # x: input tensor in the shape of (B, T, C) - # rq requires inputs in (B, C, T) - if self.input_proj is not None: - x = self.input_proj(x) - if self.input_act is not None: - x = self.input_act(x) * self.codec_range - - indices = self.rq.encode(x.permute(0, 2, 1), self.sampling_rate, bandwidth) - # return value in n_q x B x T - return indices - - def decode(self, indices): - quantized_out = self.rq.decode(indices) - # quantized_out in B x D x T - if self.output_proj is not None: - quantized_out = self.output_proj(quantized_out.transpose(1, 2)).transpose(1, 2) - return quantized_out - - def output_size(self): - return self.code_dim diff --git a/funasr/models/sense_voice/quantizer/finite_scalar_quantizer.py b/funasr/models/sense_voice/quantizer/finite_scalar_quantizer.py deleted file mode 100644 index 9a30deb46..000000000 --- a/funasr/models/sense_voice/quantizer/finite_scalar_quantizer.py +++ /dev/null @@ -1,299 +0,0 @@ -""" -Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 -Code adapted from Jax version in Appendix A.1 -""" -import random -from typing import List, Optional - -import torch -import torch.nn as nn -from torch.nn import Module -from torch import Tensor, int32 - -from einops import rearrange, pack, unpack - - -# helper functions - -def exists(v): - return v is not None - - -def default(*args): - for arg in args: - if exists(arg): - return arg - return None - - -def pack_one(t, pattern): - return pack([t], pattern) - - -def unpack_one(t, ps, pattern): - return unpack(t, ps, pattern)[0] - - -# tensor helpers - -def round_ste(z: Tensor) -> Tensor: - """Round with straight through gradients.""" - zhat = z.round() - return z + (zhat - z).detach() - - -# main class - -class FSQ(Module): - def __init__( - self, - levels: List[int], - input_size: Optional[int] = None, - num_codebooks=1, - keep_num_codebooks_dim: Optional[bool] = None, - scale: Optional[float] = None - ): - super().__init__() - _levels = torch.tensor(levels, dtype=int32) - self.register_buffer("_levels", _levels, persistent=False) - - _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) - self.register_buffer("_basis", _basis, persistent=False) - - self.scale = scale - - codebook_dim = len(levels) - self.codebook_dim = codebook_dim - - effective_codebook_dim = codebook_dim * num_codebooks - self.num_codebooks = num_codebooks - self.effective_codebook_dim = effective_codebook_dim - - keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) - assert not (num_codebooks > 1 and not keep_num_codebooks_dim) - self.keep_num_codebooks_dim = keep_num_codebooks_dim - - self.dim = default(input_size, len(_levels) * num_codebooks) - - has_projections = self.dim != effective_codebook_dim - self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() - self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() - self.has_projections = has_projections - - self.codebook_size = self._levels.prod().item() - - implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) - self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) - - def output_size(self): - return self.dim - - def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor: - """Bound `z`, an array of shape (..., d).""" - half_l = (self._levels - 1) * (1 - eps) / 2 - offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) - shift = (offset / half_l).atanh() - return (z + shift).tanh() * half_l - offset - - def quantize(self, z: Tensor) -> Tensor: - """Quantizes z, returns quantized zhat, same shape as z.""" - quantized = round_ste(self.bound(z)) - half_width = self._levels // 2 # Renormalize to [-1, 1]. - return quantized / half_width - - def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor: - half_width = self._levels // 2 - return (zhat_normalized * half_width) + half_width - - def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor: - half_width = self._levels // 2 - return (zhat - half_width) / half_width - - def codes_to_indices(self, zhat: Tensor) -> Tensor: - """Converts a `code` to an index in the codebook.""" - assert zhat.shape[-1] == self.codebook_dim - zhat = self._scale_and_shift(zhat) - return (zhat * self._basis).sum(dim=-1).to(int32) - - def indices_to_codes( - self, - indices: Tensor, - project_out=True - ) -> Tensor: - """Inverse of `codes_to_indices`.""" - - is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) - - indices = rearrange(indices, '... -> ... 1') - codes_non_centered = (indices // self._basis) % self._levels - codes = self._scale_and_shift_inverse(codes_non_centered) - - if self.keep_num_codebooks_dim: - codes = rearrange(codes, '... c d -> ... (c d)') - - if project_out: - codes = self.project_out(codes) - - if is_img_or_video: - codes = rearrange(codes, 'b ... d -> b d ...') - - return codes - - def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - """ - einstein notation - b - batch - n - sequence (or flattened spatial dimensions) - d - feature dimension, which is also log2(codebook size) - c - number of codebook dim - """ - - is_img_or_video = z.ndim >= 4 - - # standardize image or video into (batch, seq, dimension) - - if is_img_or_video: - z = rearrange(z, 'b d ... -> b ... d') - z, ps = pack_one(z, 'b * d') - - assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}' - - z = self.project_in(z) - - z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks) - - codes = self.quantize(z) - indices = self.codes_to_indices(codes) - - codes = rearrange(codes, 'b n c d -> b n (c d)') - - out = self.project_out(codes) - - # reconstitute image or video dimensions - - if is_img_or_video: - out = unpack_one(out, ps, 'b * d') - out = rearrange(out, 'b ... d -> b d ...') - - indices = unpack_one(indices, ps, 'b * c') - - if not self.keep_num_codebooks_dim: - indices = rearrange(indices, '... 1 -> ...') - - commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device) - return out, indices, commit_loss, None - - def inference( - self, - x, - bandwidth: int = None, - ): - x, indices, _, _ = self.forward(x, bandwidth=bandwidth) - - return x, indices, None - - -class BinaryFSQ(FSQ): - def __init__( - self, - levels: List[int], - input_size: Optional[int] = None, - num_codebooks=1, - keep_num_codebooks_dim: Optional[bool] = None, - scale: Optional[float] = None, - rand_num_codebooks: Optional[List] = None, - ): - _levels = torch.tensor(levels, dtype=int32) - assert torch.all(_levels == 2), "BinaryFSQ requires the levels must be 2" - super().__init__( - levels, input_size, num_codebooks, - keep_num_codebooks_dim, scale - ) - self.rand_num_codebooks = rand_num_codebooks - - def output_size(self): - return self.dim - - def bound(self, z: Tensor, eps=1e-3) -> Tensor: - """Bound `z`, an array of shape (..., d).""" - return torch.sigmoid(z) - - def quantize(self, z: Tensor) -> Tensor: - """Quantizes z, returns quantized zhat, same shape as z.""" - quantized = round_ste(self.bound(z)) - return quantized - - def codes_to_indices(self, zhat: Tensor) -> Tensor: - """Converts a `code` to an index in the codebook.""" - assert zhat.shape[-1] == self.codebook_dim - return (zhat * self._basis).sum(dim=-1).to(int32) - - def indices_to_codes( - self, - indices: Tensor, - project_out=True - ) -> Tensor: - """Inverse of `codes_to_indices`.""" - - is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) - - indices = rearrange(indices, '... -> ... 1') - codes = (indices // self._basis) % self._levels - - if self.keep_num_codebooks_dim: - codes = rearrange(codes, '... c d -> ... (c d)') - - if project_out: - codes = self.project_out(codes) - - if is_img_or_video: - codes = rearrange(codes, 'b ... d -> b d ...') - - return codes - - def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - """ - einstein notation - b - batch - n - sequence (or flattened spatial dimensions) - d - feature dimension, which is also log2(codebook size) - c - number of codebook dim - """ - - is_img_or_video = z.ndim >= 4 - - # standardize image or video into (batch, seq, dimension) - - if is_img_or_video: - z = rearrange(z, 'b d ... -> b ... d') - z, ps = pack_one(z, 'b * d') - - assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}' - - z = self.project_in(z) - - z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks) - - codes = self.quantize(z) - if self.rand_num_codebooks is not None: - quant_idx = random.choice(self.rand_num_codebooks) - codes[:, :, quant_idx:, :] = 0 - indices = self.codes_to_indices(codes) - - codes = rearrange(codes, 'b n c d -> b n (c d)') - - out = self.project_out(codes) - - # reconstitute image or video dimensions - - if is_img_or_video: - out = unpack_one(out, ps, 'b * d') - out = rearrange(out, 'b ... d -> b d ...') - - indices = unpack_one(indices, ps, 'b * c') - - if not self.keep_num_codebooks_dim: - indices = rearrange(indices, '... 1 -> ...') - - commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device) - return out, indices, commit_loss, None diff --git a/funasr/models/sense_voice/quantizer/lookup_free_quantizer.py b/funasr/models/sense_voice/quantizer/lookup_free_quantizer.py deleted file mode 100644 index 793ae99ae..000000000 --- a/funasr/models/sense_voice/quantizer/lookup_free_quantizer.py +++ /dev/null @@ -1,493 +0,0 @@ -""" -Lookup Free Quantization -Proposed in https://arxiv.org/abs/2310.05737 - -In the simplest setup, each dimension is quantized into {-1, 1}. -An entropy penalty is used to encourage utilization. -""" -import random -from math import log2, ceil -from collections import namedtuple -from typing import Optional, List - -import numpy as np -import torch -from torch import nn, einsum -import torch.nn.functional as F -from torch.nn import Module - -from einops import rearrange, reduce, pack, unpack - -# constants - -Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss']) - -LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment']) - - -# helper functions - -def exists(v): - return v is not None - - -def default(*args): - for arg in args: - if exists(arg): - return arg() if callable(arg) else arg - return None - - -def pack_one(t, pattern): - return pack([t], pattern) - - -def unpack_one(t, ps, pattern): - return unpack(t, ps, pattern)[0] - - -# entropy - -def log(t, eps=1e-5): - return t.clamp(min=eps).log() - - -def entropy(prob): - return (-prob * log(prob)).sum(dim=-1) - - -# class - -class LFQ(Module): - def __init__( - self, - *, - input_size=None, - dim=None, - codebook_size=None, - entropy_loss_weight=0.1, - commitment_loss_weight=0.25, - diversity_gamma=1., - straight_through_activation="identity", - num_codebooks=1, - keep_num_codebooks_dim=None, - codebook_scale=1., # for residual LFQ, codebook scaled down by 2x at each layer - rand_num_codebooks: Optional[List] = None, - sampling_rate=16000, - encoder_hop_length=640, - ): - super().__init__() - - # some assert validations - dim = input_size - assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ' - assert not exists(codebook_size) or log2( - codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})' - - codebook_size = default(codebook_size, lambda: 2 ** dim) - codebook_dim = int(log2(codebook_size)) - - codebook_dims = codebook_dim * num_codebooks - dim = default(dim, codebook_dims) - - has_projections = dim != codebook_dims - self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity() - self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity() - self.has_projections = has_projections - - self.dim = dim - self.codebook_dim = codebook_dim - self.num_codebooks = num_codebooks - - keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) - assert not (num_codebooks > 1 and not keep_num_codebooks_dim) - self.keep_num_codebooks_dim = keep_num_codebooks_dim - - # straight through activation - - if straight_through_activation == "identity": - self.activation = nn.Identity() - elif straight_through_activation == "tanh": - self.activation = nn.Tanh() - else: - raise NotImplementedError("Unsupported activation type, only 'tanh' and 'identity' are supported") - - # entropy aux loss related weights - - self.diversity_gamma = diversity_gamma - self.entropy_loss_weight = entropy_loss_weight - - # codebook scale - - self.codebook_scale = codebook_scale - - # commitment loss - - self.commitment_loss_weight = commitment_loss_weight - - # for no auxiliary loss, during inference - - self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1)) - self.register_buffer('zero', torch.tensor(0.), persistent=False) - - # codes - - all_codes = torch.arange(codebook_size) - bits = ((all_codes[..., None].int() & self.mask) != 0).float() - codebook = self.bits_to_codes(bits) - - self.register_buffer('codebook', codebook, persistent=False) - self.rand_num_codebooks = rand_num_codebooks - self.sampling_rate = sampling_rate - self.encoder_hop_length = encoder_hop_length - - def output_size(self): - return self.dim - - def bits_to_codes(self, bits): - return bits * self.codebook_scale * 2 - self.codebook_scale - - @property - def dtype(self): - return self.codebook.dtype - - def indices_to_codes( - self, - indices, - project_out=True - ): - is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) - - if not self.keep_num_codebooks_dim: - indices = rearrange(indices, '... -> ... 1') - - # indices to codes, which are bits of either -1 or 1 - - bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) - - codes = self.bits_to_codes(bits) - - codes = rearrange(codes, '... c d -> ... (c d)') - - # whether to project codes out to original dimensions - # if the input feature dimensions were not log2(codebook size) - - if project_out: - codes = self.project_out(codes) - - # rearrange codes back to original shape - - if is_img_or_video: - codes = rearrange(codes, 'b ... d -> b d ...') - - return codes - - def keep_first_nq_codes(self, x, nq=None): - if nq is None or nq >= self.num_codebooks: - return x - inv_p = 1.0 / (nq / self.num_codebooks) - x[:, :, nq:] = 0 - x[:, :, :nq] = x[:, :, :nq] * inv_p - - return x - - def random_dropout_codes(self, inputs): - x = torch.clone(inputs) - rand_num = random.choice(self.rand_num_codebooks) - return self.keep_first_nq_codes(x, nq=rand_num) - - def cal_num_quant(self, bite_width): - frame_rate = self.sampling_rate / self.encoder_hop_length - nq = bite_width / frame_rate / self.codebook_dim - return nq - - def forward( - self, - x, - inv_temperature=100., - return_loss_breakdown=False, - mask=None, - bite_width=None, - ): - """ - einstein notation - b - batch - n - sequence (or flattened spatial dimensions) - d - feature dimension, which is also log2(codebook size) - c - number of codebook dim - """ - - is_img_or_video = x.ndim >= 4 - - # standardize image or video into (batch, seq, dimension) - - if is_img_or_video: - x = rearrange(x, 'b d ... -> b ... d') - x, ps = pack_one(x, 'b * d') - - assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' - - x = self.project_in(x) - x = self.activation(x) - - # split out number of codebooks - - x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks) - - # quantize by eq 3. - - original_input = x - - codebook_value = torch.ones_like(x) * self.codebook_scale - # do quantization - quantized = torch.where(x > 0, codebook_value, -codebook_value) - - # use straight-through gradients (optionally with custom activation fn) if training - - if self.training: - x = x + (quantized - x).detach() - else: - x = quantized - - # calculate indices - - indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') - - # entropy aux loss - - if self.training: - # the same as euclidean distance up to a constant - distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook) - - prob = (-distance * inv_temperature).softmax(dim=-1) - - per_sample_entropy = entropy(prob).mean() - - # account for mask - - if exists(mask): - prob = prob[mask] - - # distribution over all available tokens in the batch - - avg_prob = reduce(prob, '... c d -> c d', 'mean') - codebook_entropy = entropy(avg_prob).mean() - - # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions - # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch - - entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy - else: - # if not training, just return dummy 0 - entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero - - # commit loss - - if self.training: - commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none') - - if exists(mask): - commit_loss = commit_loss[mask] - - commit_loss = commit_loss.mean() - else: - commit_loss = self.zero - - # randomly dropout codebooks to fit varying bite width - if self.training and self.rand_num_codebooks is not None: - x = self.random_dropout_codes(x) - if bite_width is not None: - x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width)) - - # merge back codebook dim - - x = rearrange(x, 'b n c d -> b n (c d)') - - # project out to feature dimension if needed - - x = self.project_out(x) - - # reconstitute image or video dimensions - - if is_img_or_video: - x = unpack_one(x, ps, 'b * d') - x = rearrange(x, 'b ... d -> b d ...') - - indices = unpack_one(indices, ps, 'b * c') - - # whether to remove single codebook dim - - if not self.keep_num_codebooks_dim: - indices = rearrange(indices, '... 1 -> ...') - - # complete aux loss - - aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight - - if not return_loss_breakdown: - return x, indices, aux_loss, None - - return x, indices, aux_loss, dict( - per_sample_entropy=per_sample_entropy, - codebook_entropy=codebook_entropy, - commit_loss=commit_loss - ) - - def inference( - self, - x, - bandwidth: int = None, - ): - x, indices, _, _ = self.forward(x, bite_width=bandwidth) - - return x, indices, None - - -class ScalableLFQ(LFQ): - def __init__(self, *, input_size=None, dim=None, codebook_size=None, entropy_loss_weight=0.1, - commitment_loss_weight=0.25, diversity_gamma=1., straight_through_activation=nn.Identity(), - num_codebooks=1, keep_num_codebooks_dim=None, codebook_scale=1., - rand_num_codebooks: Optional[List] = None, sampling_rate=16000, hop_length=640, **kwargs): - super().__init__(input_size=input_size, dim=dim, codebook_size=codebook_size, - entropy_loss_weight=entropy_loss_weight, commitment_loss_weight=commitment_loss_weight, - diversity_gamma=diversity_gamma, straight_through_activation=straight_through_activation, - num_codebooks=num_codebooks, keep_num_codebooks_dim=keep_num_codebooks_dim, - codebook_scale=codebook_scale, rand_num_codebooks=rand_num_codebooks, - sampling_rate=sampling_rate, hop_length=hop_length) - codebook_alpha_conf = kwargs.get("codebook_alpha_conf", None) - self.init_codebook_alpha(codebook_alpha_conf) - - def init_codebook_alpha(self, codebook_alpha_conf: dict): - assert codebook_alpha_conf is not None, "codebook_alpha_conf cannot be None" - name = codebook_alpha_conf.get("name", "constant") - if name == "constant": - alphas = codebook_alpha_conf.get("alphas", [1.0] * self.num_codebooks) - assert len(alphas) == self.num_codebooks, \ - f"the length of codebook alphas {len(alphas)} " \ - f"must match num_codebooks {self.num_codebooks}." - alphas = np.array(alphas) - elif name == "exponential": - temp = codebook_alpha_conf.get("temp", 8.0) - alphas = np.exp(-np.arange(0, self.num_codebooks) / temp) - else: - raise TypeError(f"Unknown codebook alpha type {name}.") - codebook_alpha = torch.tensor(alphas/alphas.sum(), dtype=torch.float32).reshape(1, 1, -1, 1) - self.register_buffer("codebook_alpha", codebook_alpha) - - def forward( - self, - x, - inv_temperature=100., - return_loss_breakdown=False, - mask=None, - bite_width=None - ): - """ - einstein notation - b - batch - n - sequence (or flattened spatial dimensions) - d - feature dimension, which is also log2(codebook size) - c - number of codebook dim - """ - - is_img_or_video = x.ndim >= 4 - - # standardize image or video into (batch, seq, dimension) - - if is_img_or_video: - x = rearrange(x, 'b d ... -> b ... d') - x, ps = pack_one(x, 'b * d') - - assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' - - x = self.project_in(x) - - # split out number of codebooks - x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks) - - # quantize by eq 3. - original_input = x - codebook_value = torch.ones_like(x) * self.codebook_scale - quantized = torch.where(x > 0, codebook_value, -codebook_value) - - # use straight-through gradients (optionally with custom activation fn) if training - if self.training: - x = self.activation(x) - x = x + (quantized - x).detach() - else: - x = quantized - - # calculate indices - indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') - - # entropy aux loss - if self.training: - # the same as euclidean distance up to a constant - distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook) - prob = (-distance * inv_temperature).softmax(dim=-1) - per_sample_entropy = entropy(prob).mean() - - # account for mask - if exists(mask): - prob = prob[mask] - - # distribution over all available tokens in the batch - avg_prob = reduce(prob, '... c d -> c d', 'mean') - codebook_entropy = entropy(avg_prob).mean() - - # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions - # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch - entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy - else: - # if not training, just return dummy 0 - entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero - - # commit loss - if self.training: - commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none') - - if exists(mask): - commit_loss = commit_loss[mask] - - commit_loss = commit_loss.mean() - else: - commit_loss = self.zero - - # randomly dropout codebooks to fit s bite width - if self.training and self.rand_num_codebooks is not None: - x = self.random_dropout_codes(x) - if bite_width is not None: - x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width)) - - x = x * self.codebook_alpha - - # merge back codebook dim - x = rearrange(x, 'b n c d -> b n (c d)') - - # project out to feature dimension if needed - - x = self.project_out(x) - - # reconstitute image or video dimensions - - if is_img_or_video: - x = unpack_one(x, ps, 'b * d') - x = rearrange(x, 'b ... d -> b d ...') - - indices = unpack_one(indices, ps, 'b * c') - - # whether to remove single codebook dim - if not self.keep_num_codebooks_dim: - indices = rearrange(indices, '... 1 -> ...') - - # complete aux loss - aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight - - if not return_loss_breakdown: - return x, indices, aux_loss, None - - return x, indices, aux_loss, dict( - per_sample_entropy=per_sample_entropy, - codebook_entropy=codebook_entropy, - commit_loss=commit_loss - ) diff --git a/funasr/models/sense_voice/quantizer/quantization/__init__.py b/funasr/models/sense_voice/quantizer/quantization/__init__.py deleted file mode 100644 index 9d174986b..000000000 --- a/funasr/models/sense_voice/quantizer/quantization/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# flake8: noqa -from funasr.modules.quantization.vq import QuantizedResult, ResidualVectorQuantizer diff --git a/funasr/models/sense_voice/quantizer/quantization/ac.py b/funasr/models/sense_voice/quantizer/quantization/ac.py deleted file mode 100644 index a3c0fd46c..000000000 --- a/funasr/models/sense_voice/quantizer/quantization/ac.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Arithmetic coder.""" - -import io -import math -import random -import typing as tp -import torch - -from funasr.models.sense_voice.quantizer.quantization.binary import BitPacker, BitUnpacker - - -def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, - roundoff: float = 1e-8, min_range: int = 2, - check: bool = True) -> torch.Tensor: - """Turn the given PDF into a quantized CDF that splits - [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional - to the PDF. - - Args: - pdf (torch.Tensor): probability distribution, shape should be `[N]`. - total_range_bits (int): see `ArithmeticCoder`, the typical range we expect - during the coding process is `[0, 2 ** total_range_bits - 1]`. - roundoff (float): will round the pdf up to that level to remove difference coming - from e.g. evaluating the Language Model on different architectures. - min_range (int): minimum range width. Should always be at least 2 for numerical - stability. Use this to avoid pathological behavior is a value - that is expected to be rare actually happens in real life. - check (bool): if True, checks that nothing bad happened, can be deactivated for speed. - """ - pdf = pdf.detach() - if roundoff: - pdf = (pdf / roundoff).floor() * roundoff - # interpolate with uniform distribution to achieve desired minimum probability. - total_range = 2 ** total_range_bits - cardinality = len(pdf) - alpha = min_range * cardinality / total_range - assert alpha <= 1, "you must reduce min_range" - ranges = (((1 - alpha) * total_range) * pdf).floor().long() - ranges += min_range - quantized_cdf = torch.cumsum(ranges, dim=-1) - if min_range < 2: - raise ValueError("min_range must be at least 2.") - if check: - assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] - if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: - raise ValueError("You must increase your total_range_bits.") - return quantized_cdf - - -class ArithmeticCoder: - """ArithmeticCoder, - Let us take a distribution `p` over `N` symbols, and assume we have a stream - of random variables `s_t` sampled from `p`. Let us assume that we have a budget - of `B` bits that we can afford to write on device. There are `2**B` possible numbers, - corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single - sequence `(s_t)` by doing the following: - - 1) Initialize the current range to` [0 ** 2 B - 1]`. - 2) For each time step t, split the current range into contiguous chunks, - one for each possible outcome, with size roughly proportional to `p`. - For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks - would be `{[0, 2], [3, 3]}`. - 3) Select the chunk corresponding to `s_t`, and replace the current range with this. - 4) When done encoding all the values, just select any value remaining in the range. - - You will notice that this procedure can fail: for instance if at any point in time - the range is smaller than `N`, then we can no longer assign a non-empty chunk to each - possible outcome. Intuitively, the more likely a value is, the less the range width - will reduce, and the longer we can go on encoding values. This makes sense: for any efficient - coding scheme, likely outcomes would take fewer bits, and more of them can be coded - with a fixed budget. - - In practice, we do not know `B` ahead of time, but we have a way to inject new bits - when the current range decreases below a given limit (given by `total_range_bits`), without - having to redo all the computations. If we encode mostly likely values, we will seldom - need to inject new bits, but a single rare value can deplete our stock of entropy! - - In this explanation, we assumed that the distribution `p` was constant. In fact, the present - code works for any sequence `(p_t)` possibly different for each timestep. - We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller - the KL between the true distribution and `p_t`, the most efficient the coding will be. - - Args: - fo (IO[bytes]): file-like object to which the bytes will be written to. - total_range_bits (int): the range `M` described above is `2 ** total_range_bits. - Any time the current range width fall under this limit, new bits will - be injected to rescale the initial range. - """ - - def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): - assert total_range_bits <= 30 - self.total_range_bits = total_range_bits - self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. - self.low: int = 0 - self.high: int = 0 - self.max_bit: int = -1 - self._dbg: tp.List[tp.Any] = [] - self._dbg2: tp.List[tp.Any] = [] - - @property - def delta(self) -> int: - """Return the current range width.""" - return self.high - self.low + 1 - - def _flush_common_prefix(self): - # If self.low and self.high start with the sames bits, - # those won't change anymore as we always just increase the range - # by powers of 2, and we can flush them out to the bit stream. - assert self.high >= self.low, (self.low, self.high) - assert self.high < 2 ** (self.max_bit + 1) - while self.max_bit >= 0: - b1 = self.low >> self.max_bit - b2 = self.high >> self.max_bit - if b1 == b2: - self.low -= (b1 << self.max_bit) - self.high -= (b1 << self.max_bit) - assert self.high >= self.low, (self.high, self.low, self.max_bit) - assert self.low >= 0 - self.max_bit -= 1 - self.packer.push(b1) - else: - break - - def push(self, symbol: int, quantized_cdf: torch.Tensor): - """Push the given symbol on the stream, flushing out bits - if possible. - - Args: - symbol (int): symbol to encode with the AC. - quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` - to build this from your pdf estimate. - """ - while self.delta < 2 ** self.total_range_bits: - self.low *= 2 - self.high = self.high * 2 + 1 - self.max_bit += 1 - - range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() - range_high = quantized_cdf[symbol].item() - 1 - effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) - effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) - assert self.low <= self.high - self.high = self.low + effective_high - self.low = self.low + effective_low - assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) - self._dbg.append((self.low, self.high)) - self._dbg2.append((self.low, self.high)) - self._flush_common_prefix() - assert self.low <= self.high - assert self.max_bit >= -1 - assert self.max_bit <= 61, self.max_bit - - def flush(self): - """Flush the remaining information to the stream. - """ - while self.max_bit >= 0: - b1 = (self.low >> self.max_bit) & 1 - self.packer.push(b1) - self.max_bit -= 1 - self.packer.flush() - - -class ArithmeticDecoder: - """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. - - Note that this must be called with **exactly** the same parameters and sequence - of quantized cdf as the arithmetic encoder or the wrong values will be decoded. - - If the AC encoder current range is [L, H], with `L` and `H` having the same common - prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. - For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside - `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained - for a specific sequence of symbols and a binary-search allows us to decode those symbols. - At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, - and we will need to read new bits from the stream and repeat the process. - - """ - def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): - self.total_range_bits = total_range_bits - self.low: int = 0 - self.high: int = 0 - self.current: int = 0 - self.max_bit: int = -1 - self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. - # Following is for debugging - self._dbg: tp.List[tp.Any] = [] - self._dbg2: tp.List[tp.Any] = [] - self._last: tp.Any = None - - @property - def delta(self) -> int: - return self.high - self.low + 1 - - def _flush_common_prefix(self): - # Given the current range [L, H], if both have a common prefix, - # we know we can remove it from our representation to avoid handling large numbers. - while self.max_bit >= 0: - b1 = self.low >> self.max_bit - b2 = self.high >> self.max_bit - if b1 == b2: - self.low -= (b1 << self.max_bit) - self.high -= (b1 << self.max_bit) - self.current -= (b1 << self.max_bit) - assert self.high >= self.low - assert self.low >= 0 - self.max_bit -= 1 - else: - break - - def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: - """Pull a symbol, reading as many bits from the stream as required. - This returns `None` when the stream has been exhausted. - - Args: - quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` - to build this from your pdf estimate. This must be **exactly** - the same cdf as the one used at encoding time. - """ - while self.delta < 2 ** self.total_range_bits: - bit = self.unpacker.pull() - if bit is None: - return None - self.low *= 2 - self.high = self.high * 2 + 1 - self.current = self.current * 2 + bit - self.max_bit += 1 - - def bin_search(low_idx: int, high_idx: int): - # Binary search is not just for coding interviews :) - if high_idx < low_idx: - raise RuntimeError("Binary search failed") - mid = (low_idx + high_idx) // 2 - range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 - range_high = quantized_cdf[mid].item() - 1 - effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) - effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) - low = effective_low + self.low - high = effective_high + self.low - if self.current >= low: - if self.current <= high: - return mid, low, high, self.current - else: - return bin_search(mid + 1, high_idx) - else: - return bin_search(low_idx, mid - 1) - - self._last = (self.low, self.high, self.current, self.max_bit) - sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) - self._dbg.append((self.low, self.high, self.current)) - self._flush_common_prefix() - self._dbg2.append((self.low, self.high, self.current)) - - return sym - - -def test(): - torch.manual_seed(1234) - random.seed(1234) - for _ in range(4): - pdfs = [] - cardinality = random.randrange(4000) - steps = random.randrange(100, 500) - fo = io.BytesIO() - encoder = ArithmeticCoder(fo) - symbols = [] - for step in range(steps): - pdf = torch.softmax(torch.randn(cardinality), dim=0) - pdfs.append(pdf) - q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) - symbol = torch.multinomial(pdf, 1).item() - symbols.append(symbol) - encoder.push(symbol, q_cdf) - encoder.flush() - - fo.seek(0) - decoder = ArithmeticDecoder(fo) - for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): - q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) - decoded_symbol = decoder.pull(q_cdf) - assert decoded_symbol == symbol, idx - assert decoder.pull(torch.zeros(1)) is None - - -if __name__ == "__main__": - test() diff --git a/funasr/models/sense_voice/quantizer/quantization/binary.py b/funasr/models/sense_voice/quantizer/quantization/binary.py deleted file mode 100644 index a00862494..000000000 --- a/funasr/models/sense_voice/quantizer/quantization/binary.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" - -import io -import json -import struct -import typing as tp -import ctypes - -# format is `ECDC` magic code, followed by the header size as uint32. -# Then an uint8 indicates the protocol version (0.) -# The header is then provided as json and should contain all required -# informations for decoding. A raw stream of bytes is then provided -# and should be interpretable using the json header. -_encodec_header_struct = struct.Struct('!4sBI') -_ENCODEC_MAGIC = b'ECDC' - - -def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any): - meta_dumped = json.dumps(metadata).encode('utf-8') - version = 0 - header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped)) - fo.write(header) - fo.write(meta_dumped) - fo.flush() - - -def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes: - buf = b"" - while len(buf) < size: - new_buf = fo.read(size) - if not new_buf: - raise EOFError("Impossible to read enough data from the stream, " - f"{size} bytes remaining.") - buf += new_buf - size -= len(new_buf) - return buf - - -def read_ecdc_header(fo: tp.IO[bytes]): - header_bytes = _read_exactly(fo, _encodec_header_struct.size) - magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) - if magic != _ENCODEC_MAGIC: - raise ValueError("File is not in ECDC format.") - if version != 0: - raise ValueError("Version not supported.") - meta_bytes = _read_exactly(fo, meta_size) - return json.loads(meta_bytes.decode('utf-8')) - - -class BitPacker: - """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. - Note that for some bandwidth (1.5, 3), the codebook representation - will not cover an integer number of bytes. - - Args: - bits (int): number of bits per value that will be pushed. - fo (IO[bytes]): file-object to push the bytes to. - """ - def __init__(self, bits: int, fo: tp.IO[bytes]): - self._current_value = 0 - self._current_bits = 0 - self.bits = bits - self.fo = fo - - def push(self, value: int): - """Push a new value to the stream. This will immediately - write as many uint8 as possible to the underlying file-object.""" - value = bin(ctypes.c_uint32.from_buffer(ctypes.c_float(value)).value) - value = ctypes.c_int.from_buffer(ctypes.c_uint32(int(value, 2))).value - self._current_value += (value << self._current_bits) - self._current_bits += self.bits - while self._current_bits >= 8: - lower_8bits = self._current_value & 0xff - self._current_bits -= 8 - self._current_value >>= 8 - self.fo.write(bytes([lower_8bits])) - - def flush(self): - """Flushes the remaining partial uint8, call this at the end - of the stream to encode.""" - if self._current_bits: - self.fo.write(bytes([self._current_value])) - self._current_value = 0 - self._current_bits = 0 - self.fo.flush() - - -class BitUnpacker: - """BitUnpacker does the opposite of `BitPacker`. - - Args: - bits (int): number of bits of the values to decode. - fo (IO[bytes]): file-object to push the bytes to. - """ - def __init__(self, bits: int, fo: tp.IO[bytes]): - self.bits = bits - self.fo = fo - self._mask = (1 << bits) - 1 - self._current_value = 0 - self._current_bits = 0 - - def pull(self) -> tp.Optional[int]: - """ - Pull a single value from the stream, potentially reading some - extra bytes from the underlying file-object. - Returns `None` when reaching the end of the stream. - """ - while self._current_bits < self.bits: - buf = self.fo.read(1) - if not buf: - return None - character = buf[0] - self._current_value += character << self._current_bits - self._current_bits += 8 - - out = self._current_value & self._mask - self._current_value >>= self.bits - self._current_bits -= self.bits - # out = ctypes.c_float.from_buffer(ctypes.c_uint32(int(out, 2))).value - return out - - -def test(): - import torch - torch.manual_seed(1234) - for rep in range(4): - length: int = torch.randint(10, 2_000, (1,)).item() - bits: int = torch.randint(1, 16, (1,)).item() - tokens: tp.List[int] = torch.randint(2 ** bits, (length,)).tolist() - rebuilt: tp.List[int] = [] - buf = io.BytesIO() - packer = BitPacker(bits, buf) - for token in tokens: - packer.push(token) - packer.flush() - buf.seek(0) - unpacker = BitUnpacker(bits, buf) - while True: - value = unpacker.pull() - if value is None: - break - rebuilt.append(value) - assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) - # The flushing mechanism might lead to "ghost" values at the end of the stream. - assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits) - for idx, (a, b) in enumerate(zip(tokens, rebuilt)): - assert a == b, (idx, a, b) - - -if __name__ == '__main__': - test() diff --git a/funasr/models/sense_voice/quantizer/quantization/core_vq.py b/funasr/models/sense_voice/quantizer/quantization/core_vq.py deleted file mode 100644 index e6db4712b..000000000 --- a/funasr/models/sense_voice/quantizer/quantization/core_vq.py +++ /dev/null @@ -1,674 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# This implementation is inspired from -# https://github.com/lucidrains/vector-quantize-pytorch -# which is released under MIT License. Hereafter, the original license: -# MIT License -# -# Copyright (c) 2020 Phil Wang -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Core vector quantization implementation.""" -import logging -import typing as tp -from random import randrange - -import numpy as np -from einops import rearrange, repeat -from math import ceil -import torch -from torch import nn -import torch.nn.functional as F -import random - -from funasr.models.sense_voice.quantizer.quantization import distrib - -def round_up_multiple(num, mult): - return ceil(num / mult) * mult - -def default(val: tp.Any, d: tp.Any) -> tp.Any: - return val if val is not None else d - - -def ema_inplace(moving_avg, new, decay: float): - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): - return (x + epsilon) / (x.sum() + n_categories * epsilon) - - -def uniform_init(*shape: int): - t = torch.empty(shape) - nn.init.kaiming_uniform_(t) - return t - - -def sample_vectors(samples, num: int): - num_samples, device = samples.shape[0], samples.device - - if num_samples >= num: - indices = torch.randperm(num_samples, device=device)[:num] - else: - indices = torch.randint(0, num_samples, (num,), device=device) - - return samples[indices] - - -@torch.no_grad() -def kmeans(samples, num_clusters: int, num_iters: int = 10): - # device = samples.device - # samples = samples.cpu() - dim, dtype = samples.shape[-1], samples.dtype - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - # diffs = rearrange(samples, "n d -> n () d") - rearrange( - # means, "c d -> () c d" - # ) - # dists = -(diffs ** 2).sum(dim=-1) - dists = -( - samples.pow(2).sum(1, keepdim=True) - - 2 * torch.matmul(samples, means.t()) - + means.t().pow(2).sum(0, keepdim=True) - ) - - buckets = dists.max(dim=-1).indices - del dists - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] - - means = torch.where(zero_mask[..., None], means, new_means) - - # means = means.to(device) - return means, bins - - -def preprocess(x): - x = rearrange(x, "... d -> (...) d") - return x - - -def postprocess_emb(embed_ind, shape): - return embed_ind.view(*shape[:-1]) - - -class EuclideanCodebook(nn.Module): - """Codebook with Euclidean distance. - Args: - dim (int): Dimension. - codebook_size (int): Codebook size. - kmeans_init (bool): Whether to use k-means to initialize the codebooks. - If set to true, run the k-means algorithm on the first training batch and use - the learned centroids as initialization. - kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: int = False, - kmeans_iters: int = 10, - decay: float = 0.99, - epsilon: float = 1e-5, - threshold_ema_dead_code: int = 2, - ): - super().__init__() - self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros - embed = init_fn(codebook_size, dim) - - self.codebook_size = codebook_size - - self.kmeans_iters = kmeans_iters - self.epsilon = epsilon - self.threshold_ema_dead_code = threshold_ema_dead_code - - self.register_buffer("inited", torch.Tensor([not kmeans_init])) - self.register_buffer("cluster_size", torch.zeros(codebook_size)) - self.register_buffer("embed", embed) - self.register_buffer("embed_avg", embed.clone()) - self.training = True - - def init_embed_(self, data): - if self.inited: - return - - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embed.data.copy_(embed) - self.embed_avg.data.copy_(embed.clone()) - self.cluster_size.data.copy_(cluster_size) - self.inited.data.copy_(torch.Tensor([True])) - # Make sure all buffers across workers are in sync after initialization - distrib.broadcast_tensors(self.buffers()) - - def replace_(self, samples, mask): - modified_codebook = torch.where( - mask[..., None], sample_vectors(samples, self.codebook_size), self.embed - ) - self.embed.data.copy_(modified_codebook) - - def expire_codes_(self, batch_samples): - if self.threshold_ema_dead_code == 0: - return - - expired_codes = self.cluster_size < self.threshold_ema_dead_code - if not torch.any(expired_codes): - return - - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - distrib.broadcast_tensors(self.buffers()) - - def quantize(self, x): - embed = self.embed.t() - dist = -( - x.pow(2).sum(1, keepdim=True) - - 2 * x @ embed - + embed.pow(2).sum(0, keepdim=True) - ) - embed_ind = dist.max(dim=-1).indices - return embed_ind - - def dequantize(self, embed_ind): - quantize = F.embedding(embed_ind, self.embed) - return quantize - - def encode(self, x): - shape = x.shape - # pre-process - x = preprocess(x) - # quantize - embed_ind = self.quantize(x) - # post-process - embed_ind = postprocess_emb(embed_ind, shape) - return embed_ind - - def decode(self, embed_ind): - quantize = self.dequantize(embed_ind) - return quantize - - def forward(self, x): - shape, dtype = x.shape, x.dtype - x = preprocess(x) - - self.init_embed_(x) - - embed_ind = self.quantize(x) - embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) - embed_ind = postprocess_emb(embed_ind, shape) - quantize = self.dequantize(embed_ind) - - if self.training: - # We do the expiry of code at that point as buffers are in sync - # and all the workers will take the same decision. - self.expire_codes_(x) - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) - embed_sum = x.t() @ embed_onehot - ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) - * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - self.embed.data.copy_(embed_normalized) - # Note: after ema update, there is a very small difference between codebooks on GPUs. - # The impact can be very small, ignore it. - - return quantize, embed_ind - - -class SimpleEuclideanCodebook(nn.Module): - """Simple Codebook with Euclidean distance. - Using gradient to update code embeddings instead of EMA. - Args: - dim (int): Dimension. - codebook_size (int): Codebook size. - kmeans_init (bool): Whether to use k-means to initialize the codebooks. - If set to true, run the k-means algorithm on the first training batch and use - the learned centroids as initialization. - kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. - """ - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: tp.Union[bool, torch.Tensor] = False, - kmeans_iters: int = 10, - **kwargs - ): - super().__init__() - if isinstance(kmeans_init, bool): - if kmeans_init: - embed = torch.zeros(codebook_size, dim) - inited = False - else: - embed = uniform_init(codebook_size, dim) - inited = True - else: - embed = kmeans_init - inited = True - self.codebook_size = codebook_size - self.kmeans_iters = kmeans_iters - - self.embed = nn.Embedding(codebook_size, dim) - self.embed.weight.data.copy_(embed) - # self.register_parameter("embed", nn.Parameter(embed, requires_grad=True)) - self.register_buffer("inited", torch.Tensor([inited])) - self.training = True - - def init_embed_(self, data): - if self.inited: - return - - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embed.data.copy_(embed) - self.inited.data.copy_(torch.Tensor([True])) - # Make sure all buffers across workers are in sync after initialization - distrib.broadcast_tensors(self.buffers()) - - def quantize(self, x): - embed = self.embed.weight.t() - dist = -( - x.pow(2).sum(1, keepdim=True) - - 2 * x @ embed - + embed.pow(2).sum(0, keepdim=True) - ) - embed_ind = dist.max(dim=-1).indices - return embed_ind - - def dequantize(self, embed_ind): - quantize = self.embed(embed_ind) - return quantize - - def encode(self, x): - shape = x.shape - # pre-process - x = preprocess(x) - # quantize - embed_ind = self.quantize(x) - # post-process - embed_ind = postprocess_emb(embed_ind, shape) - return embed_ind - - def decode(self, embed_ind): - quantize = self.dequantize(embed_ind) - return quantize - - def forward(self, x): - shape, dtype = x.shape, x.dtype - x = preprocess(x) - - self.init_embed_(x) - - embed_ind = self.quantize(x) - embed_ind = postprocess_emb(embed_ind, shape) - quantize = self.dequantize(embed_ind) - - return quantize, embed_ind - - -class VectorQuantization(nn.Module): - """Vector quantization implementation. - Currently, supports only euclidean distance. - Args: - dim (int): Dimension - codebook_size (int): Codebook size - codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - commitment_weight (float): Weight for commitment loss. - """ - def __init__( - self, - dim: int, - codebook_size: int, - codebook_dim: tp.Optional[int] = None, - decay: float = 0.99, - epsilon: float = 1e-5, - kmeans_init: bool = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - commitment_weight: float = 1., - ): - super().__init__() - _codebook_dim: int = default(codebook_dim, dim) - - requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity()) - - self.epsilon = epsilon - self.commitment_weight = commitment_weight - - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) - self.codebook_size = codebook_size - self.training = True - - @property - def codebook(self): - return self._codebook.embed - - def encode(self, x): - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - embed_in = self._codebook.encode(x) - return embed_in - - def decode(self, embed_ind): - quantize = self._codebook.decode(embed_ind) - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize - - def forward(self, x): - device = x.device - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - - quantize, embed_ind = self._codebook(x) - - if self.training: - quantize = x + (quantize - x).detach() - - loss = torch.tensor([0.0], device=device, requires_grad=self.training) - - if self.training: - if self.commitment_weight > 0: - commit_loss = F.mse_loss(quantize.detach(), x) - loss = loss + commit_loss * self.commitment_weight - - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize, embed_ind, loss - - -class SimpleVectorQuantization(nn.Module): - """Vector quantization implementation with SimpleEuclideanCodebook. - Currently, supports only euclidean distance. - Args: - dim (int): Dimension - codebook_size (int): Codebook size - codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - commitment_weight (float): Weight for commitment loss. - """ - def __init__( - self, - dim: int, - codebook_size: int, - codebook_dim: tp.Optional[int] = None, - epsilon: float = 1e-5, - kmeans_init: bool = True, - kmeans_iters: int = 50, - commitment_weight: float = 0.25, - codebook_weight: float = 1.0, - **kwargs - ): - super().__init__() - _codebook_dim: int = default(codebook_dim, dim) - - requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity()) - - self.epsilon = epsilon - self.commitment_weight = commitment_weight - self.codebook_weight = codebook_weight - logging.info(f"commitment_weight: {commitment_weight}, codebook_weight: {codebook_weight}.") - - self._codebook = SimpleEuclideanCodebook( - dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters - ) - self.codebook_size = codebook_size - self.training = True - - @property - def codebook(self): - return self._codebook.embed - - def encode(self, x): - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - embed_in = self._codebook.encode(x) - return embed_in - - def decode(self, embed_ind): - quantize = self._codebook.decode(embed_ind) - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize - - def forward(self, x): - device = x.device - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - - quantize, embed_ind = self._codebook(x) - - loss = torch.tensor([0.0], device=device, requires_grad=self.training) - - if self.training: - # commit loss for codebook - if self.codebook_weight > 0: - codebook_loss = F.mse_loss(quantize, x.detach()) - loss = loss + codebook_loss * self.codebook_weight - - # commit loss for encoder - if self.commitment_weight > 0: - commit_loss = F.mse_loss(quantize.detach(), x) - loss = loss + commit_loss * self.commitment_weight - - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize, embed_ind, loss - - -class ResidualVectorQuantization(nn.Module): - """Residual vector quantization implementation. - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf - """ - def __init__(self, *, - num_quantizers, - quantize_dropout: bool = False, - rand_num_quant: tp.Optional[tp.List] = None, - **kwargs): - super().__init__() - self.layers = nn.ModuleList( - [VectorQuantization(**kwargs) for _ in range(num_quantizers)] - ) - self.quantize_dropout = quantize_dropout - self.rand_num_quant = rand_num_quant - - def forward(self, x, n_q: tp.Optional[int] = None): - quantized_out = 0.0 - residual = x - device = x.device - - all_losses = [] - all_indices = [] - all_sub_quants = [] - n_q = n_q or len(self.layers) - - should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None - if should_quantize_dropout: - rand_quantize_dropout_index = random.choice(self.rand_num_quant) - - null_indices_shape = (x.shape[0], x.shape[2]) - null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long) - null_loss = torch.full((1,), 0., device=device, dtype=x.dtype) - null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype) - - for quantizer_index, layer in enumerate(self.layers[:n_q]): - # dropout except the first quantizer - if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index: - all_indices.append(null_indices) - all_losses.append(null_loss) - all_sub_quants.append(null_sub_quant) - continue - - quantized, indices, loss = layer(residual) - residual = residual - quantized - quantized_out = quantized_out + quantized - - all_indices.append(indices) - all_losses.append(loss) - all_sub_quants.append(quantized) - - out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants)) - return quantized_out, out_indices, out_losses, out_sub_quants - - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: - residual = x - all_indices = [] - n_q = n_q or len(self.layers) - for layer in self.layers[:n_q]: - indices = layer.encode(residual) - quantized = layer.decode(indices) - residual = residual - quantized - all_indices.append(indices) - out_indices = torch.stack(all_indices) - return out_indices - - def decode(self, q_indices: torch.Tensor) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized - return quantized_out - - -class SimpleResidualVectorQuantization(nn.Module): - """Simple Residual vector quantization with gradient to - update codebook instead of EMA - """ - def __init__(self, *, - num_quantizers, - quantize_dropout: bool = False, - rand_num_quant: tp.Optional[tp.List] = None, - **kwargs): - super().__init__() - kmeans_init = raw_kmeans_init = kwargs.pop('kmeans_init', True) - if isinstance(kmeans_init, str): - # use prepared kmeans init - embed = np.load(kmeans_init) - embed = torch.from_numpy(embed) - if embed.dim() == 2: - embed = embed.unsqueeze(0) - kmeans_init = embed - - self.layers = nn.ModuleList([ - SimpleVectorQuantization( - kmeans_init=kmeans_init[i] if isinstance(kmeans_init, torch.Tensor) else kmeans_init, - **kwargs - ) for i in range(num_quantizers) - ]) - kwargs["kmeans_init"] = raw_kmeans_init - self.quantize_dropout = quantize_dropout - self.rand_num_quant = rand_num_quant - - def forward(self, x, n_q: tp.Optional[int] = None): - quantized_out = 0.0 - residual = x - device = x.device - - all_losses = [] - all_indices = [] - all_sub_quants = [] - n_q = n_q or len(self.layers) - - should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None - if should_quantize_dropout: - rand_quantize_dropout_index = random.choice(self.rand_num_quant) - - null_indices_shape = (x.shape[0], x.shape[2]) - null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long) - null_loss = torch.full((1,), 0., device=device, dtype=x.dtype) - null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype) - - for quantizer_index, layer in enumerate(self.layers[:n_q]): - # dropout except the first quantizer - if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index: - all_indices.append(null_indices) - all_losses.append(null_loss) - all_sub_quants.append(null_sub_quant) - continue - - quantized, indices, loss = layer(residual) - residual = residual - quantized - quantized_out = quantized_out + quantized - - all_indices.append(indices) - all_losses.append(loss) - all_sub_quants.append(quantized) - - out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants)) - return quantized_out, out_indices, out_losses, out_sub_quants - - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: - residual = x - all_indices = [] - n_q = n_q or len(self.layers) - for layer in self.layers[:n_q]: - indices = layer.encode(residual) - quantized = layer.decode(indices) - residual = residual - quantized - all_indices.append(indices) - out_indices = torch.stack(all_indices) - return out_indices - - def decode(self, q_indices: torch.Tensor) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized - return quantized_out diff --git a/funasr/models/sense_voice/quantizer/quantization/ddp_core_vq.py b/funasr/models/sense_voice/quantizer/quantization/ddp_core_vq.py deleted file mode 100644 index 4fc31d7de..000000000 --- a/funasr/models/sense_voice/quantizer/quantization/ddp_core_vq.py +++ /dev/null @@ -1,513 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# This implementation is inspired from -# https://github.com/lucidrains/vector-quantize-pytorch -# which is released under MIT License. Hereafter, the original license: -# MIT License -# -# Copyright (c) 2020 Phil Wang -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Core vector quantization implementation.""" -import logging -import random -import typing as tp -from random import randrange - -import numpy as np -from einops import rearrange, repeat -from math import ceil -import torch -from torch import nn -import torch.nn.functional as F -from funasr.utils.hinter import hint_once - -from funasr.models.sense_voice.quantizer.quantization import distrib - -def round_up_multiple(num, mult): - return ceil(num / mult) * mult - -def default(val: tp.Any, d: tp.Any) -> tp.Any: - return val if val is not None else d - - -def ema_inplace(moving_avg, new, decay: float, mask=None): - if mask is None: - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - else: - mask = mask.float() - new_avg = moving_avg * decay + new * (1 - decay) - new_avg = mask * new_avg + (1 - mask) * moving_avg - moving_avg.data.copy_(new_avg.data) - - -def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): - return (x + epsilon) / (x.sum() + n_categories * epsilon) - - -def uniform_init(*shape: int): - t = torch.empty(shape) - nn.init.kaiming_uniform_(t) - return t - - -def sample_vectors(samples, num: int): - num_samples, device = samples.shape[0], samples.device - - if num_samples >= num: - indices = torch.randperm(num_samples, device=device)[:num] - else: - indices = torch.randint(0, num_samples, (num,), device=device) - - return samples[indices] - - -@torch.no_grad() -def kmeans(samples, num_clusters: int, num_iters: int = 10): - # device = samples.device - # samples = samples.cpu() - dim, dtype = samples.shape[-1], samples.dtype - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - # diffs = rearrange(samples, "n d -> n () d") - rearrange( - # means, "c d -> () c d" - # ) - # dists = -(diffs ** 2).sum(dim=-1) - dists = -( - samples.pow(2).sum(1, keepdim=True) - - 2 * torch.matmul(samples, means.t()) - + means.t().pow(2).sum(0, keepdim=True) - ) - - buckets = dists.max(dim=-1).indices - del dists - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] - - means = torch.where(zero_mask[..., None], means, new_means) - - # means = means.to(device) - return means, bins - - -def preprocess(x): - x = rearrange(x, "... d -> (...) d") - return x - - -def postprocess_emb(embed_ind, shape): - return embed_ind.view(*shape[:-1]) - - -class EuclideanCodebook(nn.Module): - """Codebook with Euclidean distance. - Args: - dim (int): Dimension. - codebook_size (int): Codebook size. - kmeans_init (bool): Whether to use k-means to initialize the codebooks. - If set to true, run the k-means algorithm on the first training batch and use - the learned centroids as initialization. - kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: int = False, - kmeans_iters: int = 10, - decay: float = 0.99, - epsilon: float = 1e-5, - threshold_ema_dead_code: int = 2, - sparse_update: bool = False, - normalized_input: bool = False, - **kwargs, - ): - super().__init__() - self.decay = decay - self.codebook_size = codebook_size - self.kmeans_iters = kmeans_iters - self.epsilon = epsilon - self.threshold_ema_dead_code = threshold_ema_dead_code - - self.inited = None - self.cluster_size = None - self.embed = None - self.embed_avg = None - self.training = True - self.sparse_update = sparse_update - self.normalized_input = normalized_input - - def init_embed_(self, data): - if self.inited: - return - - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embed.data.copy_(embed) - self.embed_avg.data.copy_(embed.clone()) - self.cluster_size.data.copy_(cluster_size) - self.inited.data.copy_(torch.Tensor([True])) - # Make sure all buffers across workers are in sync after initialization - distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited]) - - def replace_(self, samples, mask): - modified_codebook = torch.where( - mask[..., None], sample_vectors(samples, self.codebook_size), self.embed - ) - self.embed.data.copy_(modified_codebook) - - def expire_codes_(self, batch_samples): - hint_once(f"threshold_ema_dead_code: {self.threshold_ema_dead_code}.", "threshold_ema_dead_code", rank=0) - if self.threshold_ema_dead_code == 0: - return - - expired_codes = self.cluster_size < self.threshold_ema_dead_code - if not torch.any(expired_codes): - return - - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - # sync buffers outside for efficiency - # distrib.broadcast_tensors(self.buffers()) - - def quantize(self, x): - embed = self.embed.t() - dist = -( - x.pow(2).sum(1, keepdim=True) - - 2 * x @ embed - + embed.pow(2).sum(0, keepdim=True) - ) - embed_ind = dist.max(dim=-1).indices - return embed_ind - - def dequantize(self, embed_ind): - quantize = F.embedding(embed_ind, self.embed) - return quantize - - def encode(self, x, buffers): - self.inited, self.cluster_size, self.embed, self.embed_avg = buffers - - shape = x.shape - # pre-process - x = preprocess(x) - # quantize - embed_ind = self.quantize(x) - # post-process - embed_ind = postprocess_emb(embed_ind, shape) - return embed_ind - - def decode(self, embed_ind, buffers): - self.inited, self.cluster_size, self.embed, self.embed_avg = buffers - - quantize = self.dequantize(embed_ind) - return quantize - - def forward(self, x, buffers): - self.inited, self.cluster_size, self.embed, self.embed_avg = buffers - - shape, dtype = x.shape, x.dtype - x = preprocess(x) - - self.init_embed_(x) - - embed_ind = self.quantize(x) - embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) - embed_ind = postprocess_emb(embed_ind, shape) - quantize = self.dequantize(embed_ind) - - if self.training: - # We do the expiry of code at that point as buffers are in sync - # and all the workers will take the same decision. - self.expire_codes_(x) - if not self.sparse_update: - mask = None - else: - mask = embed_onehot.sum(0) > 0 - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay, mask=mask) - embed_sum = x.t() @ embed_onehot - # if self.normalized_input: - # embed_sum = F.normalize(embed_sum, dim=0) - ema_inplace(self.embed_avg, embed_sum.t(), self.decay, - mask=mask.unsqueeze(-1) if self.sparse_update else None) - if not self.sparse_update: - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) - * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - else: - embed_normalized = self.embed_avg - if self.normalized_input: - embed_normalized = F.normalize(embed_normalized, dim=-1) - self.embed.data.copy_(embed_normalized) - # Note: after ema update, there is a very small difference between codebooks on GPUs. - # The impact can be very small, ignore it. - - return quantize, embed_ind - - -class VectorQuantization(nn.Module): - """Vector quantization implementation. - Currently, supports only euclidean distance. - Args: - dim (int): Dimension - codebook_size (int): Codebook size - codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - commitment_weight (float): Weight for commitment loss. - """ - def __init__( - self, - dim: int, - codebook_size: int, - codebook_dim: tp.Optional[int] = None, - decay: float = 0.99, - epsilon: float = 1e-5, - kmeans_init: bool = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - commitment_weight: float = 1., - **kwargs, - ): - super().__init__() - _codebook_dim: int = default(codebook_dim, dim) - - requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity()) - - self.epsilon = epsilon - self.commitment_weight = commitment_weight - - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code, - **kwargs) - self.codebook_size = codebook_size - self.training = True - - @property - def codebook(self): - return self._codebook.embed - - def encode(self, x, buffers): - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - embed_in = self._codebook.encode(x, buffers) - return embed_in - - def decode(self, embed_ind, buffers): - quantize = self._codebook.decode(embed_ind, buffers) - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize - - def forward(self, x, buffers): - device = x.device - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - - quantize, embed_ind = self._codebook(x, buffers) - - if self.training: - quantize = x + (quantize - x).detach() - - loss = torch.tensor([0.0], device=device, requires_grad=self.training) - - if self.training: - if self.commitment_weight > 0: - commit_loss = F.mse_loss(quantize.detach(), x) - loss = loss + commit_loss * self.commitment_weight - - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize, embed_ind, loss - - -class DistributedResidualVectorQuantization(nn.Module): - """Efficient distributed residual vector quantization implementation. - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf - """ - def __init__(self, *, - num_quantizers, - quantize_dropout: bool = False, - rand_num_quant: tp.Optional[tp.List] = None, - **kwargs): - super().__init__() - """ - dim: int, - codebook_size: int, - codebook_dim: tp.Optional[int] = None, - """ - codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["dim"] - kmeans_init = kwargs["kmeans_init"] - if isinstance(kmeans_init, bool): - if not kwargs["kmeans_init"]: - # use uniform init - embed = uniform_init(num_quantizers, codebook_size, codebook_dim) - inited = True - cluster_size = 1 - else: - # to perform kmeans init on first batch - embed = torch.zeros(num_quantizers, codebook_size, codebook_dim) - inited = False - cluster_size = 0 - elif isinstance(kmeans_init, str): - # use prepared kmeans init - embed = np.load(kmeans_init) - embed = torch.from_numpy(embed) - if kwargs.get("normalized_input", False): - logging.info("normalize the code embeddings since the input is normalized.") - embed = F.normalize(embed, dim=-1) - if embed.dim() == 2: - embed = embed.repeat(num_quantizers, 1, 1) - inited = True - cluster_size = 1 - else: - raise TypeError("kmeans_init should be either a bool or string path to init weights.") - - self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)])) - self.register_buffer("cluster_size", torch.ones(num_quantizers, codebook_size) * cluster_size) - self.register_buffer("embed", embed) - self.register_buffer("embed_avg", embed.clone()) - - self.q0_ds_ratio = 1 - if "q0_ds_ratio" in kwargs: - self.q0_ds_ratio = kwargs.pop("q0_ds_ratio") - - self.layers = nn.ModuleList() - for i in range(num_quantizers): - vq_args = dict(**kwargs) - vq = VectorQuantization(**vq_args) - self.layers.append(vq) - - self.quantize_dropout = quantize_dropout - self.rand_num_quant = rand_num_quant - - def forward(self, x, n_q: tp.Optional[int] = None): - quantized_out = torch.zeros_like(x) - residual = x - bb, cc, tt = x.shape - device = x.device - - all_losses = [] - all_indices = [] - all_sub_quants = [] - n_q = n_q or len(self.layers) - - should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None - if should_quantize_dropout: - rand_quantize_dropout_index = random.choice(self.rand_num_quant) - - null_indices_shape = (x.shape[0], x.shape[2]) - null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long) - null_loss = torch.full((1,), 0., device=device, dtype=x.dtype) - null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype) - - for quantizer_index, layer in enumerate(self.layers[:n_q]): - # dropout except the first quantizer - if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index: - all_indices.append(null_indices) - all_losses.append(null_loss) - all_sub_quants.append(null_sub_quant) - continue - - quant_in = residual - if self.q0_ds_ratio > 1 and quantizer_index == 0: - quant_in = F.interpolate(quant_in, size=[tt//2]) - quantized, indices, loss = layer(quant_in, [ - self.inited[quantizer_index], - self.cluster_size[quantizer_index], - self.embed[quantizer_index], - self.embed_avg[quantizer_index] - ]) - if self.q0_ds_ratio > 1 and quantizer_index == 0: - quantized = F.interpolate(quantized, size=[tt]) - indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long() - residual = residual - quantized - quantized_out = quantized_out + quantized - - all_indices.append(indices) - all_losses.append(loss) - all_sub_quants.append(quantized) - - # sync buffers after one forward step - distrib.broadcast_tensors(self.buffers()) - out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants)) - - return quantized_out, out_indices, out_losses, out_sub_quants - - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: - residual = x - all_indices = [] - n_q = n_q or len(self.layers) - for i, layer in enumerate(self.layers[:n_q]): - indices = layer.encode(residual, [ - self.inited[i], - self.cluster_size[i], - self.embed[i], - self.embed_avg[i] - ]) - quantized = layer.decode(indices, [ - self.inited[i], - self.cluster_size[i], - self.embed[i], - self.embed_avg[i] - ]) - residual = residual - quantized - all_indices.append(indices) - out_indices = torch.stack(all_indices) - return out_indices - - def decode(self, q_indices: torch.Tensor) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[i] - quantized = layer.decode(indices, [ - self.inited[i], - self.cluster_size[i], - self.embed[i], - self.embed_avg[i] - ]) - quantized_out = quantized_out + quantized - return quantized_out diff --git a/funasr/models/sense_voice/quantizer/quantization/distrib.py b/funasr/models/sense_voice/quantizer/quantization/distrib.py deleted file mode 100644 index 57c38d773..000000000 --- a/funasr/models/sense_voice/quantizer/quantization/distrib.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Torch distributed utilities.""" - -import typing as tp -import time, logging -import torch - - -def rank(): - if torch.distributed.is_initialized(): - return torch.distributed.get_rank() - else: - return 0 - - -def world_size(): - if torch.distributed.is_initialized(): - return torch.distributed.get_world_size() - else: - return 1 - - -def is_distributed(): - return world_size() > 1 - - -def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): - if is_distributed(): - return torch.distributed.all_reduce(tensor, op) - - -def _is_complex_or_float(tensor): - return torch.is_floating_point(tensor) or torch.is_complex(tensor) - - -def _check_number_of_params(params: tp.List[torch.Tensor]): - # utility function to check that the number of params in all workers is the same, - # and thus avoid a deadlock with distributed all reduce. - if not is_distributed() or not params: - return - tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) - all_reduce(tensor) - if tensor.item() != len(params) * world_size(): - # If not all the workers have the same number, for at least one of them, - # this inequality will be verified. - raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " - "at least one worker has a different one.") - - -def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): - """Broadcast the tensors from the given parameters to all workers. - This can be used to ensure that all workers have the same model to start with. - """ - # start = time.time() - if not is_distributed(): - return - tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] - _check_number_of_params(tensors) - handles = [] - for tensor in tensors: - handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) - handles.append(handle) - for handle in handles: - handle.wait() - # logging.info(f"Time of broadcast {(time.time()-start):.4f}") - - -def sync_buffer(buffers, average=True): - """ - Sync grad for buffers. If average is False, broadcast instead of averaging. - """ - if not is_distributed(): - return - handles = [] - for buffer in buffers: - if torch.is_floating_point(buffer.data): - if average: - handle = torch.distributed.all_reduce( - buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) - else: - handle = torch.distributed.broadcast( - buffer.data, src=0, async_op=True) - handles.append((buffer, handle)) - for buffer, handle in handles: - handle.wait() - if average: - buffer.data /= world_size - - -def sync_grad(params): - """ - Simpler alternative to DistributedDataParallel, that doesn't rely - on any black magic. For simple models it can also be as fast. - Just call this on your model parameters after the call to backward! - """ - if not is_distributed(): - return - handles = [] - for p in params: - if p.grad is not None: - handle = torch.distributed.all_reduce( - p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) - handles.append((p, handle)) - for p, handle in handles: - handle.wait() - p.grad.data /= world_size() - - -def average_metrics(metrics: tp.Dict[str, float], count=1.): - """Average a dictionary of metrics across all workers, using the optional - `count` as unormalized weight. - """ - if not is_distributed(): - return metrics - keys, values = zip(*metrics.items()) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) - tensor *= count - all_reduce(tensor) - averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() - return dict(zip(keys, averaged)) diff --git a/funasr/models/sense_voice/quantizer/quantization/vq.py b/funasr/models/sense_voice/quantizer/quantization/vq.py deleted file mode 100644 index a841e4795..000000000 --- a/funasr/models/sense_voice/quantizer/quantization/vq.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Residual vector quantizer implementation.""" -import logging -from dataclasses import dataclass, field -import math -import typing as tp - -import torch -from torch import nn -from funasr.models.sense_voice.quantizer.quantization import distrib -from funasr.models.sense_voice.quantizer.quantization.core_vq import SimpleResidualVectorQuantization -from funasr.models.sense_voice.quantizer.quantization.ddp_core_vq import DistributedResidualVectorQuantization - -@dataclass -class QuantizedResult: - quantized: torch.Tensor - codes: torch.Tensor - bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. - penalty: tp.Optional[torch.Tensor] = None - metrics: dict = field(default_factory=dict) - sub_quants: torch.Tensor = None - - -class ResidualVectorQuantizer(nn.Module): - """Residual Vector Quantizer. - Args: - dimension (int): Dimension of the codebooks. - n_q (int): Number of residual vector quantizers used. - bins (int): Codebook size. - decay (float): Decay for exponential moving average over the codebooks. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - def __init__( - self, - dimension: int = 256, - n_q: int = 8, - bins: int = 1024, - decay: float = 0.99, - kmeans_init: tp.Union[bool, str] = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - quantize_dropout: bool = False, - rand_num_quant: tp.Optional[tp.List] = None, - encoder_hop_length: int = 320, - use_ddp: bool = True, - q0_ds_ratio: int = 1, - **kwargs, - ): - super().__init__() - self.n_q = n_q - self.dimension = dimension - self.bins = bins - self.decay = decay - self.kmeans_init = kmeans_init - self.kmeans_iters = kmeans_iters - self.threshold_ema_dead_code = threshold_ema_dead_code - self.encoder_hop_length = encoder_hop_length - self.training = True - if use_ddp: - rvq_class = DistributedResidualVectorQuantization - logging.info("Using distributed residual vector quantization.") - else: - rvq_class = SimpleResidualVectorQuantization - logging.warning("Using simple residual vector quantization") - self.model = rvq_class( - dim=self.dimension, - codebook_size=self.bins, - num_quantizers=self.n_q, - decay=self.decay, - kmeans_init=self.kmeans_init, - kmeans_iters=self.kmeans_iters, - threshold_ema_dead_code=self.threshold_ema_dead_code, - quantize_dropout=quantize_dropout, - rand_num_quant=rand_num_quant, - q0_ds_ratio=q0_ds_ratio, - **kwargs - ) - - def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: - """Residual vector quantization on the given input tensor. - Args: - x (torch.Tensor): Input tensor in the shape of (B, C, T). - sample_rate (int): Sample rate of the input tensor. - bandwidth (float): Target bandwidth. - Returns: - QuantizedResult: - The quantized (or approximately quantized) representation with - the associated bandwidth and any penalty term for the loss. - """ - bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) - n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) - quantized, codes, commit_loss, sub_quants = self.model(x, n_q=n_q) - bw = torch.tensor(n_q * bw_per_q).to(x) - return QuantizedResult(quantized, codes, bw, - penalty=torch.mean(commit_loss), - sub_quants=sub_quants) - - def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: - """Return n_q based on specified target bandwidth. - """ - bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) - n_q = self.n_q - if bandwidth and bandwidth > 0.: - n_q = int(max(1, math.floor(bandwidth / bw_per_q))) - return n_q - - def get_bandwidth_per_quantizer(self, sample_rate: int): - """Return bandwidth per quantizer for a given input sample rate. - """ - return math.log2(self.bins) * sample_rate / self.encoder_hop_length - - def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: - """Encode a given input tensor with the specified sample rate at the given bandwidth. - The RVQ encode method sets the appropriate number of quantizer to use - and returns indices for each quantizer. - """ - n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) - codes = self.model.encode(x, n_q=n_q) - return codes - - def decode(self, codes: torch.Tensor) -> torch.Tensor: - """Decode the given codes to the quantized representation. - """ - quantized = self.model.decode(codes) - return quantized