This commit is contained in:
游雁 2024-06-09 02:19:29 +08:00
parent 7b34f8ffde
commit 1186cd96a5

View File

@ -399,10 +399,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
continue
sample_length = 1 if self.batch_type == "example" else original_sample_length
potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1)
if (
potential_batch_length <= self.batch_size
and count <= self.batch_size_sample_max
):
if potential_batch_length <= self.batch_size and count < self.batch_size_sample_max:
batch.append(idx)
max_len_in_batch = max(max_len_in_batch, sample_length)
count += 1