modify statistic pooling layer

This commit is contained in:
志浩 2023-03-14 14:31:27 +08:00
parent fee3fd32fd
commit f191f4c868

View File

@ -82,7 +82,7 @@ def windowed_statistic_pooling(
tt = xs_pad.shape[2]
num_chunk = int(math.ceil(tt / pooling_stride))
pad = pooling_size // 2
if xs_pad.shape == 4:
if len(xs_pad.shape) == 4:
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
else:
features = F.pad(xs_pad, (pad, pad), "reflect")