This commit is contained in:
游雁 2024-12-26 10:47:33 +08:00
parent 90910c4a31
commit 756367d0cc
11 changed files with 0 additions and 2821 deletions

View File

@ -1,126 +0,0 @@
import torch
from funasr.models.sense_voice.quantizer.quantization.vq import ResidualVectorQuantizer
import typing as tp
class CostumeQuantizer(torch.nn.Module):
def __init__(
self,
input_size: int = 512,
codebook_size: int = 1024,
num_quantizers: int = 8,
ema_decay: float = 0.95,
kmeans_init: tp.Union[bool, str] = False,
sampling_rate: int = 16_000,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
encoder_hop_length: int = 320,
use_ddp: bool = True,
q0_ds_ratio: int = 1,
codec_dim: int = None,
codec_range: float = None,
threshold_ema_dead_code=2,
**kwargs,
):
super().__init__()
if codec_dim is None:
codec_dim = input_size
self.input_proj, self.output_proj = None, None
if codec_dim != input_size:
self.input_proj = torch.nn.Linear(input_size, codec_dim)
self.output_proj = torch.nn.Linear(codec_dim, input_size)
self.input_act, self.codec_range = None, None
if codec_range is not None:
self.input_act = torch.nn.Tanh()
self.codec_range = codec_range
self.rq = ResidualVectorQuantizer(
dimension=codec_dim,
n_q=num_quantizers,
bins=codebook_size,
decay=ema_decay,
kmeans_init=kmeans_init,
quantize_dropout=quantize_dropout,
rand_num_quant=rand_num_quant,
encoder_hop_length=encoder_hop_length,
use_ddp=use_ddp,
q0_ds_ratio=q0_ds_ratio,
threshold_ema_dead_code=threshold_ema_dead_code,
**kwargs,
)
self.code_dim = codec_dim
self.sampling_rate = sampling_rate
self.bandwidth: tp.Optional[float] = None
self.encoder_hop_length = encoder_hop_length
self.codebook_size = codebook_size
def forward(
self,
x,
bandwidth: int = None,
):
# x: input tensor in the shape of (B, T, C)
# rq requires inputs in (B, C, T)
if self.input_proj is not None:
x = self.input_proj(x)
if self.input_act is not None:
x = self.input_act(x) * self.codec_range
qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
x, indices, commit_loss, sub_quants = qv.quantized, qv.codes, qv.penalty, qv.sub_quants
x = x.permute(0, 2, 1)
if self.output_proj is not None:
x = self.output_proj(x)
return x, indices, commit_loss, sub_quants
def inference(
self,
x,
bandwidth: int = None,
):
# x: input tensor in the shape of (B, T, C)
# rq requires inputs in (B, C, T)
if self.input_proj is not None:
x = self.input_proj(x)
if self.input_act is not None:
x = self.input_act(x) * self.codec_range
qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
x, indices, sub_quants = qv.quantized, qv.codes, qv.sub_quants
x = x.permute(0, 2, 1)
if self.output_proj is not None:
x = self.output_proj(x)
return x, indices, sub_quants
def encode(
self,
x,
bandwidth: int = None,
):
# x: input tensor in the shape of (B, T, C)
# rq requires inputs in (B, C, T)
if self.input_proj is not None:
x = self.input_proj(x)
if self.input_act is not None:
x = self.input_act(x) * self.codec_range
indices = self.rq.encode(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
# return value in n_q x B x T
return indices
def decode(self, indices):
quantized_out = self.rq.decode(indices)
# quantized_out in B x D x T
if self.output_proj is not None:
quantized_out = self.output_proj(quantized_out.transpose(1, 2)).transpose(1, 2)
return quantized_out
def output_size(self):
return self.code_dim

View File

@ -1,299 +0,0 @@
"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""
import random
from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from einops import rearrange, pack, unpack
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# tensor helpers
def round_ste(z: Tensor) -> Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
# main class
class FSQ(Module):
def __init__(
self,
levels: List[int],
input_size: Optional[int] = None,
num_codebooks=1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
self.register_buffer("_levels", _levels, persistent=False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
self.register_buffer("_basis", _basis, persistent=False)
self.scale = scale
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = default(input_size, len(_levels) * num_codebooks)
has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
def output_size(self):
return self.dim
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 - eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_codes(
self,
indices: Tensor,
project_out=True
) -> Tensor:
"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
indices = rearrange(indices, '... -> ... 1')
codes_non_centered = (indices // self._basis) % self._levels
codes = self._scale_and_shift_inverse(codes_non_centered)
if self.keep_num_codebooks_dim:
codes = rearrange(codes, '... c d -> ... (c d)')
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
is_img_or_video = z.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
z = rearrange(z, 'b d ... -> b ... d')
z, ps = pack_one(z, 'b * d')
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
z = self.project_in(z)
z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
# reconstitute image or video dimensions
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device)
return out, indices, commit_loss, None
def inference(
self,
x,
bandwidth: int = None,
):
x, indices, _, _ = self.forward(x, bandwidth=bandwidth)
return x, indices, None
class BinaryFSQ(FSQ):
def __init__(
self,
levels: List[int],
input_size: Optional[int] = None,
num_codebooks=1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
rand_num_codebooks: Optional[List] = None,
):
_levels = torch.tensor(levels, dtype=int32)
assert torch.all(_levels == 2), "BinaryFSQ requires the levels must be 2"
super().__init__(
levels, input_size, num_codebooks,
keep_num_codebooks_dim, scale
)
self.rand_num_codebooks = rand_num_codebooks
def output_size(self):
return self.dim
def bound(self, z: Tensor, eps=1e-3) -> Tensor:
"""Bound `z`, an array of shape (..., d)."""
return torch.sigmoid(z)
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
return quantized
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_codes(
self,
indices: Tensor,
project_out=True
) -> Tensor:
"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
indices = rearrange(indices, '... -> ... 1')
codes = (indices // self._basis) % self._levels
if self.keep_num_codebooks_dim:
codes = rearrange(codes, '... c d -> ... (c d)')
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
is_img_or_video = z.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
z = rearrange(z, 'b d ... -> b ... d')
z, ps = pack_one(z, 'b * d')
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
z = self.project_in(z)
z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
codes = self.quantize(z)
if self.rand_num_codebooks is not None:
quant_idx = random.choice(self.rand_num_codebooks)
codes[:, :, quant_idx:, :] = 0
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
# reconstitute image or video dimensions
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device)
return out, indices, commit_loss, None

View File

@ -1,493 +0,0 @@
"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737
In the simplest setup, each dimension is quantized into {-1, 1}.
An entropy penalty is used to encourage utilization.
"""
import random
from math import log2, ceil
from collections import namedtuple
from typing import Optional, List
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module
from einops import rearrange, reduce, pack, unpack
# constants
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg() if callable(arg) else arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# entropy
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)
# class
class LFQ(Module):
def __init__(
self,
*,
input_size=None,
dim=None,
codebook_size=None,
entropy_loss_weight=0.1,
commitment_loss_weight=0.25,
diversity_gamma=1.,
straight_through_activation="identity",
num_codebooks=1,
keep_num_codebooks_dim=None,
codebook_scale=1., # for residual LFQ, codebook scaled down by 2x at each layer
rand_num_codebooks: Optional[List] = None,
sampling_rate=16000,
encoder_hop_length=640,
):
super().__init__()
# some assert validations
dim = input_size
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
assert not exists(codebook_size) or log2(
codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
codebook_size = default(codebook_size, lambda: 2 ** dim)
codebook_dim = int(log2(codebook_size))
codebook_dims = codebook_dim * num_codebooks
dim = default(dim, codebook_dims)
has_projections = dim != codebook_dims
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.dim = dim
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
# straight through activation
if straight_through_activation == "identity":
self.activation = nn.Identity()
elif straight_through_activation == "tanh":
self.activation = nn.Tanh()
else:
raise NotImplementedError("Unsupported activation type, only 'tanh' and 'identity' are supported")
# entropy aux loss related weights
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
# codebook scale
self.codebook_scale = codebook_scale
# commitment loss
self.commitment_loss_weight = commitment_loss_weight
# for no auxiliary loss, during inference
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
self.register_buffer('zero', torch.tensor(0.), persistent=False)
# codes
all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = self.bits_to_codes(bits)
self.register_buffer('codebook', codebook, persistent=False)
self.rand_num_codebooks = rand_num_codebooks
self.sampling_rate = sampling_rate
self.encoder_hop_length = encoder_hop_length
def output_size(self):
return self.dim
def bits_to_codes(self, bits):
return bits * self.codebook_scale * 2 - self.codebook_scale
@property
def dtype(self):
return self.codebook.dtype
def indices_to_codes(
self,
indices,
project_out=True
):
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... -> ... 1')
# indices to codes, which are bits of either -1 or 1
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
codes = self.bits_to_codes(bits)
codes = rearrange(codes, '... c d -> ... (c d)')
# whether to project codes out to original dimensions
# if the input feature dimensions were not log2(codebook size)
if project_out:
codes = self.project_out(codes)
# rearrange codes back to original shape
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
def keep_first_nq_codes(self, x, nq=None):
if nq is None or nq >= self.num_codebooks:
return x
inv_p = 1.0 / (nq / self.num_codebooks)
x[:, :, nq:] = 0
x[:, :, :nq] = x[:, :, :nq] * inv_p
return x
def random_dropout_codes(self, inputs):
x = torch.clone(inputs)
rand_num = random.choice(self.rand_num_codebooks)
return self.keep_first_nq_codes(x, nq=rand_num)
def cal_num_quant(self, bite_width):
frame_rate = self.sampling_rate / self.encoder_hop_length
nq = bite_width / frame_rate / self.codebook_dim
return nq
def forward(
self,
x,
inv_temperature=100.,
return_loss_breakdown=False,
mask=None,
bite_width=None,
):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
is_img_or_video = x.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
x = rearrange(x, 'b d ... -> b ... d')
x, ps = pack_one(x, 'b * d')
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
x = self.project_in(x)
x = self.activation(x)
# split out number of codebooks
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
# quantize by eq 3.
original_input = x
codebook_value = torch.ones_like(x) * self.codebook_scale
# do quantization
quantized = torch.where(x > 0, codebook_value, -codebook_value)
# use straight-through gradients (optionally with custom activation fn) if training
if self.training:
x = x + (quantized - x).detach()
else:
x = quantized
# calculate indices
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
# entropy aux loss
if self.training:
# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
prob = (-distance * inv_temperature).softmax(dim=-1)
per_sample_entropy = entropy(prob).mean()
# account for mask
if exists(mask):
prob = prob[mask]
# distribution over all available tokens in the batch
avg_prob = reduce(prob, '... c d -> c d', 'mean')
codebook_entropy = entropy(avg_prob).mean()
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
else:
# if not training, just return dummy 0
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
# commit loss
if self.training:
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none')
if exists(mask):
commit_loss = commit_loss[mask]
commit_loss = commit_loss.mean()
else:
commit_loss = self.zero
# randomly dropout codebooks to fit varying bite width
if self.training and self.rand_num_codebooks is not None:
x = self.random_dropout_codes(x)
if bite_width is not None:
x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width))
# merge back codebook dim
x = rearrange(x, 'b n c d -> b n (c d)')
# project out to feature dimension if needed
x = self.project_out(x)
# reconstitute image or video dimensions
if is_img_or_video:
x = unpack_one(x, ps, 'b * d')
x = rearrange(x, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
# whether to remove single codebook dim
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
# complete aux loss
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
if not return_loss_breakdown:
return x, indices, aux_loss, None
return x, indices, aux_loss, dict(
per_sample_entropy=per_sample_entropy,
codebook_entropy=codebook_entropy,
commit_loss=commit_loss
)
def inference(
self,
x,
bandwidth: int = None,
):
x, indices, _, _ = self.forward(x, bite_width=bandwidth)
return x, indices, None
class ScalableLFQ(LFQ):
def __init__(self, *, input_size=None, dim=None, codebook_size=None, entropy_loss_weight=0.1,
commitment_loss_weight=0.25, diversity_gamma=1., straight_through_activation=nn.Identity(),
num_codebooks=1, keep_num_codebooks_dim=None, codebook_scale=1.,
rand_num_codebooks: Optional[List] = None, sampling_rate=16000, hop_length=640, **kwargs):
super().__init__(input_size=input_size, dim=dim, codebook_size=codebook_size,
entropy_loss_weight=entropy_loss_weight, commitment_loss_weight=commitment_loss_weight,
diversity_gamma=diversity_gamma, straight_through_activation=straight_through_activation,
num_codebooks=num_codebooks, keep_num_codebooks_dim=keep_num_codebooks_dim,
codebook_scale=codebook_scale, rand_num_codebooks=rand_num_codebooks,
sampling_rate=sampling_rate, hop_length=hop_length)
codebook_alpha_conf = kwargs.get("codebook_alpha_conf", None)
self.init_codebook_alpha(codebook_alpha_conf)
def init_codebook_alpha(self, codebook_alpha_conf: dict):
assert codebook_alpha_conf is not None, "codebook_alpha_conf cannot be None"
name = codebook_alpha_conf.get("name", "constant")
if name == "constant":
alphas = codebook_alpha_conf.get("alphas", [1.0] * self.num_codebooks)
assert len(alphas) == self.num_codebooks, \
f"the length of codebook alphas {len(alphas)} " \
f"must match num_codebooks {self.num_codebooks}."
alphas = np.array(alphas)
elif name == "exponential":
temp = codebook_alpha_conf.get("temp", 8.0)
alphas = np.exp(-np.arange(0, self.num_codebooks) / temp)
else:
raise TypeError(f"Unknown codebook alpha type {name}.")
codebook_alpha = torch.tensor(alphas/alphas.sum(), dtype=torch.float32).reshape(1, 1, -1, 1)
self.register_buffer("codebook_alpha", codebook_alpha)
def forward(
self,
x,
inv_temperature=100.,
return_loss_breakdown=False,
mask=None,
bite_width=None
):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
is_img_or_video = x.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
x = rearrange(x, 'b d ... -> b ... d')
x, ps = pack_one(x, 'b * d')
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
x = self.project_in(x)
# split out number of codebooks
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
# quantize by eq 3.
original_input = x
codebook_value = torch.ones_like(x) * self.codebook_scale
quantized = torch.where(x > 0, codebook_value, -codebook_value)
# use straight-through gradients (optionally with custom activation fn) if training
if self.training:
x = self.activation(x)
x = x + (quantized - x).detach()
else:
x = quantized
# calculate indices
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
# entropy aux loss
if self.training:
# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
prob = (-distance * inv_temperature).softmax(dim=-1)
per_sample_entropy = entropy(prob).mean()
# account for mask
if exists(mask):
prob = prob[mask]
# distribution over all available tokens in the batch
avg_prob = reduce(prob, '... c d -> c d', 'mean')
codebook_entropy = entropy(avg_prob).mean()
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
else:
# if not training, just return dummy 0
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
# commit loss
if self.training:
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none')
if exists(mask):
commit_loss = commit_loss[mask]
commit_loss = commit_loss.mean()
else:
commit_loss = self.zero
# randomly dropout codebooks to fit s bite width
if self.training and self.rand_num_codebooks is not None:
x = self.random_dropout_codes(x)
if bite_width is not None:
x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width))
x = x * self.codebook_alpha
# merge back codebook dim
x = rearrange(x, 'b n c d -> b n (c d)')
# project out to feature dimension if needed
x = self.project_out(x)
# reconstitute image or video dimensions
if is_img_or_video:
x = unpack_one(x, ps, 'b * d')
x = rearrange(x, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
# whether to remove single codebook dim
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
# complete aux loss
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
if not return_loss_breakdown:
return x, indices, aux_loss, None
return x, indices, aux_loss, dict(
per_sample_entropy=per_sample_entropy,
codebook_entropy=codebook_entropy,
commit_loss=commit_loss
)

View File

@ -1,8 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from funasr.modules.quantization.vq import QuantizedResult, ResidualVectorQuantizer

View File

@ -1,291 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import io
import math
import random
import typing as tp
import torch
from funasr.models.sense_voice.quantizer.quantization.binary import BitPacker, BitUnpacker
def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
roundoff: float = 1e-8, min_range: int = 2,
check: bool = True) -> torch.Tensor:
"""Turn the given PDF into a quantized CDF that splits
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
to the PDF.
Args:
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
during the coding process is `[0, 2 ** total_range_bits - 1]`.
roundoff (float): will round the pdf up to that level to remove difference coming
from e.g. evaluating the Language Model on different architectures.
min_range (int): minimum range width. Should always be at least 2 for numerical
stability. Use this to avoid pathological behavior is a value
that is expected to be rare actually happens in real life.
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
pdf = pdf.detach()
if roundoff:
pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2 ** total_range_bits
cardinality = len(pdf)
alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range"
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
ranges += min_range
quantized_cdf = torch.cumsum(ranges, dim=-1)
if min_range < 2:
raise ValueError("min_range must be at least 2.")
if check:
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
raise ValueError("You must increase your total_range_bits.")
return quantized_cdf
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
sequence `(s_t)` by doing the following:
1) Initialize the current range to` [0 ** 2 B - 1]`.
2) For each time step t, split the current range into contiguous chunks,
one for each possible outcome, with size roughly proportional to `p`.
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
would be `{[0, 2], [3, 3]}`.
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
4) When done encoding all the values, just select any value remaining in the range.
You will notice that this procedure can fail: for instance if at any point in time
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
possible outcome. Intuitively, the more likely a value is, the less the range width
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
coding scheme, likely outcomes would take fewer bits, and more of them can be coded
with a fixed budget.
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
when the current range decreases below a given limit (given by `total_range_bits`), without
having to redo all the computations. If we encode mostly likely values, we will seldom
need to inject new bits, but a single rare value can deplete our stock of entropy!
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
code works for any sequence `(p_t)` possibly different for each timestep.
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
the KL between the true distribution and `p_t`, the most efficient the coding will be.
Args:
fo (IO[bytes]): file-like object to which the bytes will be written to.
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
Any time the current range width fall under this limit, new bits will
be injected to rescale the initial range.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
assert total_range_bits <= 30
self.total_range_bits = total_range_bits
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
self.low: int = 0
self.high: int = 0
self.max_bit: int = -1
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
@property
def delta(self) -> int:
"""Return the current range width."""
return self.high - self.low + 1
def _flush_common_prefix(self):
# If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
assert self.high < 2 ** (self.max_bit + 1)
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= (b1 << self.max_bit)
self.high -= (b1 << self.max_bit)
assert self.high >= self.low, (self.high, self.low, self.max_bit)
assert self.low >= 0
self.max_bit -= 1
self.packer.push(b1)
else:
break
def push(self, symbol: int, quantized_cdf: torch.Tensor):
"""Push the given symbol on the stream, flushing out bits
if possible.
Args:
symbol (int): symbol to encode with the AC.
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while self.delta < 2 ** self.total_range_bits:
self.low *= 2
self.high = self.high * 2 + 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
self._dbg.append((self.low, self.high))
self._dbg2.append((self.low, self.high))
self._flush_common_prefix()
assert self.low <= self.high
assert self.max_bit >= -1
assert self.max_bit <= 61, self.max_bit
def flush(self):
"""Flush the remaining information to the stream.
"""
while self.max_bit >= 0:
b1 = (self.low >> self.max_bit) & 1
self.packer.push(b1)
self.max_bit -= 1
self.packer.flush()
class ArithmeticDecoder:
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
If the AC encoder current range is [L, H], with `L` and `H` having the same common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
self.total_range_bits = total_range_bits
self.low: int = 0
self.high: int = 0
self.current: int = 0
self.max_bit: int = -1
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
# Following is for debugging
self._dbg: tp.List[tp.Any] = []
self._dbg2: tp.List[tp.Any] = []
self._last: tp.Any = None
@property
def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
# Given the current range [L, H], if both have a common prefix,
# we know we can remove it from our representation to avoid handling large numbers.
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= (b1 << self.max_bit)
self.high -= (b1 << self.max_bit)
self.current -= (b1 << self.max_bit)
assert self.high >= self.low
assert self.low >= 0
self.max_bit -= 1
else:
break
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
This returns `None` when the stream has been exhausted.
Args:
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exactly**
the same cdf as the one used at encoding time.
"""
while self.delta < 2 ** self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
self.low *= 2
self.high = self.high * 2 + 1
self.current = self.current * 2 + bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
# Binary search is not just for coding interviews :)
if high_idx < low_idx:
raise RuntimeError("Binary search failed")
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
if self.current <= high:
return mid, low, high, self.current
else:
return bin_search(mid + 1, high_idx)
else:
return bin_search(low_idx, mid - 1)
self._last = (self.low, self.high, self.current, self.max_bit)
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
self._dbg.append((self.low, self.high, self.current))
self._flush_common_prefix()
self._dbg2.append((self.low, self.high, self.current))
return sym
def test():
torch.manual_seed(1234)
random.seed(1234)
for _ in range(4):
pdfs = []
cardinality = random.randrange(4000)
steps = random.randrange(100, 500)
fo = io.BytesIO()
encoder = ArithmeticCoder(fo)
symbols = []
for step in range(steps):
pdf = torch.softmax(torch.randn(cardinality), dim=0)
pdfs.append(pdf)
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
symbol = torch.multinomial(pdf, 1).item()
symbols.append(symbol)
encoder.push(symbol, q_cdf)
encoder.flush()
fo.seek(0)
decoder = ArithmeticDecoder(fo)
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
decoded_symbol = decoder.pull(q_cdf)
assert decoded_symbol == symbol, idx
assert decoder.pull(torch.zeros(1)) is None
if __name__ == "__main__":
test()

View File

@ -1,157 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
import io
import json
import struct
import typing as tp
import ctypes
# format is `ECDC` magic code, followed by the header size as uint32.
# Then an uint8 indicates the protocol version (0.)
# The header is then provided as json and should contain all required
# informations for decoding. A raw stream of bytes is then provided
# and should be interpretable using the json header.
_encodec_header_struct = struct.Struct('!4sBI')
_ENCODEC_MAGIC = b'ECDC'
def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
meta_dumped = json.dumps(metadata).encode('utf-8')
version = 0
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped))
fo.write(header)
fo.write(meta_dumped)
fo.flush()
def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
buf = b""
while len(buf) < size:
new_buf = fo.read(size)
if not new_buf:
raise EOFError("Impossible to read enough data from the stream, "
f"{size} bytes remaining.")
buf += new_buf
size -= len(new_buf)
return buf
def read_ecdc_header(fo: tp.IO[bytes]):
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
if magic != _ENCODEC_MAGIC:
raise ValueError("File is not in ECDC format.")
if version != 0:
raise ValueError("Version not supported.")
meta_bytes = _read_exactly(fo, meta_size)
return json.loads(meta_bytes.decode('utf-8'))
class BitPacker:
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
Note that for some bandwidth (1.5, 3), the codebook representation
will not cover an integer number of bytes.
Args:
bits (int): number of bits per value that will be pushed.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self._current_value = 0
self._current_bits = 0
self.bits = bits
self.fo = fo
def push(self, value: int):
"""Push a new value to the stream. This will immediately
write as many uint8 as possible to the underlying file-object."""
value = bin(ctypes.c_uint32.from_buffer(ctypes.c_float(value)).value)
value = ctypes.c_int.from_buffer(ctypes.c_uint32(int(value, 2))).value
self._current_value += (value << self._current_bits)
self._current_bits += self.bits
while self._current_bits >= 8:
lower_8bits = self._current_value & 0xff
self._current_bits -= 8
self._current_value >>= 8
self.fo.write(bytes([lower_8bits]))
def flush(self):
"""Flushes the remaining partial uint8, call this at the end
of the stream to encode."""
if self._current_bits:
self.fo.write(bytes([self._current_value]))
self._current_value = 0
self._current_bits = 0
self.fo.flush()
class BitUnpacker:
"""BitUnpacker does the opposite of `BitPacker`.
Args:
bits (int): number of bits of the values to decode.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self.bits = bits
self.fo = fo
self._mask = (1 << bits) - 1
self._current_value = 0
self._current_bits = 0
def pull(self) -> tp.Optional[int]:
"""
Pull a single value from the stream, potentially reading some
extra bytes from the underlying file-object.
Returns `None` when reaching the end of the stream.
"""
while self._current_bits < self.bits:
buf = self.fo.read(1)
if not buf:
return None
character = buf[0]
self._current_value += character << self._current_bits
self._current_bits += 8
out = self._current_value & self._mask
self._current_value >>= self.bits
self._current_bits -= self.bits
# out = ctypes.c_float.from_buffer(ctypes.c_uint32(int(out, 2))).value
return out
def test():
import torch
torch.manual_seed(1234)
for rep in range(4):
length: int = torch.randint(10, 2_000, (1,)).item()
bits: int = torch.randint(1, 16, (1,)).item()
tokens: tp.List[int] = torch.randint(2 ** bits, (length,)).tolist()
rebuilt: tp.List[int] = []
buf = io.BytesIO()
packer = BitPacker(bits, buf)
for token in tokens:
packer.push(token)
packer.flush()
buf.seek(0)
unpacker = BitUnpacker(bits, buf)
while True:
value = unpacker.pull()
if value is None:
break
rebuilt.append(value)
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
# The flushing mechanism might lead to "ghost" values at the end of the stream.
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits)
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
assert a == b, (idx, a, b)
if __name__ == '__main__':
test()

View File

@ -1,674 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import logging
import typing as tp
from random import randrange
import numpy as np
from einops import rearrange, repeat
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F
import random
from funasr.models.sense_voice.quantizer.quantization import distrib
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
@torch.no_grad()
def kmeans(samples, num_clusters: int, num_iters: int = 10):
# device = samples.device
# samples = samples.cpu()
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
# diffs = rearrange(samples, "n d -> n () d") - rearrange(
# means, "c d -> () c d"
# )
# dists = -(diffs ** 2).sum(dim=-1)
dists = -(
samples.pow(2).sum(1, keepdim=True)
- 2 * torch.matmul(samples, means.t())
+ means.t().pow(2).sum(0, keepdim=True)
)
buckets = dists.max(dim=-1).indices
del dists
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
# means = means.to(device)
return means, bins
def preprocess(x):
x = rearrange(x, "... d -> (...) d")
return x
def postprocess_emb(embed_ind, shape):
return embed_ind.view(*shape[:-1])
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
self.training = True
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
distrib.broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
distrib.broadcast_tensors(self.buffers())
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
# Note: after ema update, there is a very small difference between codebooks on GPUs.
# The impact can be very small, ignore it.
return quantize, embed_ind
class SimpleEuclideanCodebook(nn.Module):
"""Simple Codebook with Euclidean distance.
Using gradient to update code embeddings instead of EMA.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: tp.Union[bool, torch.Tensor] = False,
kmeans_iters: int = 10,
**kwargs
):
super().__init__()
if isinstance(kmeans_init, bool):
if kmeans_init:
embed = torch.zeros(codebook_size, dim)
inited = False
else:
embed = uniform_init(codebook_size, dim)
inited = True
else:
embed = kmeans_init
inited = True
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.embed = nn.Embedding(codebook_size, dim)
self.embed.weight.data.copy_(embed)
# self.register_parameter("embed", nn.Parameter(embed, requires_grad=True))
self.register_buffer("inited", torch.Tensor([inited]))
self.training = True
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
distrib.broadcast_tensors(self.buffers())
def quantize(self, x):
embed = self.embed.weight.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def dequantize(self, embed_ind):
quantize = self.embed(embed_ind)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_ind = postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently, supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
self.training = True
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class SimpleVectorQuantization(nn.Module):
"""Vector quantization implementation with SimpleEuclideanCodebook.
Currently, supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
commitment_weight: float = 0.25,
codebook_weight: float = 1.0,
**kwargs
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self.codebook_weight = codebook_weight
logging.info(f"commitment_weight: {commitment_weight}, codebook_weight: {codebook_weight}.")
self._codebook = SimpleEuclideanCodebook(
dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters
)
self.codebook_size = codebook_size
self.training = True
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
# commit loss for codebook
if self.codebook_weight > 0:
codebook_loss = F.mse_loss(quantize, x.detach())
loss = loss + codebook_loss * self.codebook_weight
# commit loss for encoder
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *,
num_quantizers,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
**kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
self.quantize_dropout = quantize_dropout
self.rand_num_quant = rand_num_quant
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
device = x.device
all_losses = []
all_indices = []
all_sub_quants = []
n_q = n_q or len(self.layers)
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
if should_quantize_dropout:
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
null_indices_shape = (x.shape[0], x.shape[2])
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
for quantizer_index, layer in enumerate(self.layers[:n_q]):
# dropout except the first quantizer
if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
all_sub_quants.append(null_sub_quant)
continue
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
all_sub_quants.append(quantized)
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
return quantized_out, out_indices, out_losses, out_sub_quants
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
class SimpleResidualVectorQuantization(nn.Module):
"""Simple Residual vector quantization with gradient to
update codebook instead of EMA
"""
def __init__(self, *,
num_quantizers,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
**kwargs):
super().__init__()
kmeans_init = raw_kmeans_init = kwargs.pop('kmeans_init', True)
if isinstance(kmeans_init, str):
# use prepared kmeans init
embed = np.load(kmeans_init)
embed = torch.from_numpy(embed)
if embed.dim() == 2:
embed = embed.unsqueeze(0)
kmeans_init = embed
self.layers = nn.ModuleList([
SimpleVectorQuantization(
kmeans_init=kmeans_init[i] if isinstance(kmeans_init, torch.Tensor) else kmeans_init,
**kwargs
) for i in range(num_quantizers)
])
kwargs["kmeans_init"] = raw_kmeans_init
self.quantize_dropout = quantize_dropout
self.rand_num_quant = rand_num_quant
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
device = x.device
all_losses = []
all_indices = []
all_sub_quants = []
n_q = n_q or len(self.layers)
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
if should_quantize_dropout:
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
null_indices_shape = (x.shape[0], x.shape[2])
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
for quantizer_index, layer in enumerate(self.layers[:n_q]):
# dropout except the first quantizer
if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
all_sub_quants.append(null_sub_quant)
continue
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
all_sub_quants.append(quantized)
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
return quantized_out, out_indices, out_losses, out_sub_quants
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out

View File

@ -1,513 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import logging
import random
import typing as tp
from random import randrange
import numpy as np
from einops import rearrange, repeat
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F
from funasr.utils.hinter import hint_once
from funasr.models.sense_voice.quantizer.quantization import distrib
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float, mask=None):
if mask is None:
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
else:
mask = mask.float()
new_avg = moving_avg * decay + new * (1 - decay)
new_avg = mask * new_avg + (1 - mask) * moving_avg
moving_avg.data.copy_(new_avg.data)
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
@torch.no_grad()
def kmeans(samples, num_clusters: int, num_iters: int = 10):
# device = samples.device
# samples = samples.cpu()
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
# diffs = rearrange(samples, "n d -> n () d") - rearrange(
# means, "c d -> () c d"
# )
# dists = -(diffs ** 2).sum(dim=-1)
dists = -(
samples.pow(2).sum(1, keepdim=True)
- 2 * torch.matmul(samples, means.t())
+ means.t().pow(2).sum(0, keepdim=True)
)
buckets = dists.max(dim=-1).indices
del dists
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
# means = means.to(device)
return means, bins
def preprocess(x):
x = rearrange(x, "... d -> (...) d")
return x
def postprocess_emb(embed_ind, shape):
return embed_ind.view(*shape[:-1])
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
sparse_update: bool = False,
normalized_input: bool = False,
**kwargs,
):
super().__init__()
self.decay = decay
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.inited = None
self.cluster_size = None
self.embed = None
self.embed_avg = None
self.training = True
self.sparse_update = sparse_update
self.normalized_input = normalized_input
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
hint_once(f"threshold_ema_dead_code: {self.threshold_ema_dead_code}.", "threshold_ema_dead_code", rank=0)
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
# sync buffers outside for efficiency
# distrib.broadcast_tensors(self.buffers())
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
shape = x.shape
# pre-process
x = preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x, buffers):
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
shape, dtype = x.shape, x.dtype
x = preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
if not self.sparse_update:
mask = None
else:
mask = embed_onehot.sum(0) > 0
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay, mask=mask)
embed_sum = x.t() @ embed_onehot
# if self.normalized_input:
# embed_sum = F.normalize(embed_sum, dim=0)
ema_inplace(self.embed_avg, embed_sum.t(), self.decay,
mask=mask.unsqueeze(-1) if self.sparse_update else None)
if not self.sparse_update:
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
else:
embed_normalized = self.embed_avg
if self.normalized_input:
embed_normalized = F.normalize(embed_normalized, dim=-1)
self.embed.data.copy_(embed_normalized)
# Note: after ema update, there is a very small difference between codebooks on GPUs.
# The impact can be very small, ignore it.
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently, supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.,
**kwargs,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
**kwargs)
self.codebook_size = codebook_size
self.training = True
@property
def codebook(self):
return self._codebook.embed
def encode(self, x, buffers):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x, buffers)
return embed_in
def decode(self, embed_ind, buffers):
quantize = self._codebook.decode(embed_ind, buffers)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x, buffers):
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x, buffers)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class DistributedResidualVectorQuantization(nn.Module):
"""Efficient distributed residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *,
num_quantizers,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
**kwargs):
super().__init__()
"""
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
"""
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["dim"]
kmeans_init = kwargs["kmeans_init"]
if isinstance(kmeans_init, bool):
if not kwargs["kmeans_init"]:
# use uniform init
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
inited = True
cluster_size = 1
else:
# to perform kmeans init on first batch
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
inited = False
cluster_size = 0
elif isinstance(kmeans_init, str):
# use prepared kmeans init
embed = np.load(kmeans_init)
embed = torch.from_numpy(embed)
if kwargs.get("normalized_input", False):
logging.info("normalize the code embeddings since the input is normalized.")
embed = F.normalize(embed, dim=-1)
if embed.dim() == 2:
embed = embed.repeat(num_quantizers, 1, 1)
inited = True
cluster_size = 1
else:
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
self.register_buffer("cluster_size", torch.ones(num_quantizers, codebook_size) * cluster_size)
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
self.q0_ds_ratio = 1
if "q0_ds_ratio" in kwargs:
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
self.layers = nn.ModuleList()
for i in range(num_quantizers):
vq_args = dict(**kwargs)
vq = VectorQuantization(**vq_args)
self.layers.append(vq)
self.quantize_dropout = quantize_dropout
self.rand_num_quant = rand_num_quant
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = torch.zeros_like(x)
residual = x
bb, cc, tt = x.shape
device = x.device
all_losses = []
all_indices = []
all_sub_quants = []
n_q = n_q or len(self.layers)
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
if should_quantize_dropout:
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
null_indices_shape = (x.shape[0], x.shape[2])
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
for quantizer_index, layer in enumerate(self.layers[:n_q]):
# dropout except the first quantizer
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
all_sub_quants.append(null_sub_quant)
continue
quant_in = residual
if self.q0_ds_ratio > 1 and quantizer_index == 0:
quant_in = F.interpolate(quant_in, size=[tt//2])
quantized, indices, loss = layer(quant_in, [
self.inited[quantizer_index],
self.cluster_size[quantizer_index],
self.embed[quantizer_index],
self.embed_avg[quantizer_index]
])
if self.q0_ds_ratio > 1 and quantizer_index == 0:
quantized = F.interpolate(quantized, size=[tt])
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
all_sub_quants.append(quantized)
# sync buffers after one forward step
distrib.broadcast_tensors(self.buffers())
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
return quantized_out, out_indices, out_losses, out_sub_quants
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for i, layer in enumerate(self.layers[:n_q]):
indices = layer.encode(residual, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
quantized = layer.decode(indices, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices, [
self.inited[i],
self.cluster_size[i],
self.embed[i],
self.embed_avg[i]
])
quantized_out = quantized_out + quantized
return quantized_out

View File

@ -1,126 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import typing as tp
import time, logging
import torch
def rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
"at least one worker has a different one.")
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
# start = time.time()
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
# logging.info(f"Time of broadcast {(time.time()-start):.4f}")
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = torch.distributed.all_reduce(
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
else:
handle = torch.distributed.broadcast(
buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: tp.Dict[str, float], count=1.):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))

View File

@ -1,134 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
import logging
from dataclasses import dataclass, field
import math
import typing as tp
import torch
from torch import nn
from funasr.models.sense_voice.quantizer.quantization import distrib
from funasr.models.sense_voice.quantizer.quantization.core_vq import SimpleResidualVectorQuantization
from funasr.models.sense_voice.quantizer.quantization.ddp_core_vq import DistributedResidualVectorQuantization
@dataclass
class QuantizedResult:
quantized: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
sub_quants: torch.Tensor = None
class ResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: tp.Union[bool, str] = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
quantize_dropout: bool = False,
rand_num_quant: tp.Optional[tp.List] = None,
encoder_hop_length: int = 320,
use_ddp: bool = True,
q0_ds_ratio: int = 1,
**kwargs,
):
super().__init__()
self.n_q = n_q
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
self.encoder_hop_length = encoder_hop_length
self.training = True
if use_ddp:
rvq_class = DistributedResidualVectorQuantization
logging.info("Using distributed residual vector quantization.")
else:
rvq_class = SimpleResidualVectorQuantization
logging.warning("Using simple residual vector quantization")
self.model = rvq_class(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
quantize_dropout=quantize_dropout,
rand_num_quant=rand_num_quant,
q0_ds_ratio=q0_ds_ratio,
**kwargs
)
def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor in the shape of (B, C, T).
sample_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
quantized, codes, commit_loss, sub_quants = self.model(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw,
penalty=torch.mean(commit_loss),
sub_quants=sub_quants)
def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
"""Return n_q based on specified target bandwidth.
"""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
n_q = self.n_q
if bandwidth and bandwidth > 0.:
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
return n_q
def get_bandwidth_per_quantizer(self, sample_rate: int):
"""Return bandwidth per quantizer for a given input sample rate.
"""
return math.log2(self.bins) * sample_rate / self.encoder_hop_length
def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
"""
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
codes = self.model.encode(x, n_q=n_q)
return codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
"""
quantized = self.model.decode(codes)
return quantized