mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add
This commit is contained in:
parent
90910c4a31
commit
756367d0cc
@ -1,126 +0,0 @@
|
|||||||
import torch
|
|
||||||
from funasr.models.sense_voice.quantizer.quantization.vq import ResidualVectorQuantizer
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
|
|
||||||
class CostumeQuantizer(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int = 512,
|
|
||||||
codebook_size: int = 1024,
|
|
||||||
num_quantizers: int = 8,
|
|
||||||
ema_decay: float = 0.95,
|
|
||||||
kmeans_init: tp.Union[bool, str] = False,
|
|
||||||
sampling_rate: int = 16_000,
|
|
||||||
quantize_dropout: bool = False,
|
|
||||||
rand_num_quant: tp.Optional[tp.List] = None,
|
|
||||||
encoder_hop_length: int = 320,
|
|
||||||
use_ddp: bool = True,
|
|
||||||
q0_ds_ratio: int = 1,
|
|
||||||
codec_dim: int = None,
|
|
||||||
codec_range: float = None,
|
|
||||||
threshold_ema_dead_code=2,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if codec_dim is None:
|
|
||||||
codec_dim = input_size
|
|
||||||
|
|
||||||
self.input_proj, self.output_proj = None, None
|
|
||||||
if codec_dim != input_size:
|
|
||||||
self.input_proj = torch.nn.Linear(input_size, codec_dim)
|
|
||||||
self.output_proj = torch.nn.Linear(codec_dim, input_size)
|
|
||||||
|
|
||||||
self.input_act, self.codec_range = None, None
|
|
||||||
if codec_range is not None:
|
|
||||||
self.input_act = torch.nn.Tanh()
|
|
||||||
self.codec_range = codec_range
|
|
||||||
|
|
||||||
self.rq = ResidualVectorQuantizer(
|
|
||||||
dimension=codec_dim,
|
|
||||||
n_q=num_quantizers,
|
|
||||||
bins=codebook_size,
|
|
||||||
decay=ema_decay,
|
|
||||||
kmeans_init=kmeans_init,
|
|
||||||
quantize_dropout=quantize_dropout,
|
|
||||||
rand_num_quant=rand_num_quant,
|
|
||||||
encoder_hop_length=encoder_hop_length,
|
|
||||||
use_ddp=use_ddp,
|
|
||||||
q0_ds_ratio=q0_ds_ratio,
|
|
||||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.code_dim = codec_dim
|
|
||||||
self.sampling_rate = sampling_rate
|
|
||||||
self.bandwidth: tp.Optional[float] = None
|
|
||||||
self.encoder_hop_length = encoder_hop_length
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
bandwidth: int = None,
|
|
||||||
):
|
|
||||||
# x: input tensor in the shape of (B, T, C)
|
|
||||||
# rq requires inputs in (B, C, T)
|
|
||||||
|
|
||||||
if self.input_proj is not None:
|
|
||||||
x = self.input_proj(x)
|
|
||||||
if self.input_act is not None:
|
|
||||||
x = self.input_act(x) * self.codec_range
|
|
||||||
|
|
||||||
qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
|
|
||||||
x, indices, commit_loss, sub_quants = qv.quantized, qv.codes, qv.penalty, qv.sub_quants
|
|
||||||
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
if self.output_proj is not None:
|
|
||||||
x = self.output_proj(x)
|
|
||||||
|
|
||||||
return x, indices, commit_loss, sub_quants
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
bandwidth: int = None,
|
|
||||||
):
|
|
||||||
# x: input tensor in the shape of (B, T, C)
|
|
||||||
# rq requires inputs in (B, C, T)
|
|
||||||
if self.input_proj is not None:
|
|
||||||
x = self.input_proj(x)
|
|
||||||
if self.input_act is not None:
|
|
||||||
x = self.input_act(x) * self.codec_range
|
|
||||||
|
|
||||||
qv = self.rq(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
|
|
||||||
x, indices, sub_quants = qv.quantized, qv.codes, qv.sub_quants
|
|
||||||
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
if self.output_proj is not None:
|
|
||||||
x = self.output_proj(x)
|
|
||||||
|
|
||||||
return x, indices, sub_quants
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
bandwidth: int = None,
|
|
||||||
):
|
|
||||||
# x: input tensor in the shape of (B, T, C)
|
|
||||||
# rq requires inputs in (B, C, T)
|
|
||||||
if self.input_proj is not None:
|
|
||||||
x = self.input_proj(x)
|
|
||||||
if self.input_act is not None:
|
|
||||||
x = self.input_act(x) * self.codec_range
|
|
||||||
|
|
||||||
indices = self.rq.encode(x.permute(0, 2, 1), self.sampling_rate, bandwidth)
|
|
||||||
# return value in n_q x B x T
|
|
||||||
return indices
|
|
||||||
|
|
||||||
def decode(self, indices):
|
|
||||||
quantized_out = self.rq.decode(indices)
|
|
||||||
# quantized_out in B x D x T
|
|
||||||
if self.output_proj is not None:
|
|
||||||
quantized_out = self.output_proj(quantized_out.transpose(1, 2)).transpose(1, 2)
|
|
||||||
return quantized_out
|
|
||||||
|
|
||||||
def output_size(self):
|
|
||||||
return self.code_dim
|
|
||||||
@ -1,299 +0,0 @@
|
|||||||
"""
|
|
||||||
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
|
||||||
Code adapted from Jax version in Appendix A.1
|
|
||||||
"""
|
|
||||||
import random
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import Module
|
|
||||||
from torch import Tensor, int32
|
|
||||||
|
|
||||||
from einops import rearrange, pack, unpack
|
|
||||||
|
|
||||||
|
|
||||||
# helper functions
|
|
||||||
|
|
||||||
def exists(v):
|
|
||||||
return v is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(*args):
|
|
||||||
for arg in args:
|
|
||||||
if exists(arg):
|
|
||||||
return arg
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def pack_one(t, pattern):
|
|
||||||
return pack([t], pattern)
|
|
||||||
|
|
||||||
|
|
||||||
def unpack_one(t, ps, pattern):
|
|
||||||
return unpack(t, ps, pattern)[0]
|
|
||||||
|
|
||||||
|
|
||||||
# tensor helpers
|
|
||||||
|
|
||||||
def round_ste(z: Tensor) -> Tensor:
|
|
||||||
"""Round with straight through gradients."""
|
|
||||||
zhat = z.round()
|
|
||||||
return z + (zhat - z).detach()
|
|
||||||
|
|
||||||
|
|
||||||
# main class
|
|
||||||
|
|
||||||
class FSQ(Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
levels: List[int],
|
|
||||||
input_size: Optional[int] = None,
|
|
||||||
num_codebooks=1,
|
|
||||||
keep_num_codebooks_dim: Optional[bool] = None,
|
|
||||||
scale: Optional[float] = None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
_levels = torch.tensor(levels, dtype=int32)
|
|
||||||
self.register_buffer("_levels", _levels, persistent=False)
|
|
||||||
|
|
||||||
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
|
||||||
self.register_buffer("_basis", _basis, persistent=False)
|
|
||||||
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
codebook_dim = len(levels)
|
|
||||||
self.codebook_dim = codebook_dim
|
|
||||||
|
|
||||||
effective_codebook_dim = codebook_dim * num_codebooks
|
|
||||||
self.num_codebooks = num_codebooks
|
|
||||||
self.effective_codebook_dim = effective_codebook_dim
|
|
||||||
|
|
||||||
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
|
||||||
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
|
||||||
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
|
||||||
|
|
||||||
self.dim = default(input_size, len(_levels) * num_codebooks)
|
|
||||||
|
|
||||||
has_projections = self.dim != effective_codebook_dim
|
|
||||||
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
|
||||||
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
|
||||||
self.has_projections = has_projections
|
|
||||||
|
|
||||||
self.codebook_size = self._levels.prod().item()
|
|
||||||
|
|
||||||
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
|
|
||||||
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
|
||||||
|
|
||||||
def output_size(self):
|
|
||||||
return self.dim
|
|
||||||
|
|
||||||
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
|
|
||||||
"""Bound `z`, an array of shape (..., d)."""
|
|
||||||
half_l = (self._levels - 1) * (1 - eps) / 2
|
|
||||||
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
|
||||||
shift = (offset / half_l).atanh()
|
|
||||||
return (z + shift).tanh() * half_l - offset
|
|
||||||
|
|
||||||
def quantize(self, z: Tensor) -> Tensor:
|
|
||||||
"""Quantizes z, returns quantized zhat, same shape as z."""
|
|
||||||
quantized = round_ste(self.bound(z))
|
|
||||||
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
|
||||||
return quantized / half_width
|
|
||||||
|
|
||||||
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
|
|
||||||
half_width = self._levels // 2
|
|
||||||
return (zhat_normalized * half_width) + half_width
|
|
||||||
|
|
||||||
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
|
|
||||||
half_width = self._levels // 2
|
|
||||||
return (zhat - half_width) / half_width
|
|
||||||
|
|
||||||
def codes_to_indices(self, zhat: Tensor) -> Tensor:
|
|
||||||
"""Converts a `code` to an index in the codebook."""
|
|
||||||
assert zhat.shape[-1] == self.codebook_dim
|
|
||||||
zhat = self._scale_and_shift(zhat)
|
|
||||||
return (zhat * self._basis).sum(dim=-1).to(int32)
|
|
||||||
|
|
||||||
def indices_to_codes(
|
|
||||||
self,
|
|
||||||
indices: Tensor,
|
|
||||||
project_out=True
|
|
||||||
) -> Tensor:
|
|
||||||
"""Inverse of `codes_to_indices`."""
|
|
||||||
|
|
||||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
|
||||||
|
|
||||||
indices = rearrange(indices, '... -> ... 1')
|
|
||||||
codes_non_centered = (indices // self._basis) % self._levels
|
|
||||||
codes = self._scale_and_shift_inverse(codes_non_centered)
|
|
||||||
|
|
||||||
if self.keep_num_codebooks_dim:
|
|
||||||
codes = rearrange(codes, '... c d -> ... (c d)')
|
|
||||||
|
|
||||||
if project_out:
|
|
||||||
codes = self.project_out(codes)
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
codes = rearrange(codes, 'b ... d -> b d ...')
|
|
||||||
|
|
||||||
return codes
|
|
||||||
|
|
||||||
def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
n - sequence (or flattened spatial dimensions)
|
|
||||||
d - feature dimension, which is also log2(codebook size)
|
|
||||||
c - number of codebook dim
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_img_or_video = z.ndim >= 4
|
|
||||||
|
|
||||||
# standardize image or video into (batch, seq, dimension)
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
z = rearrange(z, 'b d ... -> b ... d')
|
|
||||||
z, ps = pack_one(z, 'b * d')
|
|
||||||
|
|
||||||
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
|
||||||
|
|
||||||
z = self.project_in(z)
|
|
||||||
|
|
||||||
z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
|
||||||
|
|
||||||
codes = self.quantize(z)
|
|
||||||
indices = self.codes_to_indices(codes)
|
|
||||||
|
|
||||||
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
|
||||||
|
|
||||||
out = self.project_out(codes)
|
|
||||||
|
|
||||||
# reconstitute image or video dimensions
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
out = unpack_one(out, ps, 'b * d')
|
|
||||||
out = rearrange(out, 'b ... d -> b d ...')
|
|
||||||
|
|
||||||
indices = unpack_one(indices, ps, 'b * c')
|
|
||||||
|
|
||||||
if not self.keep_num_codebooks_dim:
|
|
||||||
indices = rearrange(indices, '... 1 -> ...')
|
|
||||||
|
|
||||||
commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device)
|
|
||||||
return out, indices, commit_loss, None
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
bandwidth: int = None,
|
|
||||||
):
|
|
||||||
x, indices, _, _ = self.forward(x, bandwidth=bandwidth)
|
|
||||||
|
|
||||||
return x, indices, None
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryFSQ(FSQ):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
levels: List[int],
|
|
||||||
input_size: Optional[int] = None,
|
|
||||||
num_codebooks=1,
|
|
||||||
keep_num_codebooks_dim: Optional[bool] = None,
|
|
||||||
scale: Optional[float] = None,
|
|
||||||
rand_num_codebooks: Optional[List] = None,
|
|
||||||
):
|
|
||||||
_levels = torch.tensor(levels, dtype=int32)
|
|
||||||
assert torch.all(_levels == 2), "BinaryFSQ requires the levels must be 2"
|
|
||||||
super().__init__(
|
|
||||||
levels, input_size, num_codebooks,
|
|
||||||
keep_num_codebooks_dim, scale
|
|
||||||
)
|
|
||||||
self.rand_num_codebooks = rand_num_codebooks
|
|
||||||
|
|
||||||
def output_size(self):
|
|
||||||
return self.dim
|
|
||||||
|
|
||||||
def bound(self, z: Tensor, eps=1e-3) -> Tensor:
|
|
||||||
"""Bound `z`, an array of shape (..., d)."""
|
|
||||||
return torch.sigmoid(z)
|
|
||||||
|
|
||||||
def quantize(self, z: Tensor) -> Tensor:
|
|
||||||
"""Quantizes z, returns quantized zhat, same shape as z."""
|
|
||||||
quantized = round_ste(self.bound(z))
|
|
||||||
return quantized
|
|
||||||
|
|
||||||
def codes_to_indices(self, zhat: Tensor) -> Tensor:
|
|
||||||
"""Converts a `code` to an index in the codebook."""
|
|
||||||
assert zhat.shape[-1] == self.codebook_dim
|
|
||||||
return (zhat * self._basis).sum(dim=-1).to(int32)
|
|
||||||
|
|
||||||
def indices_to_codes(
|
|
||||||
self,
|
|
||||||
indices: Tensor,
|
|
||||||
project_out=True
|
|
||||||
) -> Tensor:
|
|
||||||
"""Inverse of `codes_to_indices`."""
|
|
||||||
|
|
||||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
|
||||||
|
|
||||||
indices = rearrange(indices, '... -> ... 1')
|
|
||||||
codes = (indices // self._basis) % self._levels
|
|
||||||
|
|
||||||
if self.keep_num_codebooks_dim:
|
|
||||||
codes = rearrange(codes, '... c d -> ... (c d)')
|
|
||||||
|
|
||||||
if project_out:
|
|
||||||
codes = self.project_out(codes)
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
codes = rearrange(codes, 'b ... d -> b d ...')
|
|
||||||
|
|
||||||
return codes
|
|
||||||
|
|
||||||
def forward(self, z: Tensor, bandwidth: int = None,) -> [Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
n - sequence (or flattened spatial dimensions)
|
|
||||||
d - feature dimension, which is also log2(codebook size)
|
|
||||||
c - number of codebook dim
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_img_or_video = z.ndim >= 4
|
|
||||||
|
|
||||||
# standardize image or video into (batch, seq, dimension)
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
z = rearrange(z, 'b d ... -> b ... d')
|
|
||||||
z, ps = pack_one(z, 'b * d')
|
|
||||||
|
|
||||||
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
|
||||||
|
|
||||||
z = self.project_in(z)
|
|
||||||
|
|
||||||
z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
|
||||||
|
|
||||||
codes = self.quantize(z)
|
|
||||||
if self.rand_num_codebooks is not None:
|
|
||||||
quant_idx = random.choice(self.rand_num_codebooks)
|
|
||||||
codes[:, :, quant_idx:, :] = 0
|
|
||||||
indices = self.codes_to_indices(codes)
|
|
||||||
|
|
||||||
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
|
||||||
|
|
||||||
out = self.project_out(codes)
|
|
||||||
|
|
||||||
# reconstitute image or video dimensions
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
out = unpack_one(out, ps, 'b * d')
|
|
||||||
out = rearrange(out, 'b ... d -> b d ...')
|
|
||||||
|
|
||||||
indices = unpack_one(indices, ps, 'b * c')
|
|
||||||
|
|
||||||
if not self.keep_num_codebooks_dim:
|
|
||||||
indices = rearrange(indices, '... 1 -> ...')
|
|
||||||
|
|
||||||
commit_loss = torch.tensor([0], dtype=torch.float32, device=z.device)
|
|
||||||
return out, indices, commit_loss, None
|
|
||||||
@ -1,493 +0,0 @@
|
|||||||
"""
|
|
||||||
Lookup Free Quantization
|
|
||||||
Proposed in https://arxiv.org/abs/2310.05737
|
|
||||||
|
|
||||||
In the simplest setup, each dimension is quantized into {-1, 1}.
|
|
||||||
An entropy penalty is used to encourage utilization.
|
|
||||||
"""
|
|
||||||
import random
|
|
||||||
from math import log2, ceil
|
|
||||||
from collections import namedtuple
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn, einsum
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn import Module
|
|
||||||
|
|
||||||
from einops import rearrange, reduce, pack, unpack
|
|
||||||
|
|
||||||
# constants
|
|
||||||
|
|
||||||
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
|
|
||||||
|
|
||||||
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
|
|
||||||
|
|
||||||
|
|
||||||
# helper functions
|
|
||||||
|
|
||||||
def exists(v):
|
|
||||||
return v is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(*args):
|
|
||||||
for arg in args:
|
|
||||||
if exists(arg):
|
|
||||||
return arg() if callable(arg) else arg
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def pack_one(t, pattern):
|
|
||||||
return pack([t], pattern)
|
|
||||||
|
|
||||||
|
|
||||||
def unpack_one(t, ps, pattern):
|
|
||||||
return unpack(t, ps, pattern)[0]
|
|
||||||
|
|
||||||
|
|
||||||
# entropy
|
|
||||||
|
|
||||||
def log(t, eps=1e-5):
|
|
||||||
return t.clamp(min=eps).log()
|
|
||||||
|
|
||||||
|
|
||||||
def entropy(prob):
|
|
||||||
return (-prob * log(prob)).sum(dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# class
|
|
||||||
|
|
||||||
class LFQ(Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
input_size=None,
|
|
||||||
dim=None,
|
|
||||||
codebook_size=None,
|
|
||||||
entropy_loss_weight=0.1,
|
|
||||||
commitment_loss_weight=0.25,
|
|
||||||
diversity_gamma=1.,
|
|
||||||
straight_through_activation="identity",
|
|
||||||
num_codebooks=1,
|
|
||||||
keep_num_codebooks_dim=None,
|
|
||||||
codebook_scale=1., # for residual LFQ, codebook scaled down by 2x at each layer
|
|
||||||
rand_num_codebooks: Optional[List] = None,
|
|
||||||
sampling_rate=16000,
|
|
||||||
encoder_hop_length=640,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# some assert validations
|
|
||||||
dim = input_size
|
|
||||||
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
|
|
||||||
assert not exists(codebook_size) or log2(
|
|
||||||
codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
|
|
||||||
|
|
||||||
codebook_size = default(codebook_size, lambda: 2 ** dim)
|
|
||||||
codebook_dim = int(log2(codebook_size))
|
|
||||||
|
|
||||||
codebook_dims = codebook_dim * num_codebooks
|
|
||||||
dim = default(dim, codebook_dims)
|
|
||||||
|
|
||||||
has_projections = dim != codebook_dims
|
|
||||||
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
|
|
||||||
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
|
|
||||||
self.has_projections = has_projections
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.codebook_dim = codebook_dim
|
|
||||||
self.num_codebooks = num_codebooks
|
|
||||||
|
|
||||||
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
|
||||||
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
|
||||||
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
|
||||||
|
|
||||||
# straight through activation
|
|
||||||
|
|
||||||
if straight_through_activation == "identity":
|
|
||||||
self.activation = nn.Identity()
|
|
||||||
elif straight_through_activation == "tanh":
|
|
||||||
self.activation = nn.Tanh()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Unsupported activation type, only 'tanh' and 'identity' are supported")
|
|
||||||
|
|
||||||
# entropy aux loss related weights
|
|
||||||
|
|
||||||
self.diversity_gamma = diversity_gamma
|
|
||||||
self.entropy_loss_weight = entropy_loss_weight
|
|
||||||
|
|
||||||
# codebook scale
|
|
||||||
|
|
||||||
self.codebook_scale = codebook_scale
|
|
||||||
|
|
||||||
# commitment loss
|
|
||||||
|
|
||||||
self.commitment_loss_weight = commitment_loss_weight
|
|
||||||
|
|
||||||
# for no auxiliary loss, during inference
|
|
||||||
|
|
||||||
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
|
||||||
self.register_buffer('zero', torch.tensor(0.), persistent=False)
|
|
||||||
|
|
||||||
# codes
|
|
||||||
|
|
||||||
all_codes = torch.arange(codebook_size)
|
|
||||||
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
|
||||||
codebook = self.bits_to_codes(bits)
|
|
||||||
|
|
||||||
self.register_buffer('codebook', codebook, persistent=False)
|
|
||||||
self.rand_num_codebooks = rand_num_codebooks
|
|
||||||
self.sampling_rate = sampling_rate
|
|
||||||
self.encoder_hop_length = encoder_hop_length
|
|
||||||
|
|
||||||
def output_size(self):
|
|
||||||
return self.dim
|
|
||||||
|
|
||||||
def bits_to_codes(self, bits):
|
|
||||||
return bits * self.codebook_scale * 2 - self.codebook_scale
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.codebook.dtype
|
|
||||||
|
|
||||||
def indices_to_codes(
|
|
||||||
self,
|
|
||||||
indices,
|
|
||||||
project_out=True
|
|
||||||
):
|
|
||||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
|
||||||
|
|
||||||
if not self.keep_num_codebooks_dim:
|
|
||||||
indices = rearrange(indices, '... -> ... 1')
|
|
||||||
|
|
||||||
# indices to codes, which are bits of either -1 or 1
|
|
||||||
|
|
||||||
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
|
|
||||||
|
|
||||||
codes = self.bits_to_codes(bits)
|
|
||||||
|
|
||||||
codes = rearrange(codes, '... c d -> ... (c d)')
|
|
||||||
|
|
||||||
# whether to project codes out to original dimensions
|
|
||||||
# if the input feature dimensions were not log2(codebook size)
|
|
||||||
|
|
||||||
if project_out:
|
|
||||||
codes = self.project_out(codes)
|
|
||||||
|
|
||||||
# rearrange codes back to original shape
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
codes = rearrange(codes, 'b ... d -> b d ...')
|
|
||||||
|
|
||||||
return codes
|
|
||||||
|
|
||||||
def keep_first_nq_codes(self, x, nq=None):
|
|
||||||
if nq is None or nq >= self.num_codebooks:
|
|
||||||
return x
|
|
||||||
inv_p = 1.0 / (nq / self.num_codebooks)
|
|
||||||
x[:, :, nq:] = 0
|
|
||||||
x[:, :, :nq] = x[:, :, :nq] * inv_p
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def random_dropout_codes(self, inputs):
|
|
||||||
x = torch.clone(inputs)
|
|
||||||
rand_num = random.choice(self.rand_num_codebooks)
|
|
||||||
return self.keep_first_nq_codes(x, nq=rand_num)
|
|
||||||
|
|
||||||
def cal_num_quant(self, bite_width):
|
|
||||||
frame_rate = self.sampling_rate / self.encoder_hop_length
|
|
||||||
nq = bite_width / frame_rate / self.codebook_dim
|
|
||||||
return nq
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
inv_temperature=100.,
|
|
||||||
return_loss_breakdown=False,
|
|
||||||
mask=None,
|
|
||||||
bite_width=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
n - sequence (or flattened spatial dimensions)
|
|
||||||
d - feature dimension, which is also log2(codebook size)
|
|
||||||
c - number of codebook dim
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_img_or_video = x.ndim >= 4
|
|
||||||
|
|
||||||
# standardize image or video into (batch, seq, dimension)
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
x = rearrange(x, 'b d ... -> b ... d')
|
|
||||||
x, ps = pack_one(x, 'b * d')
|
|
||||||
|
|
||||||
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
|
||||||
|
|
||||||
x = self.project_in(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
|
|
||||||
# split out number of codebooks
|
|
||||||
|
|
||||||
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
|
||||||
|
|
||||||
# quantize by eq 3.
|
|
||||||
|
|
||||||
original_input = x
|
|
||||||
|
|
||||||
codebook_value = torch.ones_like(x) * self.codebook_scale
|
|
||||||
# do quantization
|
|
||||||
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
|
||||||
|
|
||||||
# use straight-through gradients (optionally with custom activation fn) if training
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
x = x + (quantized - x).detach()
|
|
||||||
else:
|
|
||||||
x = quantized
|
|
||||||
|
|
||||||
# calculate indices
|
|
||||||
|
|
||||||
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
|
||||||
|
|
||||||
# entropy aux loss
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# the same as euclidean distance up to a constant
|
|
||||||
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
|
|
||||||
|
|
||||||
prob = (-distance * inv_temperature).softmax(dim=-1)
|
|
||||||
|
|
||||||
per_sample_entropy = entropy(prob).mean()
|
|
||||||
|
|
||||||
# account for mask
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
prob = prob[mask]
|
|
||||||
|
|
||||||
# distribution over all available tokens in the batch
|
|
||||||
|
|
||||||
avg_prob = reduce(prob, '... c d -> c d', 'mean')
|
|
||||||
codebook_entropy = entropy(avg_prob).mean()
|
|
||||||
|
|
||||||
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
|
||||||
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
|
||||||
|
|
||||||
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
|
||||||
else:
|
|
||||||
# if not training, just return dummy 0
|
|
||||||
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
|
||||||
|
|
||||||
# commit loss
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none')
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
commit_loss = commit_loss[mask]
|
|
||||||
|
|
||||||
commit_loss = commit_loss.mean()
|
|
||||||
else:
|
|
||||||
commit_loss = self.zero
|
|
||||||
|
|
||||||
# randomly dropout codebooks to fit varying bite width
|
|
||||||
if self.training and self.rand_num_codebooks is not None:
|
|
||||||
x = self.random_dropout_codes(x)
|
|
||||||
if bite_width is not None:
|
|
||||||
x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width))
|
|
||||||
|
|
||||||
# merge back codebook dim
|
|
||||||
|
|
||||||
x = rearrange(x, 'b n c d -> b n (c d)')
|
|
||||||
|
|
||||||
# project out to feature dimension if needed
|
|
||||||
|
|
||||||
x = self.project_out(x)
|
|
||||||
|
|
||||||
# reconstitute image or video dimensions
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
x = unpack_one(x, ps, 'b * d')
|
|
||||||
x = rearrange(x, 'b ... d -> b d ...')
|
|
||||||
|
|
||||||
indices = unpack_one(indices, ps, 'b * c')
|
|
||||||
|
|
||||||
# whether to remove single codebook dim
|
|
||||||
|
|
||||||
if not self.keep_num_codebooks_dim:
|
|
||||||
indices = rearrange(indices, '... 1 -> ...')
|
|
||||||
|
|
||||||
# complete aux loss
|
|
||||||
|
|
||||||
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
|
||||||
|
|
||||||
if not return_loss_breakdown:
|
|
||||||
return x, indices, aux_loss, None
|
|
||||||
|
|
||||||
return x, indices, aux_loss, dict(
|
|
||||||
per_sample_entropy=per_sample_entropy,
|
|
||||||
codebook_entropy=codebook_entropy,
|
|
||||||
commit_loss=commit_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
bandwidth: int = None,
|
|
||||||
):
|
|
||||||
x, indices, _, _ = self.forward(x, bite_width=bandwidth)
|
|
||||||
|
|
||||||
return x, indices, None
|
|
||||||
|
|
||||||
|
|
||||||
class ScalableLFQ(LFQ):
|
|
||||||
def __init__(self, *, input_size=None, dim=None, codebook_size=None, entropy_loss_weight=0.1,
|
|
||||||
commitment_loss_weight=0.25, diversity_gamma=1., straight_through_activation=nn.Identity(),
|
|
||||||
num_codebooks=1, keep_num_codebooks_dim=None, codebook_scale=1.,
|
|
||||||
rand_num_codebooks: Optional[List] = None, sampling_rate=16000, hop_length=640, **kwargs):
|
|
||||||
super().__init__(input_size=input_size, dim=dim, codebook_size=codebook_size,
|
|
||||||
entropy_loss_weight=entropy_loss_weight, commitment_loss_weight=commitment_loss_weight,
|
|
||||||
diversity_gamma=diversity_gamma, straight_through_activation=straight_through_activation,
|
|
||||||
num_codebooks=num_codebooks, keep_num_codebooks_dim=keep_num_codebooks_dim,
|
|
||||||
codebook_scale=codebook_scale, rand_num_codebooks=rand_num_codebooks,
|
|
||||||
sampling_rate=sampling_rate, hop_length=hop_length)
|
|
||||||
codebook_alpha_conf = kwargs.get("codebook_alpha_conf", None)
|
|
||||||
self.init_codebook_alpha(codebook_alpha_conf)
|
|
||||||
|
|
||||||
def init_codebook_alpha(self, codebook_alpha_conf: dict):
|
|
||||||
assert codebook_alpha_conf is not None, "codebook_alpha_conf cannot be None"
|
|
||||||
name = codebook_alpha_conf.get("name", "constant")
|
|
||||||
if name == "constant":
|
|
||||||
alphas = codebook_alpha_conf.get("alphas", [1.0] * self.num_codebooks)
|
|
||||||
assert len(alphas) == self.num_codebooks, \
|
|
||||||
f"the length of codebook alphas {len(alphas)} " \
|
|
||||||
f"must match num_codebooks {self.num_codebooks}."
|
|
||||||
alphas = np.array(alphas)
|
|
||||||
elif name == "exponential":
|
|
||||||
temp = codebook_alpha_conf.get("temp", 8.0)
|
|
||||||
alphas = np.exp(-np.arange(0, self.num_codebooks) / temp)
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unknown codebook alpha type {name}.")
|
|
||||||
codebook_alpha = torch.tensor(alphas/alphas.sum(), dtype=torch.float32).reshape(1, 1, -1, 1)
|
|
||||||
self.register_buffer("codebook_alpha", codebook_alpha)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
inv_temperature=100.,
|
|
||||||
return_loss_breakdown=False,
|
|
||||||
mask=None,
|
|
||||||
bite_width=None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
n - sequence (or flattened spatial dimensions)
|
|
||||||
d - feature dimension, which is also log2(codebook size)
|
|
||||||
c - number of codebook dim
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_img_or_video = x.ndim >= 4
|
|
||||||
|
|
||||||
# standardize image or video into (batch, seq, dimension)
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
x = rearrange(x, 'b d ... -> b ... d')
|
|
||||||
x, ps = pack_one(x, 'b * d')
|
|
||||||
|
|
||||||
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
|
||||||
|
|
||||||
x = self.project_in(x)
|
|
||||||
|
|
||||||
# split out number of codebooks
|
|
||||||
x = rearrange(x, 'b n (c d) -> b n c d', c=self.num_codebooks)
|
|
||||||
|
|
||||||
# quantize by eq 3.
|
|
||||||
original_input = x
|
|
||||||
codebook_value = torch.ones_like(x) * self.codebook_scale
|
|
||||||
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
|
||||||
|
|
||||||
# use straight-through gradients (optionally with custom activation fn) if training
|
|
||||||
if self.training:
|
|
||||||
x = self.activation(x)
|
|
||||||
x = x + (quantized - x).detach()
|
|
||||||
else:
|
|
||||||
x = quantized
|
|
||||||
|
|
||||||
# calculate indices
|
|
||||||
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
|
||||||
|
|
||||||
# entropy aux loss
|
|
||||||
if self.training:
|
|
||||||
# the same as euclidean distance up to a constant
|
|
||||||
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
|
|
||||||
prob = (-distance * inv_temperature).softmax(dim=-1)
|
|
||||||
per_sample_entropy = entropy(prob).mean()
|
|
||||||
|
|
||||||
# account for mask
|
|
||||||
if exists(mask):
|
|
||||||
prob = prob[mask]
|
|
||||||
|
|
||||||
# distribution over all available tokens in the batch
|
|
||||||
avg_prob = reduce(prob, '... c d -> c d', 'mean')
|
|
||||||
codebook_entropy = entropy(avg_prob).mean()
|
|
||||||
|
|
||||||
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
|
||||||
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
|
||||||
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
|
||||||
else:
|
|
||||||
# if not training, just return dummy 0
|
|
||||||
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
|
||||||
|
|
||||||
# commit loss
|
|
||||||
if self.training:
|
|
||||||
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction='none')
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
commit_loss = commit_loss[mask]
|
|
||||||
|
|
||||||
commit_loss = commit_loss.mean()
|
|
||||||
else:
|
|
||||||
commit_loss = self.zero
|
|
||||||
|
|
||||||
# randomly dropout codebooks to fit s bite width
|
|
||||||
if self.training and self.rand_num_codebooks is not None:
|
|
||||||
x = self.random_dropout_codes(x)
|
|
||||||
if bite_width is not None:
|
|
||||||
x = self.keep_first_nq_codes(x, self.cal_num_quant(bite_width))
|
|
||||||
|
|
||||||
x = x * self.codebook_alpha
|
|
||||||
|
|
||||||
# merge back codebook dim
|
|
||||||
x = rearrange(x, 'b n c d -> b n (c d)')
|
|
||||||
|
|
||||||
# project out to feature dimension if needed
|
|
||||||
|
|
||||||
x = self.project_out(x)
|
|
||||||
|
|
||||||
# reconstitute image or video dimensions
|
|
||||||
|
|
||||||
if is_img_or_video:
|
|
||||||
x = unpack_one(x, ps, 'b * d')
|
|
||||||
x = rearrange(x, 'b ... d -> b d ...')
|
|
||||||
|
|
||||||
indices = unpack_one(indices, ps, 'b * c')
|
|
||||||
|
|
||||||
# whether to remove single codebook dim
|
|
||||||
if not self.keep_num_codebooks_dim:
|
|
||||||
indices = rearrange(indices, '... 1 -> ...')
|
|
||||||
|
|
||||||
# complete aux loss
|
|
||||||
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
|
||||||
|
|
||||||
if not return_loss_breakdown:
|
|
||||||
return x, indices, aux_loss, None
|
|
||||||
|
|
||||||
return x, indices, aux_loss, dict(
|
|
||||||
per_sample_entropy=per_sample_entropy,
|
|
||||||
codebook_entropy=codebook_entropy,
|
|
||||||
commit_loss=commit_loss
|
|
||||||
)
|
|
||||||
@ -1,8 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
# flake8: noqa
|
|
||||||
from funasr.modules.quantization.vq import QuantizedResult, ResidualVectorQuantizer
|
|
||||||
@ -1,291 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
"""Arithmetic coder."""
|
|
||||||
|
|
||||||
import io
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import typing as tp
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from funasr.models.sense_voice.quantizer.quantization.binary import BitPacker, BitUnpacker
|
|
||||||
|
|
||||||
|
|
||||||
def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
|
|
||||||
roundoff: float = 1e-8, min_range: int = 2,
|
|
||||||
check: bool = True) -> torch.Tensor:
|
|
||||||
"""Turn the given PDF into a quantized CDF that splits
|
|
||||||
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
|
||||||
to the PDF.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
|
|
||||||
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
|
||||||
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
|
||||||
roundoff (float): will round the pdf up to that level to remove difference coming
|
|
||||||
from e.g. evaluating the Language Model on different architectures.
|
|
||||||
min_range (int): minimum range width. Should always be at least 2 for numerical
|
|
||||||
stability. Use this to avoid pathological behavior is a value
|
|
||||||
that is expected to be rare actually happens in real life.
|
|
||||||
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
|
||||||
"""
|
|
||||||
pdf = pdf.detach()
|
|
||||||
if roundoff:
|
|
||||||
pdf = (pdf / roundoff).floor() * roundoff
|
|
||||||
# interpolate with uniform distribution to achieve desired minimum probability.
|
|
||||||
total_range = 2 ** total_range_bits
|
|
||||||
cardinality = len(pdf)
|
|
||||||
alpha = min_range * cardinality / total_range
|
|
||||||
assert alpha <= 1, "you must reduce min_range"
|
|
||||||
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
|
||||||
ranges += min_range
|
|
||||||
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
|
||||||
if min_range < 2:
|
|
||||||
raise ValueError("min_range must be at least 2.")
|
|
||||||
if check:
|
|
||||||
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
|
|
||||||
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
|
|
||||||
raise ValueError("You must increase your total_range_bits.")
|
|
||||||
return quantized_cdf
|
|
||||||
|
|
||||||
|
|
||||||
class ArithmeticCoder:
|
|
||||||
"""ArithmeticCoder,
|
|
||||||
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
|
||||||
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
|
||||||
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
|
||||||
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
|
||||||
sequence `(s_t)` by doing the following:
|
|
||||||
|
|
||||||
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
|
||||||
2) For each time step t, split the current range into contiguous chunks,
|
|
||||||
one for each possible outcome, with size roughly proportional to `p`.
|
|
||||||
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
|
||||||
would be `{[0, 2], [3, 3]}`.
|
|
||||||
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
|
||||||
4) When done encoding all the values, just select any value remaining in the range.
|
|
||||||
|
|
||||||
You will notice that this procedure can fail: for instance if at any point in time
|
|
||||||
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
|
||||||
possible outcome. Intuitively, the more likely a value is, the less the range width
|
|
||||||
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
|
||||||
coding scheme, likely outcomes would take fewer bits, and more of them can be coded
|
|
||||||
with a fixed budget.
|
|
||||||
|
|
||||||
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
|
||||||
when the current range decreases below a given limit (given by `total_range_bits`), without
|
|
||||||
having to redo all the computations. If we encode mostly likely values, we will seldom
|
|
||||||
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
|
||||||
|
|
||||||
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
|
||||||
code works for any sequence `(p_t)` possibly different for each timestep.
|
|
||||||
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
|
||||||
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
|
||||||
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
|
||||||
Any time the current range width fall under this limit, new bits will
|
|
||||||
be injected to rescale the initial range.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
|
||||||
assert total_range_bits <= 30
|
|
||||||
self.total_range_bits = total_range_bits
|
|
||||||
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
|
||||||
self.low: int = 0
|
|
||||||
self.high: int = 0
|
|
||||||
self.max_bit: int = -1
|
|
||||||
self._dbg: tp.List[tp.Any] = []
|
|
||||||
self._dbg2: tp.List[tp.Any] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def delta(self) -> int:
|
|
||||||
"""Return the current range width."""
|
|
||||||
return self.high - self.low + 1
|
|
||||||
|
|
||||||
def _flush_common_prefix(self):
|
|
||||||
# If self.low and self.high start with the sames bits,
|
|
||||||
# those won't change anymore as we always just increase the range
|
|
||||||
# by powers of 2, and we can flush them out to the bit stream.
|
|
||||||
assert self.high >= self.low, (self.low, self.high)
|
|
||||||
assert self.high < 2 ** (self.max_bit + 1)
|
|
||||||
while self.max_bit >= 0:
|
|
||||||
b1 = self.low >> self.max_bit
|
|
||||||
b2 = self.high >> self.max_bit
|
|
||||||
if b1 == b2:
|
|
||||||
self.low -= (b1 << self.max_bit)
|
|
||||||
self.high -= (b1 << self.max_bit)
|
|
||||||
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
|
||||||
assert self.low >= 0
|
|
||||||
self.max_bit -= 1
|
|
||||||
self.packer.push(b1)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
def push(self, symbol: int, quantized_cdf: torch.Tensor):
|
|
||||||
"""Push the given symbol on the stream, flushing out bits
|
|
||||||
if possible.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol (int): symbol to encode with the AC.
|
|
||||||
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
|
||||||
to build this from your pdf estimate.
|
|
||||||
"""
|
|
||||||
while self.delta < 2 ** self.total_range_bits:
|
|
||||||
self.low *= 2
|
|
||||||
self.high = self.high * 2 + 1
|
|
||||||
self.max_bit += 1
|
|
||||||
|
|
||||||
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
|
||||||
range_high = quantized_cdf[symbol].item() - 1
|
|
||||||
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
|
|
||||||
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
|
|
||||||
assert self.low <= self.high
|
|
||||||
self.high = self.low + effective_high
|
|
||||||
self.low = self.low + effective_low
|
|
||||||
assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
|
|
||||||
self._dbg.append((self.low, self.high))
|
|
||||||
self._dbg2.append((self.low, self.high))
|
|
||||||
self._flush_common_prefix()
|
|
||||||
assert self.low <= self.high
|
|
||||||
assert self.max_bit >= -1
|
|
||||||
assert self.max_bit <= 61, self.max_bit
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
"""Flush the remaining information to the stream.
|
|
||||||
"""
|
|
||||||
while self.max_bit >= 0:
|
|
||||||
b1 = (self.low >> self.max_bit) & 1
|
|
||||||
self.packer.push(b1)
|
|
||||||
self.max_bit -= 1
|
|
||||||
self.packer.flush()
|
|
||||||
|
|
||||||
|
|
||||||
class ArithmeticDecoder:
|
|
||||||
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
|
||||||
|
|
||||||
Note that this must be called with **exactly** the same parameters and sequence
|
|
||||||
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
|
||||||
|
|
||||||
If the AC encoder current range is [L, H], with `L` and `H` having the same common
|
|
||||||
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
|
||||||
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
|
||||||
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
|
||||||
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
|
||||||
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
|
||||||
and we will need to read new bits from the stream and repeat the process.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
|
||||||
self.total_range_bits = total_range_bits
|
|
||||||
self.low: int = 0
|
|
||||||
self.high: int = 0
|
|
||||||
self.current: int = 0
|
|
||||||
self.max_bit: int = -1
|
|
||||||
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
|
||||||
# Following is for debugging
|
|
||||||
self._dbg: tp.List[tp.Any] = []
|
|
||||||
self._dbg2: tp.List[tp.Any] = []
|
|
||||||
self._last: tp.Any = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def delta(self) -> int:
|
|
||||||
return self.high - self.low + 1
|
|
||||||
|
|
||||||
def _flush_common_prefix(self):
|
|
||||||
# Given the current range [L, H], if both have a common prefix,
|
|
||||||
# we know we can remove it from our representation to avoid handling large numbers.
|
|
||||||
while self.max_bit >= 0:
|
|
||||||
b1 = self.low >> self.max_bit
|
|
||||||
b2 = self.high >> self.max_bit
|
|
||||||
if b1 == b2:
|
|
||||||
self.low -= (b1 << self.max_bit)
|
|
||||||
self.high -= (b1 << self.max_bit)
|
|
||||||
self.current -= (b1 << self.max_bit)
|
|
||||||
assert self.high >= self.low
|
|
||||||
assert self.low >= 0
|
|
||||||
self.max_bit -= 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
|
|
||||||
"""Pull a symbol, reading as many bits from the stream as required.
|
|
||||||
This returns `None` when the stream has been exhausted.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
|
||||||
to build this from your pdf estimate. This must be **exactly**
|
|
||||||
the same cdf as the one used at encoding time.
|
|
||||||
"""
|
|
||||||
while self.delta < 2 ** self.total_range_bits:
|
|
||||||
bit = self.unpacker.pull()
|
|
||||||
if bit is None:
|
|
||||||
return None
|
|
||||||
self.low *= 2
|
|
||||||
self.high = self.high * 2 + 1
|
|
||||||
self.current = self.current * 2 + bit
|
|
||||||
self.max_bit += 1
|
|
||||||
|
|
||||||
def bin_search(low_idx: int, high_idx: int):
|
|
||||||
# Binary search is not just for coding interviews :)
|
|
||||||
if high_idx < low_idx:
|
|
||||||
raise RuntimeError("Binary search failed")
|
|
||||||
mid = (low_idx + high_idx) // 2
|
|
||||||
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
|
||||||
range_high = quantized_cdf[mid].item() - 1
|
|
||||||
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
|
|
||||||
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
|
|
||||||
low = effective_low + self.low
|
|
||||||
high = effective_high + self.low
|
|
||||||
if self.current >= low:
|
|
||||||
if self.current <= high:
|
|
||||||
return mid, low, high, self.current
|
|
||||||
else:
|
|
||||||
return bin_search(mid + 1, high_idx)
|
|
||||||
else:
|
|
||||||
return bin_search(low_idx, mid - 1)
|
|
||||||
|
|
||||||
self._last = (self.low, self.high, self.current, self.max_bit)
|
|
||||||
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
|
||||||
self._dbg.append((self.low, self.high, self.current))
|
|
||||||
self._flush_common_prefix()
|
|
||||||
self._dbg2.append((self.low, self.high, self.current))
|
|
||||||
|
|
||||||
return sym
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
|
||||||
torch.manual_seed(1234)
|
|
||||||
random.seed(1234)
|
|
||||||
for _ in range(4):
|
|
||||||
pdfs = []
|
|
||||||
cardinality = random.randrange(4000)
|
|
||||||
steps = random.randrange(100, 500)
|
|
||||||
fo = io.BytesIO()
|
|
||||||
encoder = ArithmeticCoder(fo)
|
|
||||||
symbols = []
|
|
||||||
for step in range(steps):
|
|
||||||
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
|
||||||
pdfs.append(pdf)
|
|
||||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
|
||||||
symbol = torch.multinomial(pdf, 1).item()
|
|
||||||
symbols.append(symbol)
|
|
||||||
encoder.push(symbol, q_cdf)
|
|
||||||
encoder.flush()
|
|
||||||
|
|
||||||
fo.seek(0)
|
|
||||||
decoder = ArithmeticDecoder(fo)
|
|
||||||
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
|
||||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
|
||||||
decoded_symbol = decoder.pull(q_cdf)
|
|
||||||
assert decoded_symbol == symbol, idx
|
|
||||||
assert decoder.pull(torch.zeros(1)) is None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test()
|
|
||||||
@ -1,157 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
|
|
||||||
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import struct
|
|
||||||
import typing as tp
|
|
||||||
import ctypes
|
|
||||||
|
|
||||||
# format is `ECDC` magic code, followed by the header size as uint32.
|
|
||||||
# Then an uint8 indicates the protocol version (0.)
|
|
||||||
# The header is then provided as json and should contain all required
|
|
||||||
# informations for decoding. A raw stream of bytes is then provided
|
|
||||||
# and should be interpretable using the json header.
|
|
||||||
_encodec_header_struct = struct.Struct('!4sBI')
|
|
||||||
_ENCODEC_MAGIC = b'ECDC'
|
|
||||||
|
|
||||||
|
|
||||||
def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
|
|
||||||
meta_dumped = json.dumps(metadata).encode('utf-8')
|
|
||||||
version = 0
|
|
||||||
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped))
|
|
||||||
fo.write(header)
|
|
||||||
fo.write(meta_dumped)
|
|
||||||
fo.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
|
|
||||||
buf = b""
|
|
||||||
while len(buf) < size:
|
|
||||||
new_buf = fo.read(size)
|
|
||||||
if not new_buf:
|
|
||||||
raise EOFError("Impossible to read enough data from the stream, "
|
|
||||||
f"{size} bytes remaining.")
|
|
||||||
buf += new_buf
|
|
||||||
size -= len(new_buf)
|
|
||||||
return buf
|
|
||||||
|
|
||||||
|
|
||||||
def read_ecdc_header(fo: tp.IO[bytes]):
|
|
||||||
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
|
|
||||||
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
|
|
||||||
if magic != _ENCODEC_MAGIC:
|
|
||||||
raise ValueError("File is not in ECDC format.")
|
|
||||||
if version != 0:
|
|
||||||
raise ValueError("Version not supported.")
|
|
||||||
meta_bytes = _read_exactly(fo, meta_size)
|
|
||||||
return json.loads(meta_bytes.decode('utf-8'))
|
|
||||||
|
|
||||||
|
|
||||||
class BitPacker:
|
|
||||||
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
|
|
||||||
Note that for some bandwidth (1.5, 3), the codebook representation
|
|
||||||
will not cover an integer number of bytes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bits (int): number of bits per value that will be pushed.
|
|
||||||
fo (IO[bytes]): file-object to push the bytes to.
|
|
||||||
"""
|
|
||||||
def __init__(self, bits: int, fo: tp.IO[bytes]):
|
|
||||||
self._current_value = 0
|
|
||||||
self._current_bits = 0
|
|
||||||
self.bits = bits
|
|
||||||
self.fo = fo
|
|
||||||
|
|
||||||
def push(self, value: int):
|
|
||||||
"""Push a new value to the stream. This will immediately
|
|
||||||
write as many uint8 as possible to the underlying file-object."""
|
|
||||||
value = bin(ctypes.c_uint32.from_buffer(ctypes.c_float(value)).value)
|
|
||||||
value = ctypes.c_int.from_buffer(ctypes.c_uint32(int(value, 2))).value
|
|
||||||
self._current_value += (value << self._current_bits)
|
|
||||||
self._current_bits += self.bits
|
|
||||||
while self._current_bits >= 8:
|
|
||||||
lower_8bits = self._current_value & 0xff
|
|
||||||
self._current_bits -= 8
|
|
||||||
self._current_value >>= 8
|
|
||||||
self.fo.write(bytes([lower_8bits]))
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
"""Flushes the remaining partial uint8, call this at the end
|
|
||||||
of the stream to encode."""
|
|
||||||
if self._current_bits:
|
|
||||||
self.fo.write(bytes([self._current_value]))
|
|
||||||
self._current_value = 0
|
|
||||||
self._current_bits = 0
|
|
||||||
self.fo.flush()
|
|
||||||
|
|
||||||
|
|
||||||
class BitUnpacker:
|
|
||||||
"""BitUnpacker does the opposite of `BitPacker`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bits (int): number of bits of the values to decode.
|
|
||||||
fo (IO[bytes]): file-object to push the bytes to.
|
|
||||||
"""
|
|
||||||
def __init__(self, bits: int, fo: tp.IO[bytes]):
|
|
||||||
self.bits = bits
|
|
||||||
self.fo = fo
|
|
||||||
self._mask = (1 << bits) - 1
|
|
||||||
self._current_value = 0
|
|
||||||
self._current_bits = 0
|
|
||||||
|
|
||||||
def pull(self) -> tp.Optional[int]:
|
|
||||||
"""
|
|
||||||
Pull a single value from the stream, potentially reading some
|
|
||||||
extra bytes from the underlying file-object.
|
|
||||||
Returns `None` when reaching the end of the stream.
|
|
||||||
"""
|
|
||||||
while self._current_bits < self.bits:
|
|
||||||
buf = self.fo.read(1)
|
|
||||||
if not buf:
|
|
||||||
return None
|
|
||||||
character = buf[0]
|
|
||||||
self._current_value += character << self._current_bits
|
|
||||||
self._current_bits += 8
|
|
||||||
|
|
||||||
out = self._current_value & self._mask
|
|
||||||
self._current_value >>= self.bits
|
|
||||||
self._current_bits -= self.bits
|
|
||||||
# out = ctypes.c_float.from_buffer(ctypes.c_uint32(int(out, 2))).value
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
|
||||||
import torch
|
|
||||||
torch.manual_seed(1234)
|
|
||||||
for rep in range(4):
|
|
||||||
length: int = torch.randint(10, 2_000, (1,)).item()
|
|
||||||
bits: int = torch.randint(1, 16, (1,)).item()
|
|
||||||
tokens: tp.List[int] = torch.randint(2 ** bits, (length,)).tolist()
|
|
||||||
rebuilt: tp.List[int] = []
|
|
||||||
buf = io.BytesIO()
|
|
||||||
packer = BitPacker(bits, buf)
|
|
||||||
for token in tokens:
|
|
||||||
packer.push(token)
|
|
||||||
packer.flush()
|
|
||||||
buf.seek(0)
|
|
||||||
unpacker = BitUnpacker(bits, buf)
|
|
||||||
while True:
|
|
||||||
value = unpacker.pull()
|
|
||||||
if value is None:
|
|
||||||
break
|
|
||||||
rebuilt.append(value)
|
|
||||||
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
|
|
||||||
# The flushing mechanism might lead to "ghost" values at the end of the stream.
|
|
||||||
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits)
|
|
||||||
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
|
|
||||||
assert a == b, (idx, a, b)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test()
|
|
||||||
@ -1,674 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
#
|
|
||||||
# This implementation is inspired from
|
|
||||||
# https://github.com/lucidrains/vector-quantize-pytorch
|
|
||||||
# which is released under MIT License. Hereafter, the original license:
|
|
||||||
# MIT License
|
|
||||||
#
|
|
||||||
# Copyright (c) 2020 Phil Wang
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
|
||||||
# in the Software without restriction, including without limitation the rights
|
|
||||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
# copies of the Software, and to permit persons to whom the Software is
|
|
||||||
# furnished to do so, subject to the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
# SOFTWARE.
|
|
||||||
|
|
||||||
"""Core vector quantization implementation."""
|
|
||||||
import logging
|
|
||||||
import typing as tp
|
|
||||||
from random import randrange
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from math import ceil
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import random
|
|
||||||
|
|
||||||
from funasr.models.sense_voice.quantizer.quantization import distrib
|
|
||||||
|
|
||||||
def round_up_multiple(num, mult):
|
|
||||||
return ceil(num / mult) * mult
|
|
||||||
|
|
||||||
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
|
||||||
return val if val is not None else d
|
|
||||||
|
|
||||||
|
|
||||||
def ema_inplace(moving_avg, new, decay: float):
|
|
||||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
|
||||||
|
|
||||||
|
|
||||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
|
||||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
|
||||||
|
|
||||||
|
|
||||||
def uniform_init(*shape: int):
|
|
||||||
t = torch.empty(shape)
|
|
||||||
nn.init.kaiming_uniform_(t)
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def sample_vectors(samples, num: int):
|
|
||||||
num_samples, device = samples.shape[0], samples.device
|
|
||||||
|
|
||||||
if num_samples >= num:
|
|
||||||
indices = torch.randperm(num_samples, device=device)[:num]
|
|
||||||
else:
|
|
||||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
|
||||||
|
|
||||||
return samples[indices]
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
|
||||||
# device = samples.device
|
|
||||||
# samples = samples.cpu()
|
|
||||||
dim, dtype = samples.shape[-1], samples.dtype
|
|
||||||
|
|
||||||
means = sample_vectors(samples, num_clusters)
|
|
||||||
|
|
||||||
for _ in range(num_iters):
|
|
||||||
# diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
|
||||||
# means, "c d -> () c d"
|
|
||||||
# )
|
|
||||||
# dists = -(diffs ** 2).sum(dim=-1)
|
|
||||||
dists = -(
|
|
||||||
samples.pow(2).sum(1, keepdim=True)
|
|
||||||
- 2 * torch.matmul(samples, means.t())
|
|
||||||
+ means.t().pow(2).sum(0, keepdim=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
buckets = dists.max(dim=-1).indices
|
|
||||||
del dists
|
|
||||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
|
||||||
zero_mask = bins == 0
|
|
||||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
|
||||||
|
|
||||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
|
||||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
|
||||||
new_means = new_means / bins_min_clamped[..., None]
|
|
||||||
|
|
||||||
means = torch.where(zero_mask[..., None], means, new_means)
|
|
||||||
|
|
||||||
# means = means.to(device)
|
|
||||||
return means, bins
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(x):
|
|
||||||
x = rearrange(x, "... d -> (...) d")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def postprocess_emb(embed_ind, shape):
|
|
||||||
return embed_ind.view(*shape[:-1])
|
|
||||||
|
|
||||||
|
|
||||||
class EuclideanCodebook(nn.Module):
|
|
||||||
"""Codebook with Euclidean distance.
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension.
|
|
||||||
codebook_size (int): Codebook size.
|
|
||||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
|
||||||
If set to true, run the k-means algorithm on the first training batch and use
|
|
||||||
the learned centroids as initialization.
|
|
||||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
|
||||||
decay (float): Decay for exponential moving average over the codebooks.
|
|
||||||
epsilon (float): Epsilon value for numerical stability.
|
|
||||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
||||||
that have an exponential moving average cluster size less than the specified threshold with
|
|
||||||
randomly selected vector from the current batch.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
codebook_size: int,
|
|
||||||
kmeans_init: int = False,
|
|
||||||
kmeans_iters: int = 10,
|
|
||||||
decay: float = 0.99,
|
|
||||||
epsilon: float = 1e-5,
|
|
||||||
threshold_ema_dead_code: int = 2,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.decay = decay
|
|
||||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
|
||||||
embed = init_fn(codebook_size, dim)
|
|
||||||
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
|
|
||||||
self.kmeans_iters = kmeans_iters
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
||||||
|
|
||||||
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
|
||||||
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
|
||||||
self.register_buffer("embed", embed)
|
|
||||||
self.register_buffer("embed_avg", embed.clone())
|
|
||||||
self.training = True
|
|
||||||
|
|
||||||
def init_embed_(self, data):
|
|
||||||
if self.inited:
|
|
||||||
return
|
|
||||||
|
|
||||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
|
||||||
self.embed.data.copy_(embed)
|
|
||||||
self.embed_avg.data.copy_(embed.clone())
|
|
||||||
self.cluster_size.data.copy_(cluster_size)
|
|
||||||
self.inited.data.copy_(torch.Tensor([True]))
|
|
||||||
# Make sure all buffers across workers are in sync after initialization
|
|
||||||
distrib.broadcast_tensors(self.buffers())
|
|
||||||
|
|
||||||
def replace_(self, samples, mask):
|
|
||||||
modified_codebook = torch.where(
|
|
||||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
|
||||||
)
|
|
||||||
self.embed.data.copy_(modified_codebook)
|
|
||||||
|
|
||||||
def expire_codes_(self, batch_samples):
|
|
||||||
if self.threshold_ema_dead_code == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
|
||||||
if not torch.any(expired_codes):
|
|
||||||
return
|
|
||||||
|
|
||||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
|
||||||
self.replace_(batch_samples, mask=expired_codes)
|
|
||||||
distrib.broadcast_tensors(self.buffers())
|
|
||||||
|
|
||||||
def quantize(self, x):
|
|
||||||
embed = self.embed.t()
|
|
||||||
dist = -(
|
|
||||||
x.pow(2).sum(1, keepdim=True)
|
|
||||||
- 2 * x @ embed
|
|
||||||
+ embed.pow(2).sum(0, keepdim=True)
|
|
||||||
)
|
|
||||||
embed_ind = dist.max(dim=-1).indices
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def dequantize(self, embed_ind):
|
|
||||||
quantize = F.embedding(embed_ind, self.embed)
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
shape = x.shape
|
|
||||||
# pre-process
|
|
||||||
x = preprocess(x)
|
|
||||||
# quantize
|
|
||||||
embed_ind = self.quantize(x)
|
|
||||||
# post-process
|
|
||||||
embed_ind = postprocess_emb(embed_ind, shape)
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def decode(self, embed_ind):
|
|
||||||
quantize = self.dequantize(embed_ind)
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
shape, dtype = x.shape, x.dtype
|
|
||||||
x = preprocess(x)
|
|
||||||
|
|
||||||
self.init_embed_(x)
|
|
||||||
|
|
||||||
embed_ind = self.quantize(x)
|
|
||||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
|
||||||
embed_ind = postprocess_emb(embed_ind, shape)
|
|
||||||
quantize = self.dequantize(embed_ind)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# We do the expiry of code at that point as buffers are in sync
|
|
||||||
# and all the workers will take the same decision.
|
|
||||||
self.expire_codes_(x)
|
|
||||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
|
||||||
embed_sum = x.t() @ embed_onehot
|
|
||||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
|
||||||
cluster_size = (
|
|
||||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
|
||||||
* self.cluster_size.sum()
|
|
||||||
)
|
|
||||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
|
||||||
self.embed.data.copy_(embed_normalized)
|
|
||||||
# Note: after ema update, there is a very small difference between codebooks on GPUs.
|
|
||||||
# The impact can be very small, ignore it.
|
|
||||||
|
|
||||||
return quantize, embed_ind
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleEuclideanCodebook(nn.Module):
|
|
||||||
"""Simple Codebook with Euclidean distance.
|
|
||||||
Using gradient to update code embeddings instead of EMA.
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension.
|
|
||||||
codebook_size (int): Codebook size.
|
|
||||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
|
||||||
If set to true, run the k-means algorithm on the first training batch and use
|
|
||||||
the learned centroids as initialization.
|
|
||||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
codebook_size: int,
|
|
||||||
kmeans_init: tp.Union[bool, torch.Tensor] = False,
|
|
||||||
kmeans_iters: int = 10,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if isinstance(kmeans_init, bool):
|
|
||||||
if kmeans_init:
|
|
||||||
embed = torch.zeros(codebook_size, dim)
|
|
||||||
inited = False
|
|
||||||
else:
|
|
||||||
embed = uniform_init(codebook_size, dim)
|
|
||||||
inited = True
|
|
||||||
else:
|
|
||||||
embed = kmeans_init
|
|
||||||
inited = True
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.kmeans_iters = kmeans_iters
|
|
||||||
|
|
||||||
self.embed = nn.Embedding(codebook_size, dim)
|
|
||||||
self.embed.weight.data.copy_(embed)
|
|
||||||
# self.register_parameter("embed", nn.Parameter(embed, requires_grad=True))
|
|
||||||
self.register_buffer("inited", torch.Tensor([inited]))
|
|
||||||
self.training = True
|
|
||||||
|
|
||||||
def init_embed_(self, data):
|
|
||||||
if self.inited:
|
|
||||||
return
|
|
||||||
|
|
||||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
|
||||||
self.embed.data.copy_(embed)
|
|
||||||
self.inited.data.copy_(torch.Tensor([True]))
|
|
||||||
# Make sure all buffers across workers are in sync after initialization
|
|
||||||
distrib.broadcast_tensors(self.buffers())
|
|
||||||
|
|
||||||
def quantize(self, x):
|
|
||||||
embed = self.embed.weight.t()
|
|
||||||
dist = -(
|
|
||||||
x.pow(2).sum(1, keepdim=True)
|
|
||||||
- 2 * x @ embed
|
|
||||||
+ embed.pow(2).sum(0, keepdim=True)
|
|
||||||
)
|
|
||||||
embed_ind = dist.max(dim=-1).indices
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def dequantize(self, embed_ind):
|
|
||||||
quantize = self.embed(embed_ind)
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
shape = x.shape
|
|
||||||
# pre-process
|
|
||||||
x = preprocess(x)
|
|
||||||
# quantize
|
|
||||||
embed_ind = self.quantize(x)
|
|
||||||
# post-process
|
|
||||||
embed_ind = postprocess_emb(embed_ind, shape)
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def decode(self, embed_ind):
|
|
||||||
quantize = self.dequantize(embed_ind)
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
shape, dtype = x.shape, x.dtype
|
|
||||||
x = preprocess(x)
|
|
||||||
|
|
||||||
self.init_embed_(x)
|
|
||||||
|
|
||||||
embed_ind = self.quantize(x)
|
|
||||||
embed_ind = postprocess_emb(embed_ind, shape)
|
|
||||||
quantize = self.dequantize(embed_ind)
|
|
||||||
|
|
||||||
return quantize, embed_ind
|
|
||||||
|
|
||||||
|
|
||||||
class VectorQuantization(nn.Module):
|
|
||||||
"""Vector quantization implementation.
|
|
||||||
Currently, supports only euclidean distance.
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension
|
|
||||||
codebook_size (int): Codebook size
|
|
||||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
|
||||||
decay (float): Decay for exponential moving average over the codebooks.
|
|
||||||
epsilon (float): Epsilon value for numerical stability.
|
|
||||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
|
||||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
|
||||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
||||||
that have an exponential moving average cluster size less than the specified threshold with
|
|
||||||
randomly selected vector from the current batch.
|
|
||||||
commitment_weight (float): Weight for commitment loss.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
codebook_size: int,
|
|
||||||
codebook_dim: tp.Optional[int] = None,
|
|
||||||
decay: float = 0.99,
|
|
||||||
epsilon: float = 1e-5,
|
|
||||||
kmeans_init: bool = True,
|
|
||||||
kmeans_iters: int = 50,
|
|
||||||
threshold_ema_dead_code: int = 2,
|
|
||||||
commitment_weight: float = 1.,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
_codebook_dim: int = default(codebook_dim, dim)
|
|
||||||
|
|
||||||
requires_projection = _codebook_dim != dim
|
|
||||||
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
|
||||||
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
|
||||||
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.commitment_weight = commitment_weight
|
|
||||||
|
|
||||||
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
|
||||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
|
||||||
decay=decay, epsilon=epsilon,
|
|
||||||
threshold_ema_dead_code=threshold_ema_dead_code)
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.training = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def codebook(self):
|
|
||||||
return self._codebook.embed
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
x = rearrange(x, "b d n -> b n d")
|
|
||||||
x = self.project_in(x)
|
|
||||||
embed_in = self._codebook.encode(x)
|
|
||||||
return embed_in
|
|
||||||
|
|
||||||
def decode(self, embed_ind):
|
|
||||||
quantize = self._codebook.decode(embed_ind)
|
|
||||||
quantize = self.project_out(quantize)
|
|
||||||
quantize = rearrange(quantize, "b n d -> b d n")
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
device = x.device
|
|
||||||
x = rearrange(x, "b d n -> b n d")
|
|
||||||
x = self.project_in(x)
|
|
||||||
|
|
||||||
quantize, embed_ind = self._codebook(x)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
quantize = x + (quantize - x).detach()
|
|
||||||
|
|
||||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
if self.commitment_weight > 0:
|
|
||||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
|
||||||
loss = loss + commit_loss * self.commitment_weight
|
|
||||||
|
|
||||||
quantize = self.project_out(quantize)
|
|
||||||
quantize = rearrange(quantize, "b n d -> b d n")
|
|
||||||
return quantize, embed_ind, loss
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleVectorQuantization(nn.Module):
|
|
||||||
"""Vector quantization implementation with SimpleEuclideanCodebook.
|
|
||||||
Currently, supports only euclidean distance.
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension
|
|
||||||
codebook_size (int): Codebook size
|
|
||||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
|
||||||
decay (float): Decay for exponential moving average over the codebooks.
|
|
||||||
epsilon (float): Epsilon value for numerical stability.
|
|
||||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
|
||||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
|
||||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
||||||
that have an exponential moving average cluster size less than the specified threshold with
|
|
||||||
randomly selected vector from the current batch.
|
|
||||||
commitment_weight (float): Weight for commitment loss.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
codebook_size: int,
|
|
||||||
codebook_dim: tp.Optional[int] = None,
|
|
||||||
epsilon: float = 1e-5,
|
|
||||||
kmeans_init: bool = True,
|
|
||||||
kmeans_iters: int = 50,
|
|
||||||
commitment_weight: float = 0.25,
|
|
||||||
codebook_weight: float = 1.0,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
_codebook_dim: int = default(codebook_dim, dim)
|
|
||||||
|
|
||||||
requires_projection = _codebook_dim != dim
|
|
||||||
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
|
||||||
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
|
||||||
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.commitment_weight = commitment_weight
|
|
||||||
self.codebook_weight = codebook_weight
|
|
||||||
logging.info(f"commitment_weight: {commitment_weight}, codebook_weight: {codebook_weight}.")
|
|
||||||
|
|
||||||
self._codebook = SimpleEuclideanCodebook(
|
|
||||||
dim=_codebook_dim, codebook_size=codebook_size,
|
|
||||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters
|
|
||||||
)
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.training = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def codebook(self):
|
|
||||||
return self._codebook.embed
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
x = rearrange(x, "b d n -> b n d")
|
|
||||||
x = self.project_in(x)
|
|
||||||
embed_in = self._codebook.encode(x)
|
|
||||||
return embed_in
|
|
||||||
|
|
||||||
def decode(self, embed_ind):
|
|
||||||
quantize = self._codebook.decode(embed_ind)
|
|
||||||
quantize = self.project_out(quantize)
|
|
||||||
quantize = rearrange(quantize, "b n d -> b d n")
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
device = x.device
|
|
||||||
x = rearrange(x, "b d n -> b n d")
|
|
||||||
x = self.project_in(x)
|
|
||||||
|
|
||||||
quantize, embed_ind = self._codebook(x)
|
|
||||||
|
|
||||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# commit loss for codebook
|
|
||||||
if self.codebook_weight > 0:
|
|
||||||
codebook_loss = F.mse_loss(quantize, x.detach())
|
|
||||||
loss = loss + codebook_loss * self.codebook_weight
|
|
||||||
|
|
||||||
# commit loss for encoder
|
|
||||||
if self.commitment_weight > 0:
|
|
||||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
|
||||||
loss = loss + commit_loss * self.commitment_weight
|
|
||||||
|
|
||||||
quantize = self.project_out(quantize)
|
|
||||||
quantize = rearrange(quantize, "b n d -> b d n")
|
|
||||||
return quantize, embed_ind, loss
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualVectorQuantization(nn.Module):
|
|
||||||
"""Residual vector quantization implementation.
|
|
||||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
|
||||||
"""
|
|
||||||
def __init__(self, *,
|
|
||||||
num_quantizers,
|
|
||||||
quantize_dropout: bool = False,
|
|
||||||
rand_num_quant: tp.Optional[tp.List] = None,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
|
||||||
)
|
|
||||||
self.quantize_dropout = quantize_dropout
|
|
||||||
self.rand_num_quant = rand_num_quant
|
|
||||||
|
|
||||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
|
||||||
quantized_out = 0.0
|
|
||||||
residual = x
|
|
||||||
device = x.device
|
|
||||||
|
|
||||||
all_losses = []
|
|
||||||
all_indices = []
|
|
||||||
all_sub_quants = []
|
|
||||||
n_q = n_q or len(self.layers)
|
|
||||||
|
|
||||||
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
|
||||||
if should_quantize_dropout:
|
|
||||||
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
|
||||||
|
|
||||||
null_indices_shape = (x.shape[0], x.shape[2])
|
|
||||||
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
|
||||||
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
|
||||||
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
|
||||||
|
|
||||||
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
|
||||||
# dropout except the first quantizer
|
|
||||||
if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index:
|
|
||||||
all_indices.append(null_indices)
|
|
||||||
all_losses.append(null_loss)
|
|
||||||
all_sub_quants.append(null_sub_quant)
|
|
||||||
continue
|
|
||||||
|
|
||||||
quantized, indices, loss = layer(residual)
|
|
||||||
residual = residual - quantized
|
|
||||||
quantized_out = quantized_out + quantized
|
|
||||||
|
|
||||||
all_indices.append(indices)
|
|
||||||
all_losses.append(loss)
|
|
||||||
all_sub_quants.append(quantized)
|
|
||||||
|
|
||||||
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
|
||||||
return quantized_out, out_indices, out_losses, out_sub_quants
|
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
|
||||||
residual = x
|
|
||||||
all_indices = []
|
|
||||||
n_q = n_q or len(self.layers)
|
|
||||||
for layer in self.layers[:n_q]:
|
|
||||||
indices = layer.encode(residual)
|
|
||||||
quantized = layer.decode(indices)
|
|
||||||
residual = residual - quantized
|
|
||||||
all_indices.append(indices)
|
|
||||||
out_indices = torch.stack(all_indices)
|
|
||||||
return out_indices
|
|
||||||
|
|
||||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
|
||||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
|
||||||
for i, indices in enumerate(q_indices):
|
|
||||||
layer = self.layers[i]
|
|
||||||
quantized = layer.decode(indices)
|
|
||||||
quantized_out = quantized_out + quantized
|
|
||||||
return quantized_out
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleResidualVectorQuantization(nn.Module):
|
|
||||||
"""Simple Residual vector quantization with gradient to
|
|
||||||
update codebook instead of EMA
|
|
||||||
"""
|
|
||||||
def __init__(self, *,
|
|
||||||
num_quantizers,
|
|
||||||
quantize_dropout: bool = False,
|
|
||||||
rand_num_quant: tp.Optional[tp.List] = None,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__()
|
|
||||||
kmeans_init = raw_kmeans_init = kwargs.pop('kmeans_init', True)
|
|
||||||
if isinstance(kmeans_init, str):
|
|
||||||
# use prepared kmeans init
|
|
||||||
embed = np.load(kmeans_init)
|
|
||||||
embed = torch.from_numpy(embed)
|
|
||||||
if embed.dim() == 2:
|
|
||||||
embed = embed.unsqueeze(0)
|
|
||||||
kmeans_init = embed
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
SimpleVectorQuantization(
|
|
||||||
kmeans_init=kmeans_init[i] if isinstance(kmeans_init, torch.Tensor) else kmeans_init,
|
|
||||||
**kwargs
|
|
||||||
) for i in range(num_quantizers)
|
|
||||||
])
|
|
||||||
kwargs["kmeans_init"] = raw_kmeans_init
|
|
||||||
self.quantize_dropout = quantize_dropout
|
|
||||||
self.rand_num_quant = rand_num_quant
|
|
||||||
|
|
||||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
|
||||||
quantized_out = 0.0
|
|
||||||
residual = x
|
|
||||||
device = x.device
|
|
||||||
|
|
||||||
all_losses = []
|
|
||||||
all_indices = []
|
|
||||||
all_sub_quants = []
|
|
||||||
n_q = n_q or len(self.layers)
|
|
||||||
|
|
||||||
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
|
||||||
if should_quantize_dropout:
|
|
||||||
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
|
||||||
|
|
||||||
null_indices_shape = (x.shape[0], x.shape[2])
|
|
||||||
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
|
||||||
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
|
||||||
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
|
||||||
|
|
||||||
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
|
||||||
# dropout except the first quantizer
|
|
||||||
if should_quantize_dropout and quantizer_index > 0 and quantizer_index > rand_quantize_dropout_index:
|
|
||||||
all_indices.append(null_indices)
|
|
||||||
all_losses.append(null_loss)
|
|
||||||
all_sub_quants.append(null_sub_quant)
|
|
||||||
continue
|
|
||||||
|
|
||||||
quantized, indices, loss = layer(residual)
|
|
||||||
residual = residual - quantized
|
|
||||||
quantized_out = quantized_out + quantized
|
|
||||||
|
|
||||||
all_indices.append(indices)
|
|
||||||
all_losses.append(loss)
|
|
||||||
all_sub_quants.append(quantized)
|
|
||||||
|
|
||||||
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
|
||||||
return quantized_out, out_indices, out_losses, out_sub_quants
|
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
|
||||||
residual = x
|
|
||||||
all_indices = []
|
|
||||||
n_q = n_q or len(self.layers)
|
|
||||||
for layer in self.layers[:n_q]:
|
|
||||||
indices = layer.encode(residual)
|
|
||||||
quantized = layer.decode(indices)
|
|
||||||
residual = residual - quantized
|
|
||||||
all_indices.append(indices)
|
|
||||||
out_indices = torch.stack(all_indices)
|
|
||||||
return out_indices
|
|
||||||
|
|
||||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
|
||||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
|
||||||
for i, indices in enumerate(q_indices):
|
|
||||||
layer = self.layers[i]
|
|
||||||
quantized = layer.decode(indices)
|
|
||||||
quantized_out = quantized_out + quantized
|
|
||||||
return quantized_out
|
|
||||||
@ -1,513 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
#
|
|
||||||
# This implementation is inspired from
|
|
||||||
# https://github.com/lucidrains/vector-quantize-pytorch
|
|
||||||
# which is released under MIT License. Hereafter, the original license:
|
|
||||||
# MIT License
|
|
||||||
#
|
|
||||||
# Copyright (c) 2020 Phil Wang
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
|
||||||
# in the Software without restriction, including without limitation the rights
|
|
||||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
# copies of the Software, and to permit persons to whom the Software is
|
|
||||||
# furnished to do so, subject to the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
# SOFTWARE.
|
|
||||||
|
|
||||||
"""Core vector quantization implementation."""
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
import typing as tp
|
|
||||||
from random import randrange
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from math import ceil
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from funasr.utils.hinter import hint_once
|
|
||||||
|
|
||||||
from funasr.models.sense_voice.quantizer.quantization import distrib
|
|
||||||
|
|
||||||
def round_up_multiple(num, mult):
|
|
||||||
return ceil(num / mult) * mult
|
|
||||||
|
|
||||||
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
|
||||||
return val if val is not None else d
|
|
||||||
|
|
||||||
|
|
||||||
def ema_inplace(moving_avg, new, decay: float, mask=None):
|
|
||||||
if mask is None:
|
|
||||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
|
||||||
else:
|
|
||||||
mask = mask.float()
|
|
||||||
new_avg = moving_avg * decay + new * (1 - decay)
|
|
||||||
new_avg = mask * new_avg + (1 - mask) * moving_avg
|
|
||||||
moving_avg.data.copy_(new_avg.data)
|
|
||||||
|
|
||||||
|
|
||||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
|
||||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
|
||||||
|
|
||||||
|
|
||||||
def uniform_init(*shape: int):
|
|
||||||
t = torch.empty(shape)
|
|
||||||
nn.init.kaiming_uniform_(t)
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def sample_vectors(samples, num: int):
|
|
||||||
num_samples, device = samples.shape[0], samples.device
|
|
||||||
|
|
||||||
if num_samples >= num:
|
|
||||||
indices = torch.randperm(num_samples, device=device)[:num]
|
|
||||||
else:
|
|
||||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
|
||||||
|
|
||||||
return samples[indices]
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
|
||||||
# device = samples.device
|
|
||||||
# samples = samples.cpu()
|
|
||||||
dim, dtype = samples.shape[-1], samples.dtype
|
|
||||||
|
|
||||||
means = sample_vectors(samples, num_clusters)
|
|
||||||
|
|
||||||
for _ in range(num_iters):
|
|
||||||
# diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
|
||||||
# means, "c d -> () c d"
|
|
||||||
# )
|
|
||||||
# dists = -(diffs ** 2).sum(dim=-1)
|
|
||||||
dists = -(
|
|
||||||
samples.pow(2).sum(1, keepdim=True)
|
|
||||||
- 2 * torch.matmul(samples, means.t())
|
|
||||||
+ means.t().pow(2).sum(0, keepdim=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
buckets = dists.max(dim=-1).indices
|
|
||||||
del dists
|
|
||||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
|
||||||
zero_mask = bins == 0
|
|
||||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
|
||||||
|
|
||||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
|
||||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
|
||||||
new_means = new_means / bins_min_clamped[..., None]
|
|
||||||
|
|
||||||
means = torch.where(zero_mask[..., None], means, new_means)
|
|
||||||
|
|
||||||
# means = means.to(device)
|
|
||||||
return means, bins
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(x):
|
|
||||||
x = rearrange(x, "... d -> (...) d")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def postprocess_emb(embed_ind, shape):
|
|
||||||
return embed_ind.view(*shape[:-1])
|
|
||||||
|
|
||||||
|
|
||||||
class EuclideanCodebook(nn.Module):
|
|
||||||
"""Codebook with Euclidean distance.
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension.
|
|
||||||
codebook_size (int): Codebook size.
|
|
||||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
|
||||||
If set to true, run the k-means algorithm on the first training batch and use
|
|
||||||
the learned centroids as initialization.
|
|
||||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
|
||||||
decay (float): Decay for exponential moving average over the codebooks.
|
|
||||||
epsilon (float): Epsilon value for numerical stability.
|
|
||||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
||||||
that have an exponential moving average cluster size less than the specified threshold with
|
|
||||||
randomly selected vector from the current batch.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
codebook_size: int,
|
|
||||||
kmeans_init: int = False,
|
|
||||||
kmeans_iters: int = 10,
|
|
||||||
decay: float = 0.99,
|
|
||||||
epsilon: float = 1e-5,
|
|
||||||
threshold_ema_dead_code: int = 2,
|
|
||||||
sparse_update: bool = False,
|
|
||||||
normalized_input: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.decay = decay
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.kmeans_iters = kmeans_iters
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
||||||
|
|
||||||
self.inited = None
|
|
||||||
self.cluster_size = None
|
|
||||||
self.embed = None
|
|
||||||
self.embed_avg = None
|
|
||||||
self.training = True
|
|
||||||
self.sparse_update = sparse_update
|
|
||||||
self.normalized_input = normalized_input
|
|
||||||
|
|
||||||
def init_embed_(self, data):
|
|
||||||
if self.inited:
|
|
||||||
return
|
|
||||||
|
|
||||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
|
||||||
self.embed.data.copy_(embed)
|
|
||||||
self.embed_avg.data.copy_(embed.clone())
|
|
||||||
self.cluster_size.data.copy_(cluster_size)
|
|
||||||
self.inited.data.copy_(torch.Tensor([True]))
|
|
||||||
# Make sure all buffers across workers are in sync after initialization
|
|
||||||
distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
|
|
||||||
|
|
||||||
def replace_(self, samples, mask):
|
|
||||||
modified_codebook = torch.where(
|
|
||||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
|
||||||
)
|
|
||||||
self.embed.data.copy_(modified_codebook)
|
|
||||||
|
|
||||||
def expire_codes_(self, batch_samples):
|
|
||||||
hint_once(f"threshold_ema_dead_code: {self.threshold_ema_dead_code}.", "threshold_ema_dead_code", rank=0)
|
|
||||||
if self.threshold_ema_dead_code == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
|
||||||
if not torch.any(expired_codes):
|
|
||||||
return
|
|
||||||
|
|
||||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
|
||||||
self.replace_(batch_samples, mask=expired_codes)
|
|
||||||
# sync buffers outside for efficiency
|
|
||||||
# distrib.broadcast_tensors(self.buffers())
|
|
||||||
|
|
||||||
def quantize(self, x):
|
|
||||||
embed = self.embed.t()
|
|
||||||
dist = -(
|
|
||||||
x.pow(2).sum(1, keepdim=True)
|
|
||||||
- 2 * x @ embed
|
|
||||||
+ embed.pow(2).sum(0, keepdim=True)
|
|
||||||
)
|
|
||||||
embed_ind = dist.max(dim=-1).indices
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def dequantize(self, embed_ind):
|
|
||||||
quantize = F.embedding(embed_ind, self.embed)
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def encode(self, x, buffers):
|
|
||||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
|
||||||
|
|
||||||
shape = x.shape
|
|
||||||
# pre-process
|
|
||||||
x = preprocess(x)
|
|
||||||
# quantize
|
|
||||||
embed_ind = self.quantize(x)
|
|
||||||
# post-process
|
|
||||||
embed_ind = postprocess_emb(embed_ind, shape)
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def decode(self, embed_ind, buffers):
|
|
||||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
|
||||||
|
|
||||||
quantize = self.dequantize(embed_ind)
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def forward(self, x, buffers):
|
|
||||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
|
||||||
|
|
||||||
shape, dtype = x.shape, x.dtype
|
|
||||||
x = preprocess(x)
|
|
||||||
|
|
||||||
self.init_embed_(x)
|
|
||||||
|
|
||||||
embed_ind = self.quantize(x)
|
|
||||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
|
||||||
embed_ind = postprocess_emb(embed_ind, shape)
|
|
||||||
quantize = self.dequantize(embed_ind)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# We do the expiry of code at that point as buffers are in sync
|
|
||||||
# and all the workers will take the same decision.
|
|
||||||
self.expire_codes_(x)
|
|
||||||
if not self.sparse_update:
|
|
||||||
mask = None
|
|
||||||
else:
|
|
||||||
mask = embed_onehot.sum(0) > 0
|
|
||||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay, mask=mask)
|
|
||||||
embed_sum = x.t() @ embed_onehot
|
|
||||||
# if self.normalized_input:
|
|
||||||
# embed_sum = F.normalize(embed_sum, dim=0)
|
|
||||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay,
|
|
||||||
mask=mask.unsqueeze(-1) if self.sparse_update else None)
|
|
||||||
if not self.sparse_update:
|
|
||||||
cluster_size = (
|
|
||||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
|
||||||
* self.cluster_size.sum()
|
|
||||||
)
|
|
||||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
|
||||||
else:
|
|
||||||
embed_normalized = self.embed_avg
|
|
||||||
if self.normalized_input:
|
|
||||||
embed_normalized = F.normalize(embed_normalized, dim=-1)
|
|
||||||
self.embed.data.copy_(embed_normalized)
|
|
||||||
# Note: after ema update, there is a very small difference between codebooks on GPUs.
|
|
||||||
# The impact can be very small, ignore it.
|
|
||||||
|
|
||||||
return quantize, embed_ind
|
|
||||||
|
|
||||||
|
|
||||||
class VectorQuantization(nn.Module):
|
|
||||||
"""Vector quantization implementation.
|
|
||||||
Currently, supports only euclidean distance.
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension
|
|
||||||
codebook_size (int): Codebook size
|
|
||||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
|
||||||
decay (float): Decay for exponential moving average over the codebooks.
|
|
||||||
epsilon (float): Epsilon value for numerical stability.
|
|
||||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
|
||||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
|
||||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
||||||
that have an exponential moving average cluster size less than the specified threshold with
|
|
||||||
randomly selected vector from the current batch.
|
|
||||||
commitment_weight (float): Weight for commitment loss.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
codebook_size: int,
|
|
||||||
codebook_dim: tp.Optional[int] = None,
|
|
||||||
decay: float = 0.99,
|
|
||||||
epsilon: float = 1e-5,
|
|
||||||
kmeans_init: bool = True,
|
|
||||||
kmeans_iters: int = 50,
|
|
||||||
threshold_ema_dead_code: int = 2,
|
|
||||||
commitment_weight: float = 1.,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
_codebook_dim: int = default(codebook_dim, dim)
|
|
||||||
|
|
||||||
requires_projection = _codebook_dim != dim
|
|
||||||
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
|
||||||
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
|
||||||
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.commitment_weight = commitment_weight
|
|
||||||
|
|
||||||
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
|
||||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
|
||||||
decay=decay, epsilon=epsilon,
|
|
||||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
|
||||||
**kwargs)
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.training = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def codebook(self):
|
|
||||||
return self._codebook.embed
|
|
||||||
|
|
||||||
def encode(self, x, buffers):
|
|
||||||
x = rearrange(x, "b d n -> b n d")
|
|
||||||
x = self.project_in(x)
|
|
||||||
embed_in = self._codebook.encode(x, buffers)
|
|
||||||
return embed_in
|
|
||||||
|
|
||||||
def decode(self, embed_ind, buffers):
|
|
||||||
quantize = self._codebook.decode(embed_ind, buffers)
|
|
||||||
quantize = self.project_out(quantize)
|
|
||||||
quantize = rearrange(quantize, "b n d -> b d n")
|
|
||||||
return quantize
|
|
||||||
|
|
||||||
def forward(self, x, buffers):
|
|
||||||
device = x.device
|
|
||||||
x = rearrange(x, "b d n -> b n d")
|
|
||||||
x = self.project_in(x)
|
|
||||||
|
|
||||||
quantize, embed_ind = self._codebook(x, buffers)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
quantize = x + (quantize - x).detach()
|
|
||||||
|
|
||||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
if self.commitment_weight > 0:
|
|
||||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
|
||||||
loss = loss + commit_loss * self.commitment_weight
|
|
||||||
|
|
||||||
quantize = self.project_out(quantize)
|
|
||||||
quantize = rearrange(quantize, "b n d -> b d n")
|
|
||||||
return quantize, embed_ind, loss
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedResidualVectorQuantization(nn.Module):
|
|
||||||
"""Efficient distributed residual vector quantization implementation.
|
|
||||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
|
||||||
"""
|
|
||||||
def __init__(self, *,
|
|
||||||
num_quantizers,
|
|
||||||
quantize_dropout: bool = False,
|
|
||||||
rand_num_quant: tp.Optional[tp.List] = None,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__()
|
|
||||||
"""
|
|
||||||
dim: int,
|
|
||||||
codebook_size: int,
|
|
||||||
codebook_dim: tp.Optional[int] = None,
|
|
||||||
"""
|
|
||||||
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["dim"]
|
|
||||||
kmeans_init = kwargs["kmeans_init"]
|
|
||||||
if isinstance(kmeans_init, bool):
|
|
||||||
if not kwargs["kmeans_init"]:
|
|
||||||
# use uniform init
|
|
||||||
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
|
|
||||||
inited = True
|
|
||||||
cluster_size = 1
|
|
||||||
else:
|
|
||||||
# to perform kmeans init on first batch
|
|
||||||
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
|
|
||||||
inited = False
|
|
||||||
cluster_size = 0
|
|
||||||
elif isinstance(kmeans_init, str):
|
|
||||||
# use prepared kmeans init
|
|
||||||
embed = np.load(kmeans_init)
|
|
||||||
embed = torch.from_numpy(embed)
|
|
||||||
if kwargs.get("normalized_input", False):
|
|
||||||
logging.info("normalize the code embeddings since the input is normalized.")
|
|
||||||
embed = F.normalize(embed, dim=-1)
|
|
||||||
if embed.dim() == 2:
|
|
||||||
embed = embed.repeat(num_quantizers, 1, 1)
|
|
||||||
inited = True
|
|
||||||
cluster_size = 1
|
|
||||||
else:
|
|
||||||
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
|
|
||||||
|
|
||||||
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
|
|
||||||
self.register_buffer("cluster_size", torch.ones(num_quantizers, codebook_size) * cluster_size)
|
|
||||||
self.register_buffer("embed", embed)
|
|
||||||
self.register_buffer("embed_avg", embed.clone())
|
|
||||||
|
|
||||||
self.q0_ds_ratio = 1
|
|
||||||
if "q0_ds_ratio" in kwargs:
|
|
||||||
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for i in range(num_quantizers):
|
|
||||||
vq_args = dict(**kwargs)
|
|
||||||
vq = VectorQuantization(**vq_args)
|
|
||||||
self.layers.append(vq)
|
|
||||||
|
|
||||||
self.quantize_dropout = quantize_dropout
|
|
||||||
self.rand_num_quant = rand_num_quant
|
|
||||||
|
|
||||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
|
||||||
quantized_out = torch.zeros_like(x)
|
|
||||||
residual = x
|
|
||||||
bb, cc, tt = x.shape
|
|
||||||
device = x.device
|
|
||||||
|
|
||||||
all_losses = []
|
|
||||||
all_indices = []
|
|
||||||
all_sub_quants = []
|
|
||||||
n_q = n_q or len(self.layers)
|
|
||||||
|
|
||||||
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
|
||||||
if should_quantize_dropout:
|
|
||||||
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
|
||||||
|
|
||||||
null_indices_shape = (x.shape[0], x.shape[2])
|
|
||||||
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
|
||||||
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
|
||||||
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
|
||||||
|
|
||||||
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
|
||||||
# dropout except the first quantizer
|
|
||||||
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
|
|
||||||
all_indices.append(null_indices)
|
|
||||||
all_losses.append(null_loss)
|
|
||||||
all_sub_quants.append(null_sub_quant)
|
|
||||||
continue
|
|
||||||
|
|
||||||
quant_in = residual
|
|
||||||
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
|
||||||
quant_in = F.interpolate(quant_in, size=[tt//2])
|
|
||||||
quantized, indices, loss = layer(quant_in, [
|
|
||||||
self.inited[quantizer_index],
|
|
||||||
self.cluster_size[quantizer_index],
|
|
||||||
self.embed[quantizer_index],
|
|
||||||
self.embed_avg[quantizer_index]
|
|
||||||
])
|
|
||||||
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
|
||||||
quantized = F.interpolate(quantized, size=[tt])
|
|
||||||
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
|
|
||||||
residual = residual - quantized
|
|
||||||
quantized_out = quantized_out + quantized
|
|
||||||
|
|
||||||
all_indices.append(indices)
|
|
||||||
all_losses.append(loss)
|
|
||||||
all_sub_quants.append(quantized)
|
|
||||||
|
|
||||||
# sync buffers after one forward step
|
|
||||||
distrib.broadcast_tensors(self.buffers())
|
|
||||||
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
|
||||||
|
|
||||||
return quantized_out, out_indices, out_losses, out_sub_quants
|
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
|
||||||
residual = x
|
|
||||||
all_indices = []
|
|
||||||
n_q = n_q or len(self.layers)
|
|
||||||
for i, layer in enumerate(self.layers[:n_q]):
|
|
||||||
indices = layer.encode(residual, [
|
|
||||||
self.inited[i],
|
|
||||||
self.cluster_size[i],
|
|
||||||
self.embed[i],
|
|
||||||
self.embed_avg[i]
|
|
||||||
])
|
|
||||||
quantized = layer.decode(indices, [
|
|
||||||
self.inited[i],
|
|
||||||
self.cluster_size[i],
|
|
||||||
self.embed[i],
|
|
||||||
self.embed_avg[i]
|
|
||||||
])
|
|
||||||
residual = residual - quantized
|
|
||||||
all_indices.append(indices)
|
|
||||||
out_indices = torch.stack(all_indices)
|
|
||||||
return out_indices
|
|
||||||
|
|
||||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
|
||||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
|
||||||
for i, indices in enumerate(q_indices):
|
|
||||||
layer = self.layers[i]
|
|
||||||
quantized = layer.decode(indices, [
|
|
||||||
self.inited[i],
|
|
||||||
self.cluster_size[i],
|
|
||||||
self.embed[i],
|
|
||||||
self.embed_avg[i]
|
|
||||||
])
|
|
||||||
quantized_out = quantized_out + quantized
|
|
||||||
return quantized_out
|
|
||||||
@ -1,126 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
"""Torch distributed utilities."""
|
|
||||||
|
|
||||||
import typing as tp
|
|
||||||
import time, logging
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def rank():
|
|
||||||
if torch.distributed.is_initialized():
|
|
||||||
return torch.distributed.get_rank()
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def world_size():
|
|
||||||
if torch.distributed.is_initialized():
|
|
||||||
return torch.distributed.get_world_size()
|
|
||||||
else:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
def is_distributed():
|
|
||||||
return world_size() > 1
|
|
||||||
|
|
||||||
|
|
||||||
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
|
||||||
if is_distributed():
|
|
||||||
return torch.distributed.all_reduce(tensor, op)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_complex_or_float(tensor):
|
|
||||||
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
|
||||||
# utility function to check that the number of params in all workers is the same,
|
|
||||||
# and thus avoid a deadlock with distributed all reduce.
|
|
||||||
if not is_distributed() or not params:
|
|
||||||
return
|
|
||||||
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
|
||||||
all_reduce(tensor)
|
|
||||||
if tensor.item() != len(params) * world_size():
|
|
||||||
# If not all the workers have the same number, for at least one of them,
|
|
||||||
# this inequality will be verified.
|
|
||||||
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
|
|
||||||
"at least one worker has a different one.")
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
|
||||||
"""Broadcast the tensors from the given parameters to all workers.
|
|
||||||
This can be used to ensure that all workers have the same model to start with.
|
|
||||||
"""
|
|
||||||
# start = time.time()
|
|
||||||
if not is_distributed():
|
|
||||||
return
|
|
||||||
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
|
||||||
_check_number_of_params(tensors)
|
|
||||||
handles = []
|
|
||||||
for tensor in tensors:
|
|
||||||
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
|
||||||
handles.append(handle)
|
|
||||||
for handle in handles:
|
|
||||||
handle.wait()
|
|
||||||
# logging.info(f"Time of broadcast {(time.time()-start):.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
def sync_buffer(buffers, average=True):
|
|
||||||
"""
|
|
||||||
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
|
||||||
"""
|
|
||||||
if not is_distributed():
|
|
||||||
return
|
|
||||||
handles = []
|
|
||||||
for buffer in buffers:
|
|
||||||
if torch.is_floating_point(buffer.data):
|
|
||||||
if average:
|
|
||||||
handle = torch.distributed.all_reduce(
|
|
||||||
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
|
||||||
else:
|
|
||||||
handle = torch.distributed.broadcast(
|
|
||||||
buffer.data, src=0, async_op=True)
|
|
||||||
handles.append((buffer, handle))
|
|
||||||
for buffer, handle in handles:
|
|
||||||
handle.wait()
|
|
||||||
if average:
|
|
||||||
buffer.data /= world_size
|
|
||||||
|
|
||||||
|
|
||||||
def sync_grad(params):
|
|
||||||
"""
|
|
||||||
Simpler alternative to DistributedDataParallel, that doesn't rely
|
|
||||||
on any black magic. For simple models it can also be as fast.
|
|
||||||
Just call this on your model parameters after the call to backward!
|
|
||||||
"""
|
|
||||||
if not is_distributed():
|
|
||||||
return
|
|
||||||
handles = []
|
|
||||||
for p in params:
|
|
||||||
if p.grad is not None:
|
|
||||||
handle = torch.distributed.all_reduce(
|
|
||||||
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
|
||||||
handles.append((p, handle))
|
|
||||||
for p, handle in handles:
|
|
||||||
handle.wait()
|
|
||||||
p.grad.data /= world_size()
|
|
||||||
|
|
||||||
|
|
||||||
def average_metrics(metrics: tp.Dict[str, float], count=1.):
|
|
||||||
"""Average a dictionary of metrics across all workers, using the optional
|
|
||||||
`count` as unormalized weight.
|
|
||||||
"""
|
|
||||||
if not is_distributed():
|
|
||||||
return metrics
|
|
||||||
keys, values = zip(*metrics.items())
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
|
||||||
tensor *= count
|
|
||||||
all_reduce(tensor)
|
|
||||||
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
|
||||||
return dict(zip(keys, averaged))
|
|
||||||
@ -1,134 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
"""Residual vector quantizer implementation."""
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import math
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from funasr.models.sense_voice.quantizer.quantization import distrib
|
|
||||||
from funasr.models.sense_voice.quantizer.quantization.core_vq import SimpleResidualVectorQuantization
|
|
||||||
from funasr.models.sense_voice.quantizer.quantization.ddp_core_vq import DistributedResidualVectorQuantization
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class QuantizedResult:
|
|
||||||
quantized: torch.Tensor
|
|
||||||
codes: torch.Tensor
|
|
||||||
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
|
||||||
penalty: tp.Optional[torch.Tensor] = None
|
|
||||||
metrics: dict = field(default_factory=dict)
|
|
||||||
sub_quants: torch.Tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualVectorQuantizer(nn.Module):
|
|
||||||
"""Residual Vector Quantizer.
|
|
||||||
Args:
|
|
||||||
dimension (int): Dimension of the codebooks.
|
|
||||||
n_q (int): Number of residual vector quantizers used.
|
|
||||||
bins (int): Codebook size.
|
|
||||||
decay (float): Decay for exponential moving average over the codebooks.
|
|
||||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
|
||||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
|
||||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
||||||
that have an exponential moving average cluster size less than the specified threshold with
|
|
||||||
randomly selected vector from the current batch.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dimension: int = 256,
|
|
||||||
n_q: int = 8,
|
|
||||||
bins: int = 1024,
|
|
||||||
decay: float = 0.99,
|
|
||||||
kmeans_init: tp.Union[bool, str] = True,
|
|
||||||
kmeans_iters: int = 50,
|
|
||||||
threshold_ema_dead_code: int = 2,
|
|
||||||
quantize_dropout: bool = False,
|
|
||||||
rand_num_quant: tp.Optional[tp.List] = None,
|
|
||||||
encoder_hop_length: int = 320,
|
|
||||||
use_ddp: bool = True,
|
|
||||||
q0_ds_ratio: int = 1,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.n_q = n_q
|
|
||||||
self.dimension = dimension
|
|
||||||
self.bins = bins
|
|
||||||
self.decay = decay
|
|
||||||
self.kmeans_init = kmeans_init
|
|
||||||
self.kmeans_iters = kmeans_iters
|
|
||||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
||||||
self.encoder_hop_length = encoder_hop_length
|
|
||||||
self.training = True
|
|
||||||
if use_ddp:
|
|
||||||
rvq_class = DistributedResidualVectorQuantization
|
|
||||||
logging.info("Using distributed residual vector quantization.")
|
|
||||||
else:
|
|
||||||
rvq_class = SimpleResidualVectorQuantization
|
|
||||||
logging.warning("Using simple residual vector quantization")
|
|
||||||
self.model = rvq_class(
|
|
||||||
dim=self.dimension,
|
|
||||||
codebook_size=self.bins,
|
|
||||||
num_quantizers=self.n_q,
|
|
||||||
decay=self.decay,
|
|
||||||
kmeans_init=self.kmeans_init,
|
|
||||||
kmeans_iters=self.kmeans_iters,
|
|
||||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
|
||||||
quantize_dropout=quantize_dropout,
|
|
||||||
rand_num_quant=rand_num_quant,
|
|
||||||
q0_ds_ratio=q0_ds_ratio,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
|
|
||||||
"""Residual vector quantization on the given input tensor.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Input tensor in the shape of (B, C, T).
|
|
||||||
sample_rate (int): Sample rate of the input tensor.
|
|
||||||
bandwidth (float): Target bandwidth.
|
|
||||||
Returns:
|
|
||||||
QuantizedResult:
|
|
||||||
The quantized (or approximately quantized) representation with
|
|
||||||
the associated bandwidth and any penalty term for the loss.
|
|
||||||
"""
|
|
||||||
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
|
||||||
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
|
||||||
quantized, codes, commit_loss, sub_quants = self.model(x, n_q=n_q)
|
|
||||||
bw = torch.tensor(n_q * bw_per_q).to(x)
|
|
||||||
return QuantizedResult(quantized, codes, bw,
|
|
||||||
penalty=torch.mean(commit_loss),
|
|
||||||
sub_quants=sub_quants)
|
|
||||||
|
|
||||||
def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
|
|
||||||
"""Return n_q based on specified target bandwidth.
|
|
||||||
"""
|
|
||||||
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
|
||||||
n_q = self.n_q
|
|
||||||
if bandwidth and bandwidth > 0.:
|
|
||||||
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
|
|
||||||
return n_q
|
|
||||||
|
|
||||||
def get_bandwidth_per_quantizer(self, sample_rate: int):
|
|
||||||
"""Return bandwidth per quantizer for a given input sample rate.
|
|
||||||
"""
|
|
||||||
return math.log2(self.bins) * sample_rate / self.encoder_hop_length
|
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
|
|
||||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
|
||||||
The RVQ encode method sets the appropriate number of quantizer to use
|
|
||||||
and returns indices for each quantizer.
|
|
||||||
"""
|
|
||||||
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
|
||||||
codes = self.model.encode(x, n_q=n_q)
|
|
||||||
return codes
|
|
||||||
|
|
||||||
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Decode the given codes to the quantized representation.
|
|
||||||
"""
|
|
||||||
quantized = self.model.decode(codes)
|
|
||||||
return quantized
|
|
||||||
Loading…
Reference in New Issue
Block a user