mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix rwkv infer bugs
This commit is contained in:
parent
8a84ed6a4a
commit
4e0404e04e
@ -1605,7 +1605,6 @@ class Speech2TextTransducer:
|
|||||||
feats_lengths = to_device(feats_lengths, device=self.device)
|
feats_lengths = to_device(feats_lengths, device=self.device)
|
||||||
|
|
||||||
enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
|
enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
|
||||||
|
|
||||||
nbest_hyps = self.beam_search(enc_out[0])
|
nbest_hyps = self.beam_search(enc_out[0])
|
||||||
|
|
||||||
return nbest_hyps
|
return nbest_hyps
|
||||||
|
|||||||
@ -113,11 +113,12 @@ class RWKVEncoder(AbsEncoder):
|
|||||||
x = self.embed_norm(x)
|
x = self.embed_norm(x)
|
||||||
olens = mask.eq(0).sum(1)
|
olens = mask.eq(0).sum(1)
|
||||||
|
|
||||||
for block in self.rwkv_blocks:
|
# for training
|
||||||
x, _ = block(x)
|
# for block in self.rwkv_blocks:
|
||||||
# for streaming inference
|
# x, _ = block(x)
|
||||||
# xs_pad = self.rwkv_infer(xs_pad)
|
|
||||||
|
|
||||||
|
# for streaming inference
|
||||||
|
x = self.rwkv_infer(x)
|
||||||
x = self.final_norm(x)
|
x = self.final_norm(x)
|
||||||
|
|
||||||
if self.time_reduction_factor > 1:
|
if self.time_reduction_factor > 1:
|
||||||
@ -136,9 +137,9 @@ class RWKVEncoder(AbsEncoder):
|
|||||||
|
|
||||||
state = [
|
state = [
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
|
(batch_size, 1, hidden_sizes[i], self.num_blocks),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=xs_pad.device,
|
||||||
)
|
)
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
]
|
]
|
||||||
@ -151,5 +152,5 @@ class RWKVEncoder(AbsEncoder):
|
|||||||
for idx, block in enumerate(self.rwkv_blocks):
|
for idx, block in enumerate(self.rwkv_blocks):
|
||||||
x_t, state = block(x_t, state=state)
|
x_t, state = block(x_t, state=state)
|
||||||
xs_out.append(x_t)
|
xs_out.append(x_t)
|
||||||
xs_out = torch.stack(xs_out, dim=1)
|
xs_out = torch.cat(xs_out, dim=1)
|
||||||
return xs_out
|
return xs_out
|
||||||
|
|||||||
135
funasr/modules/cuda_decoder/wkv_cuda.cu
Normal file
135
funasr/modules/cuda_decoder/wkv_cuda.cu
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
// Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#define MIN_VALUE (-1e38)
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||||
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||||
|
F *__restrict__ const _y) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
F u = _u[_c];
|
||||||
|
F w = _w[_c];
|
||||||
|
const F *__restrict__ const k = _k + _offset;
|
||||||
|
const F *__restrict__ const v = _v + _offset;
|
||||||
|
F *__restrict__ const y = _y + _offset;
|
||||||
|
|
||||||
|
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||||
|
F aa = 0, bb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const F kk = k[ii];
|
||||||
|
const F vv = v[ii];
|
||||||
|
|
||||||
|
F ww = u + kk;
|
||||||
|
F p = max(pp, ww);
|
||||||
|
F e1 = exp(pp - p);
|
||||||
|
F e2 = exp(ww - p);
|
||||||
|
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||||
|
|
||||||
|
ww = w + pp;
|
||||||
|
p = max(ww, kk);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = exp(kk - p);
|
||||||
|
aa = e1 * aa + e2 * vv;
|
||||||
|
bb = e1 * bb + e2;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||||
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||||
|
const F *__restrict__ const _y, const F *__restrict__ const _gy,
|
||||||
|
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
F u = _u[_c];
|
||||||
|
F w = _w[_c];
|
||||||
|
const F *__restrict__ const k = _k + _offset;
|
||||||
|
const F *__restrict__ const v = _v + _offset;
|
||||||
|
const F *__restrict__ const y = _y + _offset;
|
||||||
|
const F *__restrict__ const gy = _gy + _offset;
|
||||||
|
F *__restrict__ const gk = _gk + _offset;
|
||||||
|
F *__restrict__ const gv = _gv + _offset;
|
||||||
|
|
||||||
|
F q[Tmax], r[Tmax];
|
||||||
|
|
||||||
|
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const F kk = k[ii];
|
||||||
|
const F vv = v[ii];
|
||||||
|
const F yy = y[ii];
|
||||||
|
|
||||||
|
F ww = u + kk;
|
||||||
|
F p = max(pp, ww);
|
||||||
|
F e1 = exp(pp - p);
|
||||||
|
F e2 = exp(ww - p);
|
||||||
|
const F qq = gy[ii] / (e1 * bb + e2);
|
||||||
|
gw += (ga - gb * yy) * e1 * qq;
|
||||||
|
gu += (vv - yy) * e2 * qq;
|
||||||
|
q[i] = qq;
|
||||||
|
r[i] = ww - p;
|
||||||
|
|
||||||
|
ww = w + pp;
|
||||||
|
p = max(ww, kk);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = exp(kk - p);
|
||||||
|
ga = e1 * (aa + ga);
|
||||||
|
gb = e1 * (bb + gb);
|
||||||
|
aa = e1 * aa + e2 * vv;
|
||||||
|
bb = e1 * bb + e2;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
const int _offsetBC = _b * C + _c;
|
||||||
|
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
||||||
|
_gu[_offsetBC] = gu;
|
||||||
|
|
||||||
|
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = T - 1; i >= 0; i--) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const F kk = k[ii];
|
||||||
|
const F vv = v[ii];
|
||||||
|
const F yy = y[ii];
|
||||||
|
const F qq = q[i];
|
||||||
|
const F rr = r[i];
|
||||||
|
|
||||||
|
F e1 = qq * exp(rr);
|
||||||
|
F e2 = exp(kk + pp);
|
||||||
|
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
||||||
|
gv[ii] = e1 + e2 * aa;
|
||||||
|
|
||||||
|
const F ww = w + pp;
|
||||||
|
const F www = rr - u - kk;
|
||||||
|
const F p = max(ww, www);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = qq * exp(www - p);
|
||||||
|
aa = e1 * aa + e2;
|
||||||
|
bb = e1 * bb - e2 * yy;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||||
|
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||||
|
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||||
|
}
|
||||||
37
funasr/modules/cuda_decoder/wkv_op.cpp
Normal file
37
funasr/modules/cuda_decoder/wkv_op.cpp
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp
|
||||||
|
Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||||
|
|
||||||
|
void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||||
|
const int B = k.size(0);
|
||||||
|
const int T = k.size(1);
|
||||||
|
const int C = k.size(2);
|
||||||
|
|
||||||
|
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||||
|
const int B = k.size(0);
|
||||||
|
const int T = k.size(1);
|
||||||
|
const int C = k.size(2);
|
||||||
|
|
||||||
|
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &forward, "wkv forward");
|
||||||
|
m.def("backward", &backward, "wkv backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(wkv_decoder, m) {
|
||||||
|
m.def("forward", forward);
|
||||||
|
m.def("backward", backward);
|
||||||
|
}
|
||||||
135
funasr/modules/cuda_encoder/wkv_cuda.cu
Normal file
135
funasr/modules/cuda_encoder/wkv_cuda.cu
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
// Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#define MIN_VALUE (-1e38)
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||||
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||||
|
F *__restrict__ const _y) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
F u = _u[_c];
|
||||||
|
F w = _w[_c];
|
||||||
|
const F *__restrict__ const k = _k + _offset;
|
||||||
|
const F *__restrict__ const v = _v + _offset;
|
||||||
|
F *__restrict__ const y = _y + _offset;
|
||||||
|
|
||||||
|
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||||
|
F aa = 0, bb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const F kk = k[ii];
|
||||||
|
const F vv = v[ii];
|
||||||
|
|
||||||
|
F ww = u + kk;
|
||||||
|
F p = max(pp, ww);
|
||||||
|
F e1 = exp(pp - p);
|
||||||
|
F e2 = exp(ww - p);
|
||||||
|
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||||
|
|
||||||
|
ww = w + pp;
|
||||||
|
p = max(ww, kk);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = exp(kk - p);
|
||||||
|
aa = e1 * aa + e2 * vv;
|
||||||
|
bb = e1 * bb + e2;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||||
|
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||||
|
const F *__restrict__ const _y, const F *__restrict__ const _gy,
|
||||||
|
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int _b = idx / C;
|
||||||
|
const int _c = idx % C;
|
||||||
|
const int _offset = _b * T * C + _c;
|
||||||
|
|
||||||
|
F u = _u[_c];
|
||||||
|
F w = _w[_c];
|
||||||
|
const F *__restrict__ const k = _k + _offset;
|
||||||
|
const F *__restrict__ const v = _v + _offset;
|
||||||
|
const F *__restrict__ const y = _y + _offset;
|
||||||
|
const F *__restrict__ const gy = _gy + _offset;
|
||||||
|
F *__restrict__ const gk = _gk + _offset;
|
||||||
|
F *__restrict__ const gv = _gv + _offset;
|
||||||
|
|
||||||
|
F q[Tmax], r[Tmax];
|
||||||
|
|
||||||
|
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = 0; i < T; i++) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const F kk = k[ii];
|
||||||
|
const F vv = v[ii];
|
||||||
|
const F yy = y[ii];
|
||||||
|
|
||||||
|
F ww = u + kk;
|
||||||
|
F p = max(pp, ww);
|
||||||
|
F e1 = exp(pp - p);
|
||||||
|
F e2 = exp(ww - p);
|
||||||
|
const F qq = gy[ii] / (e1 * bb + e2);
|
||||||
|
gw += (ga - gb * yy) * e1 * qq;
|
||||||
|
gu += (vv - yy) * e2 * qq;
|
||||||
|
q[i] = qq;
|
||||||
|
r[i] = ww - p;
|
||||||
|
|
||||||
|
ww = w + pp;
|
||||||
|
p = max(ww, kk);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = exp(kk - p);
|
||||||
|
ga = e1 * (aa + ga);
|
||||||
|
gb = e1 * (bb + gb);
|
||||||
|
aa = e1 * aa + e2 * vv;
|
||||||
|
bb = e1 * bb + e2;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
const int _offsetBC = _b * C + _c;
|
||||||
|
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
||||||
|
_gu[_offsetBC] = gu;
|
||||||
|
|
||||||
|
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||||
|
for (int i = T - 1; i >= 0; i--) {
|
||||||
|
const int ii = i * C;
|
||||||
|
const F kk = k[ii];
|
||||||
|
const F vv = v[ii];
|
||||||
|
const F yy = y[ii];
|
||||||
|
const F qq = q[i];
|
||||||
|
const F rr = r[i];
|
||||||
|
|
||||||
|
F e1 = qq * exp(rr);
|
||||||
|
F e2 = exp(kk + pp);
|
||||||
|
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
||||||
|
gv[ii] = e1 + e2 * aa;
|
||||||
|
|
||||||
|
const F ww = w + pp;
|
||||||
|
const F www = rr - u - kk;
|
||||||
|
const F p = max(ww, www);
|
||||||
|
e1 = exp(ww - p);
|
||||||
|
e2 = qq * exp(www - p);
|
||||||
|
aa = e1 * aa + e2;
|
||||||
|
bb = e1 * bb - e2 * yy;
|
||||||
|
pp = p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||||
|
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||||
|
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||||
|
assert(B * C % threadsPerBlock.x == 0);
|
||||||
|
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||||
|
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||||
|
}
|
||||||
37
funasr/modules/cuda_encoder/wkv_op.cpp
Normal file
37
funasr/modules/cuda_encoder/wkv_op.cpp
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp
|
||||||
|
Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||||
|
|
||||||
|
void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||||
|
const int B = k.size(0);
|
||||||
|
const int T = k.size(1);
|
||||||
|
const int C = k.size(2);
|
||||||
|
|
||||||
|
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||||
|
const int B = k.size(0);
|
||||||
|
const int T = k.size(1);
|
||||||
|
const int C = k.size(2);
|
||||||
|
|
||||||
|
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &forward, "wkv forward");
|
||||||
|
m.def("backward", &backward, "wkv backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(wkv_encoder, m) {
|
||||||
|
m.def("forward", forward);
|
||||||
|
m.def("backward", backward);
|
||||||
|
}
|
||||||
@ -445,7 +445,7 @@ class SelfAttention(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
num_state, den_state, max_state = state
|
num_state, den_state, max_state = state
|
||||||
|
time_decay = -torch.exp(time_decay)
|
||||||
max_for_output = torch.maximum(max_state, (time_first + key))
|
max_for_output = torch.maximum(max_state, (time_first + key))
|
||||||
|
|
||||||
e1 = torch.exp(max_state - max_for_output)
|
e1 = torch.exp(max_state - max_for_output)
|
||||||
@ -495,7 +495,7 @@ class DecoderSelfAttention(SelfAttention):
|
|||||||
dropout_rate,
|
dropout_rate,
|
||||||
num_blocks
|
num_blocks
|
||||||
)
|
)
|
||||||
load_decoder_wkv_kernel(context_size)
|
# load_decoder_wkv_kernel(context_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -577,7 +577,7 @@ class EncoderSelfAttention(SelfAttention):
|
|||||||
dropout_rate,
|
dropout_rate,
|
||||||
num_blocks
|
num_blocks
|
||||||
)
|
)
|
||||||
load_encoder_wkv_kernel(context_size)
|
# load_encoder_wkv_kernel(context_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user