mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
insert VQ into sensevoice encoder
This commit is contained in:
parent
6fdba0822e
commit
e9557a0ee7
@ -14,6 +14,7 @@ from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.utils.hinter import hint_once
|
||||
|
||||
|
||||
class SinusoidalPositionEncoder(torch.nn.Module):
|
||||
@ -1604,6 +1605,130 @@ class SenseVoiceEncoder(nn.Module):
|
||||
return x, olens
|
||||
|
||||
|
||||
@tables.register("encoder_classes", "SenseVoiceQuantizedEncoder")
|
||||
class SenseVoiceQuantizedEncoder(SenseVoiceEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
linear_units: int,
|
||||
attention_heads: int,
|
||||
num_blocks: int,
|
||||
quantize_layer_idx: int,
|
||||
normalized_quant_input: bool,
|
||||
quantizer_config: dict,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(input_size, linear_units, attention_heads, num_blocks, **kwargs)
|
||||
self.linear_units = linear_units
|
||||
self.quantize_layer_idx = quantize_layer_idx
|
||||
self.normalized_quant_input = normalized_quant_input
|
||||
self.quantizer = self.build_quantizer(quantizer_config)
|
||||
|
||||
def build_quantizer(self, vq_config):
|
||||
if vq_config is None:
|
||||
return None
|
||||
name = vq_config.pop("name", "costume_quantizer")
|
||||
if name == "costume_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.costume_quantizer import CostumeQuantizer
|
||||
quantizer = CostumeQuantizer(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
)
|
||||
vq_config["name"] = "costume_quantizer"
|
||||
return quantizer
|
||||
elif name == "lookup_free_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.lookup_free_quantizer import LFQ
|
||||
quantizer = LFQ(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
)
|
||||
vq_config["name"] = "lookup_free_quantizer"
|
||||
return quantizer
|
||||
elif name == "finite_scalar_quantizer":
|
||||
from funasr.models.sense_voice.quantizer.finite_scalar_quantizer import FSQ
|
||||
quantizer = FSQ(
|
||||
input_size=self.linear_units,
|
||||
**vq_config,
|
||||
)
|
||||
vq_config["name"] = "finite_scalar_quantizer"
|
||||
return quantizer
|
||||
else:
|
||||
raise NotImplemented("quantizer {} not implemented".format(name))
|
||||
|
||||
def quantize_enc_outs(self, x):
|
||||
ret_dict = {}
|
||||
|
||||
if self.normalized_quant_input:
|
||||
x = F.normalize(x, dim=-1)
|
||||
ret_dict["quant_in"] = x
|
||||
x, indices, commit_loss, sub_quants = self.quantizer(x)
|
||||
ret_dict["quant_out"] = x
|
||||
ret_dict["indices"] = indices
|
||||
ret_dict["quant_loss"] = commit_loss
|
||||
|
||||
return x, ret_dict
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
use_padmask = self.use_padmask
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
only_extract_tokens = kwargs.get("only_extract_tokens", False)
|
||||
|
||||
n_frames = x.size(1)
|
||||
max_pos = n_frames
|
||||
|
||||
if ilens is not None:
|
||||
if self.downsample_rate == 4:
|
||||
olens = (
|
||||
1
|
||||
+ (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
|
||||
// self.conv1.stride[0]
|
||||
)
|
||||
else:
|
||||
olens = ilens
|
||||
olens = (
|
||||
1
|
||||
+ (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
|
||||
// self.conv2.stride[0]
|
||||
)
|
||||
olens = torch.clamp(olens, max=max_pos)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
if use_padmask and olens is not None:
|
||||
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device)
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
device = x.device
|
||||
seq_length = x.shape[1]
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x = block(x, mask=padding_mask, position_ids=position_ids)
|
||||
if self.quantize_layer_idx is not None and self.quantizer is not None:
|
||||
if layer == self.quantize_layer_idx:
|
||||
hint_once(f"Quantization at layer {layer} wit {self.quantizer}",
|
||||
"normalize_quant_enc_out", rank=0)
|
||||
x, ret_dict = self.quantize_enc_outs(x)
|
||||
if only_extract_tokens:
|
||||
return (x, ret_dict), olens
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if ilens is None:
|
||||
return x
|
||||
else:
|
||||
return x, olens
|
||||
|
||||
|
||||
import types
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
0
funasr/models/sense_voice/quantizer/__init__.py
Normal file
0
funasr/models/sense_voice/quantizer/__init__.py
Normal file
126
funasr/models/sense_voice/quantizer/costume_quantizer.py
Normal file
126
funasr/models/sense_voice/quantizer/costume_quantizer.py
Normal file
@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from funasr.models.sense_voice.quantizer.quantization.vq import ResidualVectorQuantizer
|
||||
import typing as tp
|
||||
|
||||
|
||||
class CostumeQuantizer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 512,
|
||||
codebook_size: int = 1024,
|
||||
num_quantizers: int = 8,
|
||||
ema_decay: float = 0.95,
|
||||
kmeans_init: tp.Union[bool, str] = False,
|
||||
sampling_rate: int = 16_000,
|
||||
quantize_dropout: bool = False,
|
||||
rand_num_quant: tp.Optional[tp.List] = None,
|
||||
encoder_hop_length: int = 320,
|
||||
use_ddp: bool = True,
|
||||
q0_ds_ratio: int = 1,
|
||||
codec_dim: int = None,
|
||||
codec_range: float = None,
|
||||
threshold_ema_dead_code=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if codec_dim is None:
|
||||
codec_dim = input_size
|
||||
|
||||
self.input_proj, self.output_proj = None, None
|
||||
if codec_dim != input_size:
|
||||
self.input_proj = torch.nn.Linear(input_size, codec_dim)
|
||||
self.output_proj = torch.nn.Linear(codec_dim, input_size)
|
||||
|
||||
self.input_act, self.codec_range = None, None
|
||||
if codec_range is not None:
|
||||
self.input_act = torch.nn.Tanh()
|
||||
self.codec_range = codec_range
|
||||
|
||||
self.rq = ResidualVectorQuantizer(
|
||||
dimension=codec_dim,
|
||||
n_q=num_quantizers,
|
||||
bins=codebook_size,
|
||||
decay=ema_decay,
|
||||
kmeans_init=kmeans_init,
|
||||
quantize_dropout=quantize_dropout,
|
||||
rand_num_quant=rand_num_quant,
|
||||
encoder_hop_length=encoder_hop_length,
|
||||
use_ddp=use_ddp,
|
||||
q0_ds_ratio=q0_ds_ratio,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
||||
**kwargs,
|
||||
)
|
||||
self.code_dim = codec_dim
|
||||
self.sampling_rate = sampling_rate
|
||||
self.bandwidth: tp.Optional[float] = None
|
||||
self.encoder_hop_length = encoder_hop_length
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
bandwidth: int = None,
|
||||
):
|
||||
# x: input tensor in the shape of (B, T, C)
|
||||
# rq requires inputs in (B, C, T)
|
||||
|
||||
if self.input_proj is not None:
|
||||
x = self.input_proj(x)
|
||||
if self.input_act is not None:
|
||||
x = self.input_act(x) * self.codec_range
|
||||
|
||||
qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
|
||||
x, indices, commit_loss, sub_quants = qv.quantized, qv.codes, qv.penalty, qv.sub_quants
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
if self.output_proj is not None:
|
||||
x = self.output_proj(x)
|
||||
|
||||
return x, indices, commit_loss, sub_quants
|
||||
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
bandwidth: int = None,
|
||||
):
|
||||
# x: input tensor in the shape of (B, T, C)
|
||||
# rq requires inputs in (B, C, T)
|
||||
if self.input_proj is not None:
|
||||
x = self.input_proj(x)
|
||||
if self.input_act is not None:
|
||||
x = self.input_act(x) * self.codec_range
|
||||
|
||||
qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
|
||||
x, indices, sub_quants = qv.quantized, qv.codes, qv.sub_quants
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
if self.output_proj is not None:
|
||||
x = self.output_proj(x)
|
||||
|
||||
return x, indices, sub_quants
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x,
|
||||
bandwidth: int = None,
|
||||
):
|
||||
# x: input tensor in the shape of (B, T, C)
|
||||
# rq requires inputs in (B, C, T)
|
||||
if self.input_proj is not None:
|
||||
x = self.input_proj(x)
|
||||
if self.input_act is not None:
|
||||
x = self.input_act(x) * self.codec_range
|
||||
|
||||
indices = self.rq.encode(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
|
||||
# return value in n_q x B x T
|
||||
return indices
|
||||
|
||||
def decode(self, indices):
|
||||
quantized_out = self.rq.decode(indices)
|
||||
# quantized_out in B x D x T
|
||||
if self.output_proj is not None:
|
||||
quantized_out = self.output_proj(quantized_out.transpose(1, 2)).transpose(1, 2)
|
||||
return quantized_out
|
||||
|
||||
def output_size(self):
|
||||
return self.code_dim
|
||||
299
funasr/models/sense_voice/quantizer/finite_scalar_quantizer.py
Normal file
299
funasr/models/sense_voice/quantizer/finite_scalar_quantizer.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""
|
||||
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
||||
Code adapted from Jax version in Appendix A.1
|
||||
"""
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Module
|
||||
from torch import Tensor, int32
|
||||
|
||||
from einops import rearrange, pack, unpack
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def default(*args):
|
||||
for arg in args:
|
||||
if exists(arg):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def round_ste(z: Tensor) -> Tensor:
|
||||
"""Round with straight through gradients."""
|
||||
zhat = z.round()
|
||||
return z + (zhat - z).detach()
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
class FSQ(Module):
|
||||
def __init__(
|
||||
self,
|
||||
levels: List[int],
|
||||
input_size: Optional[int] = None,
|
||||
num_codebooks=1,
|
||||
keep_num_codebooks_dim: Optional[bool] = None,
|
||||
scale: Optional[float] = None
|
||||
):
|
||||
super().__init__()
|
||||
_levels = torch.tensor(levels, dtype=int32)
|
||||
self.register_buffer("_levels", _levels, persistent=False)
|
||||
|
||||
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
||||
self.register_buffer("_basis", _basis, persistent=False)
|
||||
|
||||
self.scale = scale
|
||||
|
||||
codebook_dim = len(levels)
|
||||
self.codebook_dim = codebook_dim
|
||||
|
||||
effective_codebook_dim = codebook_dim * num_codebooks
|
||||
self.num_codebooks = num_codebooks
|
||||
self.effective_codebook_dim = effective_codebook_dim
|
||||
|
||||
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
||||
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
||||
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
||||
|
||||
self.dim = default(input_size, len(_levels) * num_codebooks)
|
||||
|
||||
has_projections = self.dim != effective_codebook_dim
|
||||
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
||||
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
||||
self.has_projections = has_projections
|
||||
|
||||
self.codebook_size = self._levels.prod().item()
|
||||
|
||||
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
|
||||
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
||||
|
||||
def output_size(self):
|
||||
return self.dim
|
||||
|
||||
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
|
||||
"""Bound `z`, an array of shape (..., d)."""
|
||||
half_l = (self._levels - 1) * (1 - eps) / 2
|
||||
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
||||
shift = (offset / half_l).atanh()
|
||||
return (z + shift).tanh() * half_l - offset
|
||||
|
||||
def quantize(self, z: Tensor) -> Tensor:
|
||||
"""Quantizes z, returns quantized zhat, same shape as z."""
|
||||
quantized = round_ste(self.bound(z))
|
||||
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
||||
return quantized / half_width
|
||||
|
||||
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
|
||||
half_width = self._levels // 2
|
||||
return (zhat_normalized * half_width) + half_width
|
||||
|
||||
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
|
||||
half_width = self._levels // 2
|
||||
return (zhat - half_width) / half_width
|
||||
|
||||
def codes_to_indices(self, zhat: Tensor) -> Tensor:
|
||||
"""Converts a `code` to an index in the codebook."""
|
||||
assert zhat.shape[-1] == self.codebook_dim
|
||||
zhat = self._scale_and_shift(zhat)
|
||||
return (zhat * self._basis).sum(dim=-1).to(int32)
|
||||
|
||||
def indices_to_codes(
|
||||
self,
|
||||
indices: Tensor,
|
||||
project_out=True
|
||||
) -> Tensor:
|
||||
"""Inverse of `codes_to_indices`."""
|
||||
|
||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
||||
|
||||
indices = rearrange(indices, '... -> ... 1')
|
||||
codes_non_centered = (indices // self._basis) % self._levels
|
||||
codes = self._scale_and_shift_inverse(codes_non_centered)
|
||||
|
||||
if self.keep_num_codebooks_dim:
|
||||
codes = rearrange(codes, '... c d -> ... (c d)')
|
||||
|
||||
if project_out:
|
||||
codes = self.project_out(codes)
|
||||
|
||||
if is_img_or_video:
|
||||
codes = rearrange(codes, 'b ... d -> b d ...')
|
||||
|
||||
return codes
|
||||
|
||||
def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
n - sequence (or flattened spatial dimensions)
|
||||
d - feature dimension, which is also log2(codebook size)
|
||||
c - number of codebook dim
|
||||
"""
|
||||
|
||||
is_img_or_video = z.ndim >= 4
|
||||
|
||||
# standardize image or video into (batch, seq, dimension)
|
||||
|
||||
if is_img_or_video:
|
||||
z = rearrange(z, 'b d ... -> b ... d')
|
||||
z, ps = pack_one(z, 'b * d')
|
||||
|
||||
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
||||
|
||||
z = self.project_in(z)
|
||||
|
||||
z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
||||
|
||||
codes = self.quantize(z)
|
||||
indices = self.codes_to_indices(codes)
|
||||
|
||||
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
||||
|
||||
out = self.project_out(codes)
|
||||
|
||||
# reconstitute image or video dimensions
|
||||
|
||||
if is_img_or_video:
|
||||
out = unpack_one(out, ps, 'b * d')
|
||||
out = rearrange(out, 'b ... d -> b d ...')
|
||||
|
||||
indices = unpack_one(indices, ps, 'b * c')
|
||||
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, '... 1 -> ...')
|
||||
|
||||
commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device)
|
||||
return out, indices, commit_loss, None
|
||||
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
bandwidth: int = None,
|
||||
):
|
||||
x, indices, _, _ = self.forward(x, bandwidth=bandwidth)
|
||||
|
||||
return x, indices, None
|
||||
|
||||
|
||||
class BinaryFSQ(FSQ):
|
||||
def __init__(
|
||||
self,
|
||||
levels: List[int],
|
||||
input_size: Optional[int] = None,
|
||||
num_codebooks=1,
|
||||
keep_num_codebooks_dim: Optional[bool] = None,
|
||||
scale: Optional[float] = None,
|
||||
rand_num_codebooks: Optional[List] = None,
|
||||
):
|
||||
_levels = torch.tensor(levels, dtype=int32)
|
||||
assert torch.all(_levels == 2), "BinaryFSQ requires the levels must be 2"
|
||||
super().__init__(
|
||||
levels, input_size, num_codebooks,
|
||||
keep_num_codebooks_dim, scale
|
||||
)
|
||||
self.rand_num_codebooks = rand_num_codebooks
|
||||
|
||||
def output_size(self):
|
||||
return self.dim
|
||||
|
||||
def bound(self, z: Tensor, eps=1e-3) -> Tensor:
|
||||
"""Bound `z`, an array of shape (..., d)."""
|
||||
return torch.sigmoid(z)
|
||||
|
||||
def quantize(self, z: Tensor) -> Tensor:
|
||||
"""Quantizes z, returns quantized zhat, same shape as z."""
|
||||
quantized = round_ste(self.bound(z))
|
||||
return quantized
|
||||
|
||||
def codes_to_indices(self, zhat: Tensor) -> Tensor:
|
||||
"""Converts a `code` to an index in the codebook."""
|
||||
assert zhat.shape[-1] == self.codebook_dim
|
||||
return (zhat * self._basis).sum(dim=-1).to(int32)
|
||||
|
||||
def indices_to_codes(
|
||||
self,
|
||||
indices: Tensor,
|
||||
project_out=True
|
||||
) -> Tensor:
|
||||
"""Inverse of `codes_to_indices`."""
|
||||
|
||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
||||
|
||||
indices = rearrange(indices, '... -> ... 1')
|
||||
codes = (indices // self._basis) % self._levels
|
||||
|
||||
if self.keep_num_codebooks_dim:
|
||||
codes = rearrange(codes, '... c d -> ... (c d)')
|
||||
|
||||
if project_out:
|
||||
codes = self.project_out(codes)
|
||||
|
||||
if is_img_or_video:
|
||||
codes = rearrange(codes, 'b ... d -> b d ...')
|
||||
|
||||
return codes
|
||||
|
||||
def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
n - sequence (or flattened spatial dimensions)
|
||||
d - feature dimension, which is also log2(codebook size)
|
||||
c - number of codebook dim
|
||||
"""
|
||||
|
||||
is_img_or_video = z.ndim >= 4
|
||||
|
||||
# standardize image or video into (batch, seq, dimension)
|
||||
|
||||
if is_img_or_video:
|
||||
z = rearrange(z, 'b d ... -> b ... d')
|
||||
z, ps = pack_one(z, 'b * d')
|
||||
|
||||
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
||||
|
||||
z = self.project_in(z)
|
||||
|
||||
z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
||||
|
||||
codes = self.quantize(z)
|
||||
if self.rand_num_codebooks is not None:
|
||||
quant_idx = random.choice(self.rand_num_codebooks)
|
||||
codes[:, :, quant_idx:, :] = 0
|
||||
indices = self.codes_to_indices(codes)
|
||||
|
||||
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
||||
|
||||
out = self.project_out(codes)
|
||||
|
||||
# reconstitute image or video dimensions
|
||||
|
||||
if is_img_or_video:
|
||||
out = unpack_one(out, ps, 'b * d')
|
||||
out = rearrange(out, 'b ... d -> b d ...')
|
||||
|
||||
indices = unpack_one(indices, ps, 'b * c')
|
||||
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, '... 1 -> ...')
|
||||
|
||||
commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device)
|
||||
return out, indices, commit_loss, None
|
||||
493
funasr/models/sense_voice/quantizer/lookup_free_quantizer.py
Normal file
493
funasr/models/sense_voice/quantizer/lookup_free_quantizer.py
Normal file
@ -0,0 +1,493 @@
|
||||
"""
|
||||
Lookup Free Quantization
|
||||
Proposed in https://arxiv.org/abs/2310.05737
|
||||
|
||||
In the simplest setup, each dimension is quantized into {-1, 1}.
|
||||
An entropy penalty is used to encourage utilization.
|
||||
"""
|
||||
import random
|
||||
from math import log2, ceil
|
||||
from collections import namedtuple
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
|
||||
from einops import rearrange, reduce, pack, unpack
|
||||
|
||||
# constants
|
||||
|
||||
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
|
||||
|
||||
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def default(*args):
|
||||
for arg in args:
|
||||
if exists(arg):
|
||||
return arg() if callable(arg) else arg
|
||||
return None
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
# entropy
|
||||
|
||||
def log(t, eps=1e-5):
|
||||
return t.clamp(min=eps).log()
|
||||
|
||||
|
||||
def entropy(prob):
|
||||
return (-prob * log(prob)).sum(dim=-1)
|
||||
|
||||
|
||||
# class
|
||||
|
||||
class LFQ(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_size=None,
|
||||
dim=None,
|
||||
codebook_size=None,
|
||||
entropy_loss_weight=0.1,
|
||||
commitment_loss_weight=0.25,
|
||||
diversity_gamma=1.,
|
||||
straight_through_activation="identity",
|
||||
num_codebooks=1,
|
||||
keep_num_codebooks_dim=None,
|
||||
codebook_scale=1., # for residual LFQ, codebook scaled down by 2x at each layer
|
||||
rand_num_codebooks: Optional[List] = None,
|
||||
sampling_rate=16000,
|
||||
encoder_hop_length=640,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# some assert validations
|
||||
dim = input_size
|
||||
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
|
||||
assert not exists(codebook_size) or log2(
|
||||
codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
|
||||
|
||||
codebook_size = default(codebook_size, lambda: 2 ** dim)
|
||||
codebook_dim = int(log2(codebook_size))
|
||||
|
||||
codebook_dims = codebook_dim * num_codebooks
|
||||
dim = default(dim, codebook_dims)
|
||||
|
||||
has_projections = dim != codebook_dims
|
||||
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
|
||||
self.has_projections = has_projections
|
||||
|
||||
self.dim = dim
|
||||
self.codebook_dim = codebook_dim
|
||||
self.num_codebooks = num_codebooks
|
||||
|
||||
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
||||
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
||||
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
||||
|
||||
# straight through activation
|
||||
|
||||
if straight_through_activation == "identity":
|
||||
self.activation = nn.Identity()
|
||||
elif straight_through_activation == "tanh":
|
||||
self.activation = nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError("Unsupported activation type, only 'tanh' and 'identity' are supported")
|
||||
|
||||
# entropy aux loss related weights
|
||||
|
||||
self.diversity_gamma = diversity_gamma
|
||||
self.entropy_loss_weight = entropy_loss_weight
|
||||
|
||||
# codebook scale
|
||||
|
||||
self.codebook_scale = codebook_scale
|
||||
|
||||
# commitment loss
|
||||
|
||||
self.commitment_loss_weight = commitment_loss_weight
|
||||
|
||||
# for no auxiliary loss, during inference
|
||||
|
||||
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
||||
self.register_buffer('zero', torch.tensor(0.), persistent=False)
|
||||
|
||||
# codes
|
||||
|
||||
all_codes = torch.arange(codebook_size)
|
||||
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
||||
codebook = self.bits_to_codes(bits)
|
||||
|
||||
self.register_buffer('codebook', codebook, persistent=False)
|
||||
self.rand_num_codebooks = rand_num_codebooks
|
||||
self.sampling_rate = sampling_rate
|
||||
self.encoder_hop_length = encoder_hop_length
|
||||
|
||||
def output_size(self):
|
||||
return self.dim
|
||||
|
||||
def bits_to_codes(self, bits):
|
||||
return bits * self.codebook_scale * 2 - self.codebook_scale
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.codebook.dtype
|
||||
|
||||
def indices_to_codes(
|
||||
self,
|
||||
indices,
|
||||
project_out=True
|
||||
):
|
||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
||||
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, '... -> ... 1')
|
||||
|
||||
# indices to codes, which are bits of either -1 or 1
|
||||
|
||||
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
|
||||
|
||||
codes = self.bits_to_codes(bits)
|
||||
|
||||
codes = rearrange(codes, '... c d -> ... (c d)')
|
||||
|
||||
# whether to project codes out to original dimensions
|
||||
# if the input feature dimensions were not log2(codebook size)
|
||||
|
||||
if project_out:
|
||||
codes = self.project_out(codes)
|
||||
|
||||
# rearrange codes back to original shape
|
||||
|
||||
if is_img_or_video:
|
||||
codes = rearrange(codes, 'b ... d -> b d ...')
|
||||
|
||||
return codes
|
||||
|
||||
def keep_first_nq_codes(self, x, nq=None):
|
||||
if nq is None or nq >= self.num_codebooks:
|
||||
return x
|
||||
inv_p = 1.0 / (nq / self.num_codebooks)
|
||||
x[:, :, nq:] = 0
|
||||
x[:, :, :nq] = x[:, :, :nq] * inv_p
|
||||
|
||||
return x
|
||||
|
||||
def random_dropout_codes(self, inputs):
|
||||
x = torch.clone(inputs)
|
||||
rand_num = random.choice(self.rand_num_codebooks)
|
||||
return self.keep_first_nq_codes(x, nq=rand_num)
|
||||
|
||||
def cal_num_quant(self, bite_width):
|
||||
frame_rate = self.sampling_rate / self.encoder_hop_length
|
||||
nq = bite_width / frame_rate / self.codebook_dim
|
||||
return nq
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
inv_temperature=100.,
|
||||
return_loss_breakdown=False,
|
||||
mask=None,
|
||||
bite_width=None,
|
||||
):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
n - sequence (or flattened spatial dimensions)
|
||||
d - feature dimension, which is also log2(codebook size)
|
||||
c - number of codebook dim
|
||||
"""
|
||||
|
||||
is_img_or_video = x.ndim >= 4
|
||||
|
||||
# standardize image or video into (batch, seq, dimension)
|
||||
|
||||
if is_img_or_video:
|
||||
x = rearrange(x, 'b d ... -> b ... d')
|
||||
x, ps = pack_one(x, 'b * d')
|
||||
|
||||
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
||||
|
||||
x = self.project_in(x)
|
||||
x = self.activation(x)
|
||||
|
||||
# split out number of codebooks
|
||||
|
||||
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
||||
|
||||
# quantize by eq 3.
|
||||
|
||||
original_input = x
|
||||
|
||||
codebook_value = torch.ones_like(x) * self.codebook_scale
|
||||
# do quantization
|
||||
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
||||
|
||||
# use straight-through gradients (optionally with custom activation fn) if training
|
||||
|
||||
if self.training:
|
||||
x = x + (quantized - x).detach()
|
||||
else:
|
||||
x = quantized
|
||||
|
||||
# calculate indices
|
||||
|
||||
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
||||
|
||||
# entropy aux loss
|
||||
|
||||
if self.training:
|
||||
# the same as euclidean distance up to a constant
|
||||
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
|
||||
|
||||
prob = (-distance * inv_temperature).softmax(dim=-1)
|
||||
|
||||
per_sample_entropy = entropy(prob).mean()
|
||||
|
||||
# account for mask
|
||||
|
||||
if exists(mask):
|
||||
prob = prob[mask]
|
||||
|
||||
# distribution over all available tokens in the batch
|
||||
|
||||
avg_prob = reduce(prob, '... c d -> c d', 'mean')
|
||||
codebook_entropy = entropy(avg_prob).mean()
|
||||
|
||||
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
||||
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
||||
|
||||
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
||||
else:
|
||||
# if not training, just return dummy 0
|
||||
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
||||
|
||||
# commit loss
|
||||
|
||||
if self.training:
|
||||
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none')
|
||||
|
||||
if exists(mask):
|
||||
commit_loss = commit_loss[mask]
|
||||
|
||||
commit_loss = commit_loss.mean()
|
||||
else:
|
||||
commit_loss = self.zero
|
||||
|
||||
# randomly dropout codebooks to fit varying bite width
|
||||
if self.training and self.rand_num_codebooks is not None:
|
||||
x = self.random_dropout_codes(x)
|
||||
if bite_width is not None:
|
||||
x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width))
|
||||
|
||||
# merge back codebook dim
|
||||
|
||||
x = rearrange(x, 'b n c d -> b n (c d)')
|
||||
|
||||
# project out to feature dimension if needed
|
||||
|
||||
x = self.project_out(x)
|
||||
|
||||
# reconstitute image or video dimensions
|
||||
|
||||
if is_img_or_video:
|
||||
x = unpack_one(x, ps, 'b * d')
|
||||
x = rearrange(x, 'b ... d -> b d ...')
|
||||
|
||||
indices = unpack_one(indices, ps, 'b * c')
|
||||
|
||||
# whether to remove single codebook dim
|
||||
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, '... 1 -> ...')
|
||||
|
||||
# complete aux loss
|
||||
|
||||
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
||||
|
||||
if not return_loss_breakdown:
|
||||
return x, indices, aux_loss, None
|
||||
|
||||
return x, indices, aux_loss, dict(
|
||||
per_sample_entropy=per_sample_entropy,
|
||||
codebook_entropy=codebook_entropy,
|
||||
commit_loss=commit_loss
|
||||
)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
bandwidth: int = None,
|
||||
):
|
||||
x, indices, _, _ = self.forward(x, bite_width=bandwidth)
|
||||
|
||||
return x, indices, None
|
||||
|
||||
|
||||
class ScalableLFQ(LFQ):
|
||||
def __init__(self, *, input_size=None, dim=None, codebook_size=None, entropy_loss_weight=0.1,
|
||||
commitment_loss_weight=0.25, diversity_gamma=1., straight_through_activation=nn.Identity(),
|
||||
num_codebooks=1, keep_num_codebooks_dim=None, codebook_scale=1.,
|
||||
rand_num_codebooks: Optional[List] = None, sampling_rate=16000, hop_length=640, **kwargs):
|
||||
super().__init__(input_size=input_size, dim=dim, codebook_size=codebook_size,
|
||||
entropy_loss_weight=entropy_loss_weight, commitment_loss_weight=commitment_loss_weight,
|
||||
diversity_gamma=diversity_gamma, straight_through_activation=straight_through_activation,
|
||||
num_codebooks=num_codebooks, keep_num_codebooks_dim=keep_num_codebooks_dim,
|
||||
codebook_scale=codebook_scale, rand_num_codebooks=rand_num_codebooks,
|
||||
sampling_rate=sampling_rate, hop_length=hop_length)
|
||||
codebook_alpha_conf = kwargs.get("codebook_alpha_conf", None)
|
||||
self.init_codebook_alpha(codebook_alpha_conf)
|
||||
|
||||
def init_codebook_alpha(self, codebook_alpha_conf: dict):
|
||||
assert codebook_alpha_conf is not None, "codebook_alpha_conf cannot be None"
|
||||
name = codebook_alpha_conf.get("name", "constant")
|
||||
if name == "constant":
|
||||
alphas = codebook_alpha_conf.get("alphas", [1.0] * self.num_codebooks)
|
||||
assert len(alphas) == self.num_codebooks, \
|
||||
f"the length of codebook alphas {len(alphas)} " \
|
||||
f"must match num_codebooks {self.num_codebooks}."
|
||||
alphas = np.array(alphas)
|
||||
elif name == "exponential":
|
||||
temp = codebook_alpha_conf.get("temp", 8.0)
|
||||
alphas = np.exp(-np.arange(0, self.num_codebooks) / temp)
|
||||
else:
|
||||
raise TypeError(f"Unknown codebook alpha type {name}.")
|
||||
codebook_alpha = torch.tensor(alphas/alphas.sum(), dtype=torch.float32).reshape(1, 1, -1, 1)
|
||||
self.register_buffer("codebook_alpha", codebook_alpha)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
inv_temperature=100.,
|
||||
return_loss_breakdown=False,
|
||||
mask=None,
|
||||
bite_width=None
|
||||
):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
n - sequence (or flattened spatial dimensions)
|
||||
d - feature dimension, which is also log2(codebook size)
|
||||
c - number of codebook dim
|
||||
"""
|
||||
|
||||
is_img_or_video = x.ndim >= 4
|
||||
|
||||
# standardize image or video into (batch, seq, dimension)
|
||||
|
||||
if is_img_or_video:
|
||||
x = rearrange(x, 'b d ... -> b ... d')
|
||||
x, ps = pack_one(x, 'b * d')
|
||||
|
||||
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
||||
|
||||
x = self.project_in(x)
|
||||
|
||||
# split out number of codebooks
|
||||
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
||||
|
||||
# quantize by eq 3.
|
||||
original_input = x
|
||||
codebook_value = torch.ones_like(x) * self.codebook_scale
|
||||
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
||||
|
||||
# use straight-through gradients (optionally with custom activation fn) if training
|
||||
if self.training:
|
||||
x = self.activation(x)
|
||||
x = x + (quantized - x).detach()
|
||||
else:
|
||||
x = quantized
|
||||
|
||||
# calculate indices
|
||||
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
||||
|
||||
# entropy aux loss
|
||||
if self.training:
|
||||
# the same as euclidean distance up to a constant
|
||||
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
|
||||
prob = (-distance * inv_temperature).softmax(dim=-1)
|
||||
per_sample_entropy = entropy(prob).mean()
|
||||
|
||||
# account for mask
|
||||
if exists(mask):
|
||||
prob = prob[mask]
|
||||
|
||||
# distribution over all available tokens in the batch
|
||||
avg_prob = reduce(prob, '... c d -> c d', 'mean')
|
||||
codebook_entropy = entropy(avg_prob).mean()
|
||||
|
||||
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
||||
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
||||
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
||||
else:
|
||||
# if not training, just return dummy 0
|
||||
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
||||
|
||||
# commit loss
|
||||
if self.training:
|
||||
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none')
|
||||
|
||||
if exists(mask):
|
||||
commit_loss = commit_loss[mask]
|
||||
|
||||
commit_loss = commit_loss.mean()
|
||||
else:
|
||||
commit_loss = self.zero
|
||||
|
||||
# randomly dropout codebooks to fit s bite width
|
||||
if self.training and self.rand_num_codebooks is not None:
|
||||
x = self.random_dropout_codes(x)
|
||||
if bite_width is not None:
|
||||
x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width))
|
||||
|
||||
x = x * self.codebook_alpha
|
||||
|
||||
# merge back codebook dim
|
||||
x = rearrange(x, 'b n c d -> b n (c d)')
|
||||
|
||||
# project out to feature dimension if needed
|
||||
|
||||
x = self.project_out(x)
|
||||
|
||||
# reconstitute image or video dimensions
|
||||
|
||||
if is_img_or_video:
|
||||
x = unpack_one(x, ps, 'b * d')
|
||||
x = rearrange(x, 'b ... d -> b d ...')
|
||||
|
||||
indices = unpack_one(indices, ps, 'b * c')
|
||||
|
||||
# whether to remove single codebook dim
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, '... 1 -> ...')
|
||||
|
||||
# complete aux loss
|
||||
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
||||
|
||||
if not return_loss_breakdown:
|
||||
return x, indices, aux_loss, None
|
||||
|
||||
return x, indices, aux_loss, dict(
|
||||
per_sample_entropy=per_sample_entropy,
|
||||
codebook_entropy=codebook_entropy,
|
||||
commit_loss=commit_loss
|
||||
)
|
||||
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
from funasr.modules.quantization.vq import QuantizedResult, ResidualVectorQuantizer
|
||||
291
funasr/models/sense_voice/quantizer/quantization/ac.py
Normal file
291
funasr/models/sense_voice/quantizer/quantization/ac.py
Normal file
@ -0,0 +1,291 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Arithmetic coder."""
|
||||
|
||||
import io
|
||||
import math
|
||||
import random
|
||||
import typing as tp
|
||||
import torch
|
||||
|
||||
from funasr.models.sense_voice.quantizer.quantization.binary import BitPacker, BitUnpacker
|
||||
|
||||
|
||||
def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
|
||||
roundoff: float = 1e-8, min_range: int = 2,
|
||||
check: bool = True) -> torch.Tensor:
|
||||
"""Turn the given PDF into a quantized CDF that splits
|
||||
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
||||
to the PDF.
|
||||
|
||||
Args:
|
||||
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
|
||||
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
||||
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
||||
roundoff (float): will round the pdf up to that level to remove difference coming
|
||||
from e.g. evaluating the Language Model on different architectures.
|
||||
min_range (int): minimum range width. Should always be at least 2 for numerical
|
||||
stability. Use this to avoid pathological behavior is a value
|
||||
that is expected to be rare actually happens in real life.
|
||||
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
||||
"""
|
||||
pdf = pdf.detach()
|
||||
if roundoff:
|
||||
pdf = (pdf / roundoff).floor() * roundoff
|
||||
# interpolate with uniform distribution to achieve desired minimum probability.
|
||||
total_range = 2 ** total_range_bits
|
||||
cardinality = len(pdf)
|
||||
alpha = min_range * cardinality / total_range
|
||||
assert alpha <= 1, "you must reduce min_range"
|
||||
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
||||
ranges += min_range
|
||||
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
||||
if min_range < 2:
|
||||
raise ValueError("min_range must be at least 2.")
|
||||
if check:
|
||||
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
|
||||
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
|
||||
raise ValueError("You must increase your total_range_bits.")
|
||||
return quantized_cdf
|
||||
|
||||
|
||||
class ArithmeticCoder:
|
||||
"""ArithmeticCoder,
|
||||
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
||||
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
||||
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
||||
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
||||
sequence `(s_t)` by doing the following:
|
||||
|
||||
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
||||
2) For each time step t, split the current range into contiguous chunks,
|
||||
one for each possible outcome, with size roughly proportional to `p`.
|
||||
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
||||
would be `{[0, 2], [3, 3]}`.
|
||||
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
||||
4) When done encoding all the values, just select any value remaining in the range.
|
||||
|
||||
You will notice that this procedure can fail: for instance if at any point in time
|
||||
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
||||
possible outcome. Intuitively, the more likely a value is, the less the range width
|
||||
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
||||
coding scheme, likely outcomes would take fewer bits, and more of them can be coded
|
||||
with a fixed budget.
|
||||
|
||||
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
||||
when the current range decreases below a given limit (given by `total_range_bits`), without
|
||||
having to redo all the computations. If we encode mostly likely values, we will seldom
|
||||
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
||||
|
||||
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
||||
code works for any sequence `(p_t)` possibly different for each timestep.
|
||||
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
||||
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
||||
|
||||
Args:
|
||||
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
||||
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
||||
Any time the current range width fall under this limit, new bits will
|
||||
be injected to rescale the initial range.
|
||||
"""
|
||||
|
||||
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
||||
assert total_range_bits <= 30
|
||||
self.total_range_bits = total_range_bits
|
||||
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
||||
self.low: int = 0
|
||||
self.high: int = 0
|
||||
self.max_bit: int = -1
|
||||
self._dbg: tp.List[tp.Any] = []
|
||||
self._dbg2: tp.List[tp.Any] = []
|
||||
|
||||
@property
|
||||
def delta(self) -> int:
|
||||
"""Return the current range width."""
|
||||
return self.high - self.low + 1
|
||||
|
||||
def _flush_common_prefix(self):
|
||||
# If self.low and self.high start with the sames bits,
|
||||
# those won't change anymore as we always just increase the range
|
||||
# by powers of 2, and we can flush them out to the bit stream.
|
||||
assert self.high >= self.low, (self.low, self.high)
|
||||
assert self.high < 2 ** (self.max_bit + 1)
|
||||
while self.max_bit >= 0:
|
||||
b1 = self.low >> self.max_bit
|
||||
b2 = self.high >> self.max_bit
|
||||
if b1 == b2:
|
||||
self.low -= (b1 << self.max_bit)
|
||||
self.high -= (b1 << self.max_bit)
|
||||
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
||||
assert self.low >= 0
|
||||
self.max_bit -= 1
|
||||
self.packer.push(b1)
|
||||
else:
|
||||
break
|
||||
|
||||
def push(self, symbol: int, quantized_cdf: torch.Tensor):
|
||||
"""Push the given symbol on the stream, flushing out bits
|
||||
if possible.
|
||||
|
||||
Args:
|
||||
symbol (int): symbol to encode with the AC.
|
||||
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
||||
to build this from your pdf estimate.
|
||||
"""
|
||||
while self.delta < 2 ** self.total_range_bits:
|
||||
self.low *= 2
|
||||
self.high = self.high * 2 + 1
|
||||
self.max_bit += 1
|
||||
|
||||
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
||||
range_high = quantized_cdf[symbol].item() - 1
|
||||
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
|
||||
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
|
||||
assert self.low <= self.high
|
||||
self.high = self.low + effective_high
|
||||
self.low = self.low + effective_low
|
||||
assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
|
||||
self._dbg.append((self.low, self.high))
|
||||
self._dbg2.append((self.low, self.high))
|
||||
self._flush_common_prefix()
|
||||
assert self.low <= self.high
|
||||
assert self.max_bit >= -1
|
||||
assert self.max_bit <= 61, self.max_bit
|
||||
|
||||
def flush(self):
|
||||
"""Flush the remaining information to the stream.
|
||||
"""
|
||||
while self.max_bit >= 0:
|
||||
b1 = (self.low >> self.max_bit) & 1
|
||||
self.packer.push(b1)
|
||||
self.max_bit -= 1
|
||||
self.packer.flush()
|
||||
|
||||
|
||||
class ArithmeticDecoder:
|
||||
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
||||
|
||||
Note that this must be called with **exactly** the same parameters and sequence
|
||||
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
||||
|
||||
If the AC encoder current range is [L, H], with `L` and `H` having the same common
|
||||
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
||||
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
||||
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
||||
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
||||
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
||||
and we will need to read new bits from the stream and repeat the process.
|
||||
|
||||
"""
|
||||
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
||||
self.total_range_bits = total_range_bits
|
||||
self.low: int = 0
|
||||
self.high: int = 0
|
||||
self.current: int = 0
|
||||
self.max_bit: int = -1
|
||||
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
||||
# Following is for debugging
|
||||
self._dbg: tp.List[tp.Any] = []
|
||||
self._dbg2: tp.List[tp.Any] = []
|
||||
self._last: tp.Any = None
|
||||
|
||||
@property
|
||||
def delta(self) -> int:
|
||||
return self.high - self.low + 1
|
||||
|
||||
def _flush_common_prefix(self):
|
||||
# Given the current range [L, H], if both have a common prefix,
|
||||
# we know we can remove it from our representation to avoid handling large numbers.
|
||||
while self.max_bit >= 0:
|
||||
b1 = self.low >> self.max_bit
|
||||
b2 = self.high >> self.max_bit
|
||||
if b1 == b2:
|
||||
self.low -= (b1 << self.max_bit)
|
||||
self.high -= (b1 << self.max_bit)
|
||||
self.current -= (b1 << self.max_bit)
|
||||
assert self.high >= self.low
|
||||
assert self.low >= 0
|
||||
self.max_bit -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
|
||||
"""Pull a symbol, reading as many bits from the stream as required.
|
||||
This returns `None` when the stream has been exhausted.
|
||||
|
||||
Args:
|
||||
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
||||
to build this from your pdf estimate. This must be **exactly**
|
||||
the same cdf as the one used at encoding time.
|
||||
"""
|
||||
while self.delta < 2 ** self.total_range_bits:
|
||||
bit = self.unpacker.pull()
|
||||
if bit is None:
|
||||
return None
|
||||
self.low *= 2
|
||||
self.high = self.high * 2 + 1
|
||||
self.current = self.current * 2 + bit
|
||||
self.max_bit += 1
|
||||
|
||||
def bin_search(low_idx: int, high_idx: int):
|
||||
# Binary search is not just for coding interviews :)
|
||||
if high_idx < low_idx:
|
||||
raise RuntimeError("Binary search failed")
|
||||
mid = (low_idx + high_idx) // 2
|
||||
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
||||
range_high = quantized_cdf[mid].item() - 1
|
||||
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
|
||||
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
|
||||
low = effective_low + self.low
|
||||
high = effective_high + self.low
|
||||
if self.current >= low:
|
||||
if self.current <= high:
|
||||
return mid, low, high, self.current
|
||||
else:
|
||||
return bin_search(mid + 1, high_idx)
|
||||
else:
|
||||
return bin_search(low_idx, mid - 1)
|
||||
|
||||
self._last = (self.low, self.high, self.current, self.max_bit)
|
||||
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
||||
self._dbg.append((self.low, self.high, self.current))
|
||||
self._flush_common_prefix()
|
||||
self._dbg2.append((self.low, self.high, self.current))
|
||||
|
||||
return sym
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(1234)
|
||||
random.seed(1234)
|
||||
for _ in range(4):
|
||||
pdfs = []
|
||||
cardinality = random.randrange(4000)
|
||||
steps = random.randrange(100, 500)
|
||||
fo = io.BytesIO()
|
||||
encoder = ArithmeticCoder(fo)
|
||||
symbols = []
|
||||
for step in range(steps):
|
||||
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
||||
pdfs.append(pdf)
|
||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
||||
symbol = torch.multinomial(pdf, 1).item()
|
||||
symbols.append(symbol)
|
||||
encoder.push(symbol, q_cdf)
|
||||
encoder.flush()
|
||||
|
||||
fo.seek(0)
|
||||
decoder = ArithmeticDecoder(fo)
|
||||
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
||||
decoded_symbol = decoder.pull(q_cdf)
|
||||
assert decoded_symbol == symbol, idx
|
||||
assert decoder.pull(torch.zeros(1)) is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
157
funasr/models/sense_voice/quantizer/quantization/binary.py
Normal file
157
funasr/models/sense_voice/quantizer/quantization/binary.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import typing as tp
|
||||
import ctypes
|
||||
|
||||
# format is `ECDC` magic code, followed by the header size as uint32.
|
||||
# Then an uint8 indicates the protocol version (0.)
|
||||
# The header is then provided as json and should contain all required
|
||||
# informations for decoding. A raw stream of bytes is then provided
|
||||
# and should be interpretable using the json header.
|
||||
_encodec_header_struct = struct.Struct('!4sBI')
|
||||
_ENCODEC_MAGIC = b'ECDC'
|
||||
|
||||
|
||||
def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
|
||||
meta_dumped = json.dumps(metadata).encode('utf-8')
|
||||
version = 0
|
||||
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped))
|
||||
fo.write(header)
|
||||
fo.write(meta_dumped)
|
||||
fo.flush()
|
||||
|
||||
|
||||
def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
|
||||
buf = b""
|
||||
while len(buf) < size:
|
||||
new_buf = fo.read(size)
|
||||
if not new_buf:
|
||||
raise EOFError("Impossible to read enough data from the stream, "
|
||||
f"{size} bytes remaining.")
|
||||
buf += new_buf
|
||||
size -= len(new_buf)
|
||||
return buf
|
||||
|
||||
|
||||
def read_ecdc_header(fo: tp.IO[bytes]):
|
||||
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
|
||||
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
|
||||
if magic != _ENCODEC_MAGIC:
|
||||
raise ValueError("File is not in ECDC format.")
|
||||
if version != 0:
|
||||
raise ValueError("Version not supported.")
|
||||
meta_bytes = _read_exactly(fo, meta_size)
|
||||
return json.loads(meta_bytes.decode('utf-8'))
|
||||
|
||||
|
||||
class BitPacker:
|
||||
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
|
||||
Note that for some bandwidth (1.5, 3), the codebook representation
|
||||
will not cover an integer number of bytes.
|
||||
|
||||
Args:
|
||||
bits (int): number of bits per value that will be pushed.
|
||||
fo (IO[bytes]): file-object to push the bytes to.
|
||||
"""
|
||||
def __init__(self, bits: int, fo: tp.IO[bytes]):
|
||||
self._current_value = 0
|
||||
self._current_bits = 0
|
||||
self.bits = bits
|
||||
self.fo = fo
|
||||
|
||||
def push(self, value: int):
|
||||
"""Push a new value to the stream. This will immediately
|
||||
write as many uint8 as possible to the underlying file-object."""
|
||||
value = bin(ctypes.c_uint32.from_buffer(ctypes.c_float(value)).value)
|
||||
value = ctypes.c_int.from_buffer(ctypes.c_uint32(int(value, 2))).value
|
||||
self._current_value += (value << self._current_bits)
|
||||
self._current_bits += self.bits
|
||||
while self._current_bits >= 8:
|
||||
lower_8bits = self._current_value & 0xff
|
||||
self._current_bits -= 8
|
||||
self._current_value >>= 8
|
||||
self.fo.write(bytes([lower_8bits]))
|
||||
|
||||
def flush(self):
|
||||
"""Flushes the remaining partial uint8, call this at the end
|
||||
of the stream to encode."""
|
||||
if self._current_bits:
|
||||
self.fo.write(bytes([self._current_value]))
|
||||
self._current_value = 0
|
||||
self._current_bits = 0
|
||||
self.fo.flush()
|
||||
|
||||
|
||||
class BitUnpacker:
|
||||
"""BitUnpacker does the opposite of `BitPacker`.
|
||||
|
||||
Args:
|
||||
bits (int): number of bits of the values to decode.
|
||||
fo (IO[bytes]): file-object to push the bytes to.
|
||||
"""
|
||||
def __init__(self, bits: int, fo: tp.IO[bytes]):
|
||||
self.bits = bits
|
||||
self.fo = fo
|
||||
self._mask = (1 << bits) - 1
|
||||
self._current_value = 0
|
||||
self._current_bits = 0
|
||||
|
||||
def pull(self) -> tp.Optional[int]:
|
||||
"""
|
||||
Pull a single value from the stream, potentially reading some
|
||||
extra bytes from the underlying file-object.
|
||||
Returns `None` when reaching the end of the stream.
|
||||
"""
|
||||
while self._current_bits < self.bits:
|
||||
buf = self.fo.read(1)
|
||||
if not buf:
|
||||
return None
|
||||
character = buf[0]
|
||||
self._current_value += character << self._current_bits
|
||||
self._current_bits += 8
|
||||
|
||||
out = self._current_value & self._mask
|
||||
self._current_value >>= self.bits
|
||||
self._current_bits -= self.bits
|
||||
# out = ctypes.c_float.from_buffer(ctypes.c_uint32(int(out, 2))).value
|
||||
return out
|
||||
|
||||
|
||||
def test():
|
||||
import torch
|
||||
torch.manual_seed(1234)
|
||||
for rep in range(4):
|
||||
length: int = torch.randint(10, 2_000, (1,)).item()
|
||||
bits: int = torch.randint(1, 16, (1,)).item()
|
||||
tokens: tp.List[int] = torch.randint(2 ** bits, (length,)).tolist()
|
||||
rebuilt: tp.List[int] = []
|
||||
buf = io.BytesIO()
|
||||
packer = BitPacker(bits, buf)
|
||||
for token in tokens:
|
||||
packer.push(token)
|
||||
packer.flush()
|
||||
buf.seek(0)
|
||||
unpacker = BitUnpacker(bits, buf)
|
||||
while True:
|
||||
value = unpacker.pull()
|
||||
if value is None:
|
||||
break
|
||||
rebuilt.append(value)
|
||||
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
|
||||
# The flushing mechanism might lead to "ghost" values at the end of the stream.
|
||||
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits)
|
||||
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
|
||||
assert a == b, (idx, a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
674
funasr/models/sense_voice/quantizer/quantization/core_vq.py
Normal file
674
funasr/models/sense_voice/quantizer/quantization/core_vq.py
Normal file
@ -0,0 +1,674 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# This implementation is inspired from
|
||||
# https://github.com/lucidrains/vector-quantize-pytorch
|
||||
# which is released under MIT License. Hereafter, the original license:
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Phil Wang
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""Core vector quantization implementation."""
|
||||
import logging
|
||||
import typing as tp
|
||||
from random import randrange
|
||||
|
||||
import numpy as np
|
||||
from einops import rearrange, repeat
|
||||
from math import ceil
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import random
|
||||
|
||||
from funasr.models.sense_voice.quantizer.quantization import distrib
|
||||
|
||||
def round_up_multiple(num, mult):
|
||||
return ceil(num / mult) * mult
|
||||
|
||||
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay: float):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
||||
|
||||
|
||||
def uniform_init(*shape: int):
|
||||
t = torch.empty(shape)
|
||||
nn.init.kaiming_uniform_(t)
|
||||
return t
|
||||
|
||||
|
||||
def sample_vectors(samples, num: int):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
# device = samples.device
|
||||
# samples = samples.cpu()
|
||||
dim, dtype = samples.shape[-1], samples.dtype
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
# diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
||||
# means, "c d -> () c d"
|
||||
# )
|
||||
# dists = -(diffs ** 2).sum(dim=-1)
|
||||
dists = -(
|
||||
samples.pow(2).sum(1, keepdim=True)
|
||||
- 2 * torch.matmul(samples, means.t())
|
||||
+ means.t().pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
del dists
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
# means = means.to(device)
|
||||
return means, bins
|
||||
|
||||
|
||||
def preprocess(x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
return x
|
||||
|
||||
|
||||
def postprocess_emb(embed_ind, shape):
|
||||
return embed_ind.view(*shape[:-1])
|
||||
|
||||
|
||||
class EuclideanCodebook(nn.Module):
|
||||
"""Codebook with Euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension.
|
||||
codebook_size (int): Codebook size.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
||||
If set to true, run the k-means algorithm on the first training batch and use
|
||||
the learned centroids as initialization.
|
||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
kmeans_init: int = False,
|
||||
kmeans_iters: int = 10,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.epsilon = epsilon
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
|
||||
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
||||
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
self.training = True
|
||||
|
||||
def init_embed_(self, data):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed.clone())
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
distrib.broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
if self.threshold_ema_dead_code == 0:
|
||||
return
|
||||
|
||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
distrib.broadcast_tensors(self.buffers())
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
def dequantize(self, embed_ind):
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
return quantize
|
||||
|
||||
def encode(self, x):
|
||||
shape = x.shape
|
||||
# pre-process
|
||||
x = preprocess(x)
|
||||
# quantize
|
||||
embed_ind = self.quantize(x)
|
||||
# post-process
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self.dequantize(embed_ind)
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
x = preprocess(x)
|
||||
|
||||
self.init_embed_(x)
|
||||
|
||||
embed_ind = self.quantize(x)
|
||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
if self.training:
|
||||
# We do the expiry of code at that point as buffers are in sync
|
||||
# and all the workers will take the same decision.
|
||||
self.expire_codes_(x)
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
# Note: after ema update, there is a very small difference between codebooks on GPUs.
|
||||
# The impact can be very small, ignore it.
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
class SimpleEuclideanCodebook(nn.Module):
|
||||
"""Simple Codebook with Euclidean distance.
|
||||
Using gradient to update code embeddings instead of EMA.
|
||||
Args:
|
||||
dim (int): Dimension.
|
||||
codebook_size (int): Codebook size.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
||||
If set to true, run the k-means algorithm on the first training batch and use
|
||||
the learned centroids as initialization.
|
||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
kmeans_init: tp.Union[bool, torch.Tensor] = False,
|
||||
kmeans_iters: int = 10,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kmeans_init, bool):
|
||||
if kmeans_init:
|
||||
embed = torch.zeros(codebook_size, dim)
|
||||
inited = False
|
||||
else:
|
||||
embed = uniform_init(codebook_size, dim)
|
||||
inited = True
|
||||
else:
|
||||
embed = kmeans_init
|
||||
inited = True
|
||||
self.codebook_size = codebook_size
|
||||
self.kmeans_iters = kmeans_iters
|
||||
|
||||
self.embed = nn.Embedding(codebook_size, dim)
|
||||
self.embed.weight.data.copy_(embed)
|
||||
# self.register_parameter("embed", nn.Parameter(embed, requires_grad=True))
|
||||
self.register_buffer("inited", torch.Tensor([inited]))
|
||||
self.training = True
|
||||
|
||||
def init_embed_(self, data):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
distrib.broadcast_tensors(self.buffers())
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.weight.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
def dequantize(self, embed_ind):
|
||||
quantize = self.embed(embed_ind)
|
||||
return quantize
|
||||
|
||||
def encode(self, x):
|
||||
shape = x.shape
|
||||
# pre-process
|
||||
x = preprocess(x)
|
||||
# quantize
|
||||
embed_ind = self.quantize(x)
|
||||
# post-process
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self.dequantize(embed_ind)
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
x = preprocess(x)
|
||||
|
||||
self.init_embed_(x)
|
||||
|
||||
embed_ind = self.quantize(x)
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
class VectorQuantization(nn.Module):
|
||||
"""Vector quantization implementation.
|
||||
Currently, supports only euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension
|
||||
codebook_size (int): Codebook size
|
||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
commitment_weight: float = 1.,
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
||||
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
|
||||
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
||||
decay=decay, epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code)
|
||||
self.codebook_size = codebook_size
|
||||
self.training = True
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook.embed
|
||||
|
||||
def encode(self, x):
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
embed_in = self._codebook.encode(x)
|
||||
return embed_in
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self._codebook.decode(embed_ind)
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
|
||||
quantize, embed_ind = self._codebook(x)
|
||||
|
||||
if self.training:
|
||||
quantize = x + (quantize - x).detach()
|
||||
|
||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
||||
|
||||
if self.training:
|
||||
if self.commitment_weight > 0:
|
||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
||||
loss = loss + commit_loss * self.commitment_weight
|
||||
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
class SimpleVectorQuantization(nn.Module):
|
||||
"""Vector quantization implementation with SimpleEuclideanCodebook.
|
||||
Currently, supports only euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension
|
||||
codebook_size (int): Codebook size
|
||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
epsilon: float = 1e-5,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
commitment_weight: float = 0.25,
|
||||
codebook_weight: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
||||
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
self.codebook_weight = codebook_weight
|
||||
logging.info(f"commitment_weight: {commitment_weight}, codebook_weight: {codebook_weight}.")
|
||||
|
||||
self._codebook = SimpleEuclideanCodebook(
|
||||
dim=_codebook_dim, codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters
|
||||
)
|
||||
self.codebook_size = codebook_size
|
||||
self.training = True
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook.embed
|
||||
|
||||
def encode(self, x):
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
embed_in = self._codebook.encode(x)
|
||||
return embed_in
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self._codebook.decode(embed_ind)
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
|
||||
quantize, embed_ind = self._codebook(x)
|
||||
|
||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
||||
|
||||
if self.training:
|
||||
# commit loss for codebook
|
||||
if self.codebook_weight > 0:
|
||||
codebook_loss = F.mse_loss(quantize, x.detach())
|
||||
loss = loss + codebook_loss * self.codebook_weight
|
||||
|
||||
# commit loss for encoder
|
||||
if self.commitment_weight > 0:
|
||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
||||
loss = loss + commit_loss * self.commitment_weight
|
||||
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
class ResidualVectorQuantization(nn.Module):
|
||||
"""Residual vector quantization implementation.
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
"""
|
||||
def __init__(self, *,
|
||||
num_quantizers,
|
||||
quantize_dropout: bool = False,
|
||||
rand_num_quant: tp.Optional[tp.List] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
self.quantize_dropout = quantize_dropout
|
||||
self.rand_num_quant = rand_num_quant
|
||||
|
||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
device = x.device
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
all_sub_quants = []
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
||||
|
||||
null_indices_shape = (x.shape[0], x.shape[2])
|
||||
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
||||
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
||||
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
||||
# dropout except the first quantizer
|
||||
if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index:
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
all_sub_quants.append(null_sub_quant)
|
||||
continue
|
||||
|
||||
quantized, indices, loss = layer(residual)
|
||||
residual = residual - quantized
|
||||
quantized_out = quantized_out + quantized
|
||||
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
all_sub_quants.append(quantized)
|
||||
|
||||
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
||||
return quantized_out, out_indices, out_losses, out_sub_quants
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
for layer in self.layers[:n_q]:
|
||||
indices = layer.encode(residual)
|
||||
quantized = layer.decode(indices)
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
||||
|
||||
|
||||
class SimpleResidualVectorQuantization(nn.Module):
|
||||
"""Simple Residual vector quantization with gradient to
|
||||
update codebook instead of EMA
|
||||
"""
|
||||
def __init__(self, *,
|
||||
num_quantizers,
|
||||
quantize_dropout: bool = False,
|
||||
rand_num_quant: tp.Optional[tp.List] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
kmeans_init = raw_kmeans_init = kwargs.pop('kmeans_init', True)
|
||||
if isinstance(kmeans_init, str):
|
||||
# use prepared kmeans init
|
||||
embed = np.load(kmeans_init)
|
||||
embed = torch.from_numpy(embed)
|
||||
if embed.dim() == 2:
|
||||
embed = embed.unsqueeze(0)
|
||||
kmeans_init = embed
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
SimpleVectorQuantization(
|
||||
kmeans_init=kmeans_init[i] if isinstance(kmeans_init, torch.Tensor) else kmeans_init,
|
||||
**kwargs
|
||||
) for i in range(num_quantizers)
|
||||
])
|
||||
kwargs["kmeans_init"] = raw_kmeans_init
|
||||
self.quantize_dropout = quantize_dropout
|
||||
self.rand_num_quant = rand_num_quant
|
||||
|
||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
device = x.device
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
all_sub_quants = []
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
||||
|
||||
null_indices_shape = (x.shape[0], x.shape[2])
|
||||
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
||||
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
||||
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
||||
# dropout except the first quantizer
|
||||
if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index:
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
all_sub_quants.append(null_sub_quant)
|
||||
continue
|
||||
|
||||
quantized, indices, loss = layer(residual)
|
||||
residual = residual - quantized
|
||||
quantized_out = quantized_out + quantized
|
||||
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
all_sub_quants.append(quantized)
|
||||
|
||||
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
||||
return quantized_out, out_indices, out_losses, out_sub_quants
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
for layer in self.layers[:n_q]:
|
||||
indices = layer.encode(residual)
|
||||
quantized = layer.decode(indices)
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
||||
513
funasr/models/sense_voice/quantizer/quantization/ddp_core_vq.py
Normal file
513
funasr/models/sense_voice/quantizer/quantization/ddp_core_vq.py
Normal file
@ -0,0 +1,513 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# This implementation is inspired from
|
||||
# https://github.com/lucidrains/vector-quantize-pytorch
|
||||
# which is released under MIT License. Hereafter, the original license:
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Phil Wang
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""Core vector quantization implementation."""
|
||||
import logging
|
||||
import random
|
||||
import typing as tp
|
||||
from random import randrange
|
||||
|
||||
import numpy as np
|
||||
from einops import rearrange, repeat
|
||||
from math import ceil
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from funasr.utils.hinter import hint_once
|
||||
|
||||
from funasr.models.sense_voice.quantizer.quantization import distrib
|
||||
|
||||
def round_up_multiple(num, mult):
|
||||
return ceil(num / mult) * mult
|
||||
|
||||
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay: float, mask=None):
|
||||
if mask is None:
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
else:
|
||||
mask = mask.float()
|
||||
new_avg = moving_avg * decay + new * (1 - decay)
|
||||
new_avg = mask * new_avg + (1 - mask) * moving_avg
|
||||
moving_avg.data.copy_(new_avg.data)
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
||||
|
||||
|
||||
def uniform_init(*shape: int):
|
||||
t = torch.empty(shape)
|
||||
nn.init.kaiming_uniform_(t)
|
||||
return t
|
||||
|
||||
|
||||
def sample_vectors(samples, num: int):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
# device = samples.device
|
||||
# samples = samples.cpu()
|
||||
dim, dtype = samples.shape[-1], samples.dtype
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
# diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
||||
# means, "c d -> () c d"
|
||||
# )
|
||||
# dists = -(diffs ** 2).sum(dim=-1)
|
||||
dists = -(
|
||||
samples.pow(2).sum(1, keepdim=True)
|
||||
- 2 * torch.matmul(samples, means.t())
|
||||
+ means.t().pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
del dists
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
# means = means.to(device)
|
||||
return means, bins
|
||||
|
||||
|
||||
def preprocess(x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
return x
|
||||
|
||||
|
||||
def postprocess_emb(embed_ind, shape):
|
||||
return embed_ind.view(*shape[:-1])
|
||||
|
||||
|
||||
class EuclideanCodebook(nn.Module):
|
||||
"""Codebook with Euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension.
|
||||
codebook_size (int): Codebook size.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
||||
If set to true, run the k-means algorithm on the first training batch and use
|
||||
the learned centroids as initialization.
|
||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
kmeans_init: int = False,
|
||||
kmeans_iters: int = 10,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
sparse_update: bool = False,
|
||||
normalized_input: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
self.codebook_size = codebook_size
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.epsilon = epsilon
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
|
||||
self.inited = None
|
||||
self.cluster_size = None
|
||||
self.embed = None
|
||||
self.embed_avg = None
|
||||
self.training = True
|
||||
self.sparse_update = sparse_update
|
||||
self.normalized_input = normalized_input
|
||||
|
||||
def init_embed_(self, data):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed.clone())
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
hint_once(f"threshold_ema_dead_code: {self.threshold_ema_dead_code}.", "threshold_ema_dead_code", rank=0)
|
||||
if self.threshold_ema_dead_code == 0:
|
||||
return
|
||||
|
||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
# sync buffers outside for efficiency
|
||||
# distrib.broadcast_tensors(self.buffers())
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
def dequantize(self, embed_ind):
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
return quantize
|
||||
|
||||
def encode(self, x, buffers):
|
||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
||||
|
||||
shape = x.shape
|
||||
# pre-process
|
||||
x = preprocess(x)
|
||||
# quantize
|
||||
embed_ind = self.quantize(x)
|
||||
# post-process
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind, buffers):
|
||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
||||
|
||||
quantize = self.dequantize(embed_ind)
|
||||
return quantize
|
||||
|
||||
def forward(self, x, buffers):
|
||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
||||
|
||||
shape, dtype = x.shape, x.dtype
|
||||
x = preprocess(x)
|
||||
|
||||
self.init_embed_(x)
|
||||
|
||||
embed_ind = self.quantize(x)
|
||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
if self.training:
|
||||
# We do the expiry of code at that point as buffers are in sync
|
||||
# and all the workers will take the same decision.
|
||||
self.expire_codes_(x)
|
||||
if not self.sparse_update:
|
||||
mask = None
|
||||
else:
|
||||
mask = embed_onehot.sum(0) > 0
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay, mask=mask)
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
# if self.normalized_input:
|
||||
# embed_sum = F.normalize(embed_sum, dim=0)
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay,
|
||||
mask=mask.unsqueeze(-1) if self.sparse_update else None)
|
||||
if not self.sparse_update:
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
else:
|
||||
embed_normalized = self.embed_avg
|
||||
if self.normalized_input:
|
||||
embed_normalized = F.normalize(embed_normalized, dim=-1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
# Note: after ema update, there is a very small difference between codebooks on GPUs.
|
||||
# The impact can be very small, ignore it.
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
class VectorQuantization(nn.Module):
|
||||
"""Vector quantization implementation.
|
||||
Currently, supports only euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension
|
||||
codebook_size (int): Codebook size
|
||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
commitment_weight: float = 1.,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
||||
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
|
||||
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
||||
decay=decay, epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
||||
**kwargs)
|
||||
self.codebook_size = codebook_size
|
||||
self.training = True
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook.embed
|
||||
|
||||
def encode(self, x, buffers):
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
embed_in = self._codebook.encode(x, buffers)
|
||||
return embed_in
|
||||
|
||||
def decode(self, embed_ind, buffers):
|
||||
quantize = self._codebook.decode(embed_ind, buffers)
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize
|
||||
|
||||
def forward(self, x, buffers):
|
||||
device = x.device
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
|
||||
quantize, embed_ind = self._codebook(x, buffers)
|
||||
|
||||
if self.training:
|
||||
quantize = x + (quantize - x).detach()
|
||||
|
||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
||||
|
||||
if self.training:
|
||||
if self.commitment_weight > 0:
|
||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
||||
loss = loss + commit_loss * self.commitment_weight
|
||||
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
class DistributedResidualVectorQuantization(nn.Module):
|
||||
"""Efficient distributed residual vector quantization implementation.
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
"""
|
||||
def __init__(self, *,
|
||||
num_quantizers,
|
||||
quantize_dropout: bool = False,
|
||||
rand_num_quant: tp.Optional[tp.List] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
"""
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
"""
|
||||
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["dim"]
|
||||
kmeans_init = kwargs["kmeans_init"]
|
||||
if isinstance(kmeans_init, bool):
|
||||
if not kwargs["kmeans_init"]:
|
||||
# use uniform init
|
||||
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
|
||||
inited = True
|
||||
cluster_size = 1
|
||||
else:
|
||||
# to perform kmeans init on first batch
|
||||
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
|
||||
inited = False
|
||||
cluster_size = 0
|
||||
elif isinstance(kmeans_init, str):
|
||||
# use prepared kmeans init
|
||||
embed = np.load(kmeans_init)
|
||||
embed = torch.from_numpy(embed)
|
||||
if kwargs.get("normalized_input", False):
|
||||
logging.info("normalize the code embeddings since the input is normalized.")
|
||||
embed = F.normalize(embed, dim=-1)
|
||||
if embed.dim() == 2:
|
||||
embed = embed.repeat(num_quantizers, 1, 1)
|
||||
inited = True
|
||||
cluster_size = 1
|
||||
else:
|
||||
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
|
||||
|
||||
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
|
||||
self.register_buffer("cluster_size", torch.ones(num_quantizers, codebook_size) * cluster_size)
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
self.q0_ds_ratio = 1
|
||||
if "q0_ds_ratio" in kwargs:
|
||||
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_quantizers):
|
||||
vq_args = dict(**kwargs)
|
||||
vq = VectorQuantization(**vq_args)
|
||||
self.layers.append(vq)
|
||||
|
||||
self.quantize_dropout = quantize_dropout
|
||||
self.rand_num_quant = rand_num_quant
|
||||
|
||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
||||
quantized_out = torch.zeros_like(x)
|
||||
residual = x
|
||||
bb, cc, tt = x.shape
|
||||
device = x.device
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
all_sub_quants = []
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
||||
|
||||
null_indices_shape = (x.shape[0], x.shape[2])
|
||||
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
||||
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
||||
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
||||
# dropout except the first quantizer
|
||||
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
all_sub_quants.append(null_sub_quant)
|
||||
continue
|
||||
|
||||
quant_in = residual
|
||||
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
||||
quant_in = F.interpolate(quant_in, size=[tt//2])
|
||||
quantized, indices, loss = layer(quant_in, [
|
||||
self.inited[quantizer_index],
|
||||
self.cluster_size[quantizer_index],
|
||||
self.embed[quantizer_index],
|
||||
self.embed_avg[quantizer_index]
|
||||
])
|
||||
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
||||
quantized = F.interpolate(quantized, size=[tt])
|
||||
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
|
||||
residual = residual - quantized
|
||||
quantized_out = quantized_out + quantized
|
||||
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
all_sub_quants.append(quantized)
|
||||
|
||||
# sync buffers after one forward step
|
||||
distrib.broadcast_tensors(self.buffers())
|
||||
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
||||
|
||||
return quantized_out, out_indices, out_losses, out_sub_quants
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
for i, layer in enumerate(self.layers[:n_q]):
|
||||
indices = layer.encode(residual, [
|
||||
self.inited[i],
|
||||
self.cluster_size[i],
|
||||
self.embed[i],
|
||||
self.embed_avg[i]
|
||||
])
|
||||
quantized = layer.decode(indices, [
|
||||
self.inited[i],
|
||||
self.cluster_size[i],
|
||||
self.embed[i],
|
||||
self.embed_avg[i]
|
||||
])
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[i]
|
||||
quantized = layer.decode(indices, [
|
||||
self.inited[i],
|
||||
self.cluster_size[i],
|
||||
self.embed[i],
|
||||
self.embed_avg[i]
|
||||
])
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
||||
126
funasr/models/sense_voice/quantizer/quantization/distrib.py
Normal file
126
funasr/models/sense_voice/quantizer/quantization/distrib.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Torch distributed utilities."""
|
||||
|
||||
import typing as tp
|
||||
import time, logging
|
||||
import torch
|
||||
|
||||
|
||||
def rank():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def world_size():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def is_distributed():
|
||||
return world_size() > 1
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
||||
if is_distributed():
|
||||
return torch.distributed.all_reduce(tensor, op)
|
||||
|
||||
|
||||
def _is_complex_or_float(tensor):
|
||||
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||||
|
||||
|
||||
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
||||
# utility function to check that the number of params in all workers is the same,
|
||||
# and thus avoid a deadlock with distributed all reduce.
|
||||
if not is_distributed() or not params:
|
||||
return
|
||||
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||||
all_reduce(tensor)
|
||||
if tensor.item() != len(params) * world_size():
|
||||
# If not all the workers have the same number, for at least one of them,
|
||||
# this inequality will be verified.
|
||||
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
|
||||
"at least one worker has a different one.")
|
||||
|
||||
|
||||
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
||||
"""Broadcast the tensors from the given parameters to all workers.
|
||||
This can be used to ensure that all workers have the same model to start with.
|
||||
"""
|
||||
# start = time.time()
|
||||
if not is_distributed():
|
||||
return
|
||||
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||||
_check_number_of_params(tensors)
|
||||
handles = []
|
||||
for tensor in tensors:
|
||||
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
||||
handles.append(handle)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
# logging.info(f"Time of broadcast {(time.time()-start):.4f}")
|
||||
|
||||
|
||||
def sync_buffer(buffers, average=True):
|
||||
"""
|
||||
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for buffer in buffers:
|
||||
if torch.is_floating_point(buffer.data):
|
||||
if average:
|
||||
handle = torch.distributed.all_reduce(
|
||||
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(
|
||||
buffer.data, src=0, async_op=True)
|
||||
handles.append((buffer, handle))
|
||||
for buffer, handle in handles:
|
||||
handle.wait()
|
||||
if average:
|
||||
buffer.data /= world_size
|
||||
|
||||
|
||||
def sync_grad(params):
|
||||
"""
|
||||
Simpler alternative to DistributedDataParallel, that doesn't rely
|
||||
on any black magic. For simple models it can also be as fast.
|
||||
Just call this on your model parameters after the call to backward!
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for p in params:
|
||||
if p.grad is not None:
|
||||
handle = torch.distributed.all_reduce(
|
||||
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||
handles.append((p, handle))
|
||||
for p, handle in handles:
|
||||
handle.wait()
|
||||
p.grad.data /= world_size()
|
||||
|
||||
|
||||
def average_metrics(metrics: tp.Dict[str, float], count=1.):
|
||||
"""Average a dictionary of metrics across all workers, using the optional
|
||||
`count` as unormalized weight.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return metrics
|
||||
keys, values = zip(*metrics.items())
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
||||
tensor *= count
|
||||
all_reduce(tensor)
|
||||
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
||||
return dict(zip(keys, averaged))
|
||||
134
funasr/models/sense_voice/quantizer/quantization/vq.py
Normal file
134
funasr/models/sense_voice/quantizer/quantization/vq.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Residual vector quantizer implementation."""
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from funasr.models.sense_voice.quantizer.quantization import distrib
|
||||
from funasr.models.sense_voice.quantizer.quantization.core_vq import SimpleResidualVectorQuantization
|
||||
from funasr.models.sense_voice.quantizer.quantization.ddp_core_vq import DistributedResidualVectorQuantization
|
||||
|
||||
@dataclass
|
||||
class QuantizedResult:
|
||||
quantized: torch.Tensor
|
||||
codes: torch.Tensor
|
||||
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
||||
penalty: tp.Optional[torch.Tensor] = None
|
||||
metrics: dict = field(default_factory=dict)
|
||||
sub_quants: torch.Tensor = None
|
||||
|
||||
|
||||
class ResidualVectorQuantizer(nn.Module):
|
||||
"""Residual Vector Quantizer.
|
||||
Args:
|
||||
dimension (int): Dimension of the codebooks.
|
||||
n_q (int): Number of residual vector quantizers used.
|
||||
bins (int): Codebook size.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int = 256,
|
||||
n_q: int = 8,
|
||||
bins: int = 1024,
|
||||
decay: float = 0.99,
|
||||
kmeans_init: tp.Union[bool, str] = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
quantize_dropout: bool = False,
|
||||
rand_num_quant: tp.Optional[tp.List] = None,
|
||||
encoder_hop_length: int = 320,
|
||||
use_ddp: bool = True,
|
||||
q0_ds_ratio: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_q = n_q
|
||||
self.dimension = dimension
|
||||
self.bins = bins
|
||||
self.decay = decay
|
||||
self.kmeans_init = kmeans_init
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.encoder_hop_length = encoder_hop_length
|
||||
self.training = True
|
||||
if use_ddp:
|
||||
rvq_class = DistributedResidualVectorQuantization
|
||||
logging.info("Using distributed residual vector quantization.")
|
||||
else:
|
||||
rvq_class = SimpleResidualVectorQuantization
|
||||
logging.warning("Using simple residual vector quantization")
|
||||
self.model = rvq_class(
|
||||
dim=self.dimension,
|
||||
codebook_size=self.bins,
|
||||
num_quantizers=self.n_q,
|
||||
decay=self.decay,
|
||||
kmeans_init=self.kmeans_init,
|
||||
kmeans_iters=self.kmeans_iters,
|
||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
||||
quantize_dropout=quantize_dropout,
|
||||
rand_num_quant=rand_num_quant,
|
||||
q0_ds_ratio=q0_ds_ratio,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
|
||||
"""Residual vector quantization on the given input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor in the shape of (B, C, T).
|
||||
sample_rate (int): Sample rate of the input tensor.
|
||||
bandwidth (float): Target bandwidth.
|
||||
Returns:
|
||||
QuantizedResult:
|
||||
The quantized (or approximately quantized) representation with
|
||||
the associated bandwidth and any penalty term for the loss.
|
||||
"""
|
||||
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
||||
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
||||
quantized, codes, commit_loss, sub_quants = self.model(x, n_q=n_q)
|
||||
bw = torch.tensor(n_q * bw_per_q).to(x)
|
||||
return QuantizedResult(quantized, codes, bw,
|
||||
penalty=torch.mean(commit_loss),
|
||||
sub_quants=sub_quants)
|
||||
|
||||
def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
|
||||
"""Return n_q based on specified target bandwidth.
|
||||
"""
|
||||
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
||||
n_q = self.n_q
|
||||
if bandwidth and bandwidth > 0.:
|
||||
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
|
||||
return n_q
|
||||
|
||||
def get_bandwidth_per_quantizer(self, sample_rate: int):
|
||||
"""Return bandwidth per quantizer for a given input sample rate.
|
||||
"""
|
||||
return math.log2(self.bins) * sample_rate / self.encoder_hop_length
|
||||
|
||||
def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
"""
|
||||
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
||||
codes = self.model.encode(x, n_q=n_q)
|
||||
return codes
|
||||
|
||||
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode the given codes to the quantized representation.
|
||||
"""
|
||||
quantized = self.model.decode(codes)
|
||||
return quantized
|
||||
Loading…
Reference in New Issue
Block a user