mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
161 lines
5.5 KiB
Python
161 lines
5.5 KiB
Python
from collections import defaultdict
|
|
from typing import Dict
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
from funasr.modules.rnn.attentions import AttAdd
|
|
from funasr.modules.rnn.attentions import AttCov
|
|
from funasr.modules.rnn.attentions import AttCovLoc
|
|
from funasr.modules.rnn.attentions import AttDot
|
|
from funasr.modules.rnn.attentions import AttForward
|
|
from funasr.modules.rnn.attentions import AttForwardTA
|
|
from funasr.modules.rnn.attentions import AttLoc
|
|
from funasr.modules.rnn.attentions import AttLoc2D
|
|
from funasr.modules.rnn.attentions import AttLocRec
|
|
from funasr.modules.rnn.attentions import AttMultiHeadAdd
|
|
from funasr.modules.rnn.attentions import AttMultiHeadDot
|
|
from funasr.modules.rnn.attentions import AttMultiHeadLoc
|
|
from funasr.modules.rnn.attentions import AttMultiHeadMultiResLoc
|
|
from funasr.modules.rnn.attentions import NoAtt
|
|
from funasr.modules.attention import MultiHeadedAttention
|
|
|
|
|
|
from funasr.models.base_model import FunASRModel
|
|
|
|
|
|
@torch.no_grad()
|
|
def calculate_all_attentions(
|
|
model: FunASRModel, batch: Dict[str, torch.Tensor]
|
|
) -> Dict[str, List[torch.Tensor]]:
|
|
"""Derive the outputs from the all attention layers
|
|
|
|
Args:
|
|
model:
|
|
batch: same as forward
|
|
Returns:
|
|
return_dict: A dict of a list of tensor.
|
|
key_names x batch x (D1, D2, ...)
|
|
|
|
"""
|
|
bs = len(next(iter(batch.values())))
|
|
assert all(len(v) == bs for v in batch.values()), {
|
|
k: v.shape for k, v in batch.items()
|
|
}
|
|
|
|
# 1. Register forward_hook fn to save the output from specific layers
|
|
outputs = {}
|
|
handles = {}
|
|
for name, modu in model.named_modules():
|
|
|
|
def hook(module, input, output, name=name):
|
|
if isinstance(module, MultiHeadedAttention):
|
|
# NOTE(kamo): MultiHeadedAttention doesn't return attention weight
|
|
# attn: (B, Head, Tout, Tin)
|
|
outputs[name] = module.attn.detach().cpu()
|
|
elif isinstance(module, AttLoc2D):
|
|
c, w = output
|
|
# w: previous concate attentions
|
|
# w: (B, nprev, Tin)
|
|
att_w = w[:, -1].detach().cpu()
|
|
outputs.setdefault(name, []).append(att_w)
|
|
elif isinstance(module, (AttCov, AttCovLoc)):
|
|
c, w = output
|
|
assert isinstance(w, list), type(w)
|
|
# w: list of previous attentions
|
|
# w: nprev x (B, Tin)
|
|
att_w = w[-1].detach().cpu()
|
|
outputs.setdefault(name, []).append(att_w)
|
|
elif isinstance(module, AttLocRec):
|
|
# w: (B, Tin)
|
|
c, (w, (att_h, att_c)) = output
|
|
att_w = w.detach().cpu()
|
|
outputs.setdefault(name, []).append(att_w)
|
|
elif isinstance(
|
|
module,
|
|
(
|
|
AttMultiHeadDot,
|
|
AttMultiHeadAdd,
|
|
AttMultiHeadLoc,
|
|
AttMultiHeadMultiResLoc,
|
|
),
|
|
):
|
|
c, w = output
|
|
# w: nhead x (B, Tin)
|
|
assert isinstance(w, list), type(w)
|
|
att_w = [_w.detach().cpu() for _w in w]
|
|
outputs.setdefault(name, []).append(att_w)
|
|
elif isinstance(
|
|
module,
|
|
(
|
|
AttAdd,
|
|
AttDot,
|
|
AttForward,
|
|
AttForwardTA,
|
|
AttLoc,
|
|
NoAtt,
|
|
),
|
|
):
|
|
c, w = output
|
|
att_w = w.detach().cpu()
|
|
outputs.setdefault(name, []).append(att_w)
|
|
|
|
handle = modu.register_forward_hook(hook)
|
|
handles[name] = handle
|
|
|
|
# 2. Just forward one by one sample.
|
|
# Batch-mode can't be used to keep requirements small for each models.
|
|
keys = []
|
|
for k in batch:
|
|
if not k.endswith("_lengths"):
|
|
keys.append(k)
|
|
|
|
return_dict = defaultdict(list)
|
|
for ibatch in range(bs):
|
|
# *: (B, L, ...) -> (1, L2, ...)
|
|
_sample = {
|
|
k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]]
|
|
if k + "_lengths" in batch
|
|
else batch[k][ibatch, None]
|
|
for k in keys
|
|
}
|
|
|
|
# *_lengths: (B,) -> (1,)
|
|
_sample.update(
|
|
{
|
|
k + "_lengths": batch[k + "_lengths"][ibatch, None]
|
|
for k in keys
|
|
if k + "_lengths" in batch
|
|
}
|
|
)
|
|
model(**_sample)
|
|
|
|
# Derive the attention results
|
|
for name, output in outputs.items():
|
|
if isinstance(output, list):
|
|
if isinstance(output[0], list):
|
|
# output: nhead x (Tout, Tin)
|
|
output = torch.stack(
|
|
[
|
|
# Tout x (1, Tin) -> (Tout, Tin)
|
|
torch.cat([o[idx] for o in output], dim=0)
|
|
for idx in range(len(output[0]))
|
|
],
|
|
dim=0,
|
|
)
|
|
else:
|
|
# Tout x (1, Tin) -> (Tout, Tin)
|
|
output = torch.cat(output, dim=0)
|
|
else:
|
|
# output: (1, NHead, Tout, Tin) -> (NHead, Tout, Tin)
|
|
output = output.squeeze(0)
|
|
# output: (Tout, Tin) or (NHead, Tout, Tin)
|
|
return_dict[name].append(output)
|
|
outputs.clear()
|
|
|
|
# 3. Remove all hooks
|
|
for _, handle in handles.items():
|
|
handle.remove()
|
|
|
|
return dict(return_dict)
|