This commit is contained in:
游雁 2024-06-09 02:05:49 +08:00
parent 8bb9971753
commit 56986acaa7

View File

@ -392,7 +392,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
)
batch = []
max_len_in_batch = 0
count = 0
count = 1
for idx in buffer:
original_sample_length = self.dataset.get_source_len(idx)
if original_sample_length > self.max_token_length:
@ -410,7 +410,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
buffer_batches.append(batch)
batch = [idx]
max_len_in_batch = sample_length
count = 0
count = 1
if batch:
buffer_batches.append(batch)