diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 4648fb333..7015eb8e2 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -1605,7 +1605,6 @@ class Speech2TextTransducer: feats_lengths = to_device(feats_lengths, device=self.device) enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths) - nbest_hyps = self.beam_search(enc_out[0]) return nbest_hyps diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py index 8a33520e9..40151bf43 100644 --- a/funasr/models/encoder/rwkv_encoder.py +++ b/funasr/models/encoder/rwkv_encoder.py @@ -113,11 +113,12 @@ class RWKVEncoder(AbsEncoder): x = self.embed_norm(x) olens = mask.eq(0).sum(1) - for block in self.rwkv_blocks: - x, _ = block(x) - # for streaming inference - # xs_pad = self.rwkv_infer(xs_pad) + # for training + # for block in self.rwkv_blocks: + # x, _ = block(x) + # for streaming inference + x = self.rwkv_infer(x) x = self.final_norm(x) if self.time_reduction_factor > 1: @@ -136,9 +137,9 @@ class RWKVEncoder(AbsEncoder): state = [ torch.zeros( - (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), + (batch_size, 1, hidden_sizes[i], self.num_blocks), dtype=torch.float32, - device=self.device, + device=xs_pad.device, ) for i in range(5) ] @@ -151,5 +152,5 @@ class RWKVEncoder(AbsEncoder): for idx, block in enumerate(self.rwkv_blocks): x_t, state = block(x_t, state=state) xs_out.append(x_t) - xs_out = torch.stack(xs_out, dim=1) + xs_out = torch.cat(xs_out, dim=1) return xs_out diff --git a/funasr/modules/cuda_decoder/wkv_cuda.cu b/funasr/modules/cuda_decoder/wkv_cuda.cu new file mode 100644 index 000000000..1dcbc7141 --- /dev/null +++ b/funasr/modules/cuda_decoder/wkv_cuda.cu @@ -0,0 +1,135 @@ +// Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu + +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + const F *__restrict__ const _y, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; + const F *__restrict__ const gy = _gy + _offset; + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F q[Tmax], r[Tmax]; + + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/funasr/modules/cuda_decoder/wkv_op.cpp b/funasr/modules/cuda_decoder/wkv_op.cpp new file mode 100644 index 000000000..102421999 --- /dev/null +++ b/funasr/modules/cuda_decoder/wkv_op.cpp @@ -0,0 +1,37 @@ +/* + * Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp + Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp + + */ + +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} + +void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv_decoder, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/funasr/modules/cuda_encoder/wkv_cuda.cu b/funasr/modules/cuda_encoder/wkv_cuda.cu new file mode 100644 index 000000000..1dcbc7141 --- /dev/null +++ b/funasr/modules/cuda_encoder/wkv_cuda.cu @@ -0,0 +1,135 @@ +// Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu + +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + const F *__restrict__ const _y, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; + const F *__restrict__ const gy = _gy + _offset; + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F q[Tmax], r[Tmax]; + + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/funasr/modules/cuda_encoder/wkv_op.cpp b/funasr/modules/cuda_encoder/wkv_op.cpp new file mode 100644 index 000000000..16f364637 --- /dev/null +++ b/funasr/modules/cuda_encoder/wkv_op.cpp @@ -0,0 +1,37 @@ +/* + * Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp + Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp + + */ + +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} + +void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv_encoder, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/funasr/modules/rwkv_attention.py b/funasr/modules/rwkv_attention.py index f0c7da39e..5384fb9ca 100644 --- a/funasr/modules/rwkv_attention.py +++ b/funasr/modules/rwkv_attention.py @@ -445,7 +445,7 @@ class SelfAttention(torch.nn.Module): """ num_state, den_state, max_state = state - + time_decay = -torch.exp(time_decay) max_for_output = torch.maximum(max_state, (time_first + key)) e1 = torch.exp(max_state - max_for_output) @@ -495,7 +495,7 @@ class DecoderSelfAttention(SelfAttention): dropout_rate, num_blocks ) - load_decoder_wkv_kernel(context_size) + # load_decoder_wkv_kernel(context_size) def forward( self, @@ -577,7 +577,7 @@ class EncoderSelfAttention(SelfAttention): dropout_rate, num_blocks ) - load_encoder_wkv_kernel(context_size) + # load_encoder_wkv_kernel(context_size) def forward( self,