mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
modify statistic pooling layer
This commit is contained in:
parent
fee3fd32fd
commit
f191f4c868
@ -82,7 +82,7 @@ def windowed_statistic_pooling(
|
||||
tt = xs_pad.shape[2]
|
||||
num_chunk = int(math.ceil(tt / pooling_stride))
|
||||
pad = pooling_size // 2
|
||||
if xs_pad.shape == 4:
|
||||
if len(xs_pad.shape) == 4:
|
||||
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
|
||||
else:
|
||||
features = F.pad(xs_pad, (pad, pad), "reflect")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user