FunASR/funasr/export/models/language_models/seq_rnn.py
2023-07-07 16:53:16 +08:00

85 lines
2.5 KiB
Python

import os
import torch
import torch.nn as nn
class SequentialRNNLM(nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.encoder = model.encoder
self.rnn = model.rnn
self.rnn_type = model.rnn_type
self.decoder = model.decoder
self.nlayers = model.nlayers
self.nhid = model.nhid
self.model_name = "seq_rnnlm"
def forward(self, y, hidden1, hidden2=None):
# batch_score function.
emb = self.encoder(y)
if self.rnn_type == "LSTM":
output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
else:
output, hidden1 = self.rnn(emb, hidden1)
decoded = self.decoder(
output.contiguous().view(output.size(0) * output.size(1), output.size(2))
)
if self.rnn_type == "LSTM":
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
hidden2,
)
else:
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
)
def get_dummy_inputs(self):
tgt = torch.LongTensor([0, 1]).unsqueeze(0)
hidden = torch.randn(self.nlayers, 1, self.nhid)
if self.rnn_type == "LSTM":
return (tgt, hidden, hidden)
else:
return (tgt, hidden)
def get_input_names(self):
if self.rnn_type == "LSTM":
return ["x", "in_hidden1", "in_hidden2"]
else:
return ["x", "in_hidden1"]
def get_output_names(self):
if self.rnn_type == "LSTM":
return ["y", "out_hidden1", "out_hidden2"]
else:
return ["y", "out_hidden1"]
def get_dynamic_axes(self):
ret = {
"x": {0: "x_batch", 1: "x_length"},
"y": {0: "y_batch"},
"in_hidden1": {1: "hidden1_batch"},
"out_hidden1": {1: "out_hidden1_batch"},
}
if self.rnn_type == "LSTM":
ret.update(
{
"in_hidden2": {1: "hidden2_batch"},
"out_hidden2": {1: "out_hidden2_batch"},
}
)
return ret
def get_model_config(self, path):
return {
"use_lm": True,
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
"lm_type": "SequentialRNNLM",
"rnn_type": self.rnn_type,
"nhid": self.nhid,
"nlayers": self.nlayers,
}