From 4c053ccc39ef3bdc6d131482274e9b1cd9ceee67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=8C=E5=8C=80?= Date: Mon, 27 Feb 2023 19:10:15 +0800 Subject: [PATCH] gpu bug fix --- funasr/models/encoder/fsmn_encoder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/funasr/models/encoder/fsmn_encoder.py b/funasr/models/encoder/fsmn_encoder.py index c749dc438..38d164dfe 100755 --- a/funasr/models/encoder/fsmn_encoder.py +++ b/funasr/models/encoder/fsmn_encoder.py @@ -82,7 +82,8 @@ class FSMNBlock(nn.Module): def forward(self, input: torch.Tensor, cache: torch.Tensor): x = torch.unsqueeze(input, 1) x_per = x.permute(0, 3, 2, 1) # B D T C - + + cache = cache.to(x_per.device) y_left = torch.cat((cache, x_per), dim=2) cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] y_left = self.conv_left(y_left) @@ -297,4 +298,4 @@ if __name__ == '__main__': print('input shape: {}'.format(x.shape)) print('output shape: {}'.format(y.shape)) - print(fsmn.to_kaldi_net()) \ No newline at end of file + print(fsmn.to_kaldi_net())