0%

PyTorch collate_fn and Python zip function

Part of the article is generated by [ChatGPT]

collate_fn

This post records about the collate_fn from torch.util.data.DataLoader and the python built-in function zip.

Each batch, the dataloader collects batch_size number of items. They are picked from the dataset one by one. So currently, the batch data is [(data1, target1), (data2, target2), ..., (dataN, targetN)].

The default collate_fn would change it into [torch.tensor([data1, data2, ..., dataN]), torch.tensor([target1, target2, ..., targetN])].

However, in some NLP tasks, the data is not in the same length. So we need to apply torch.nn.utils.rnn.pad_sequence to make each data same length (usually the maximum length in this batch). A typical implementation is:

1
2
3
4
5
def collate_fn_train(batch):
x, y = zip(*batch)
x_pad = pad_sequence(x, batch_first=True)
# y = torch.Tensor(y) # optional
return x_pad, y

zip and *

  • What does the zip do in the above function?

The zip() function in Python is a built-in function that takes one or more iterables (such as lists, tuples, or strings) and “zips” them together, returning an iterator of tuples where the i-th tuple contains the i-th element from each of the input iterables.

1
2
3
4
5
6
7
8
9
a = [1, 2, 3]
b = ['a', 'b', 'c']
c = [True, False, True]

zipped = zip(a, b, c)

print(list(zipped))
# OUTPUT
[(1, 'a', True), (2, 'b', False), (3, 'c', True)]
  • What is the *?

In Python, the asterisk (*) symbol can be used to unpack iterables like lists or tuples. When used in this way, the asterisk is sometimes called the “splat” operator or the “unpacking” operator. The unpacking operator is used to extract the individual elements from an iterable and assign them to separate variables.

1
2
3
4
5
6
my_list = [1, 2, 3]
print(*my_list)
# equals
print(1, 2, 3)
# OUTPUT
1 2 3
  • So, what is the x, y = zip(*batch)?

First, batch is [(data1, target1), (data2, target2), ..., (dataN, targetN)].

*batch would unpack the outmost list, to be zip((data1, target1), (data2, target2), ..., (dataN, targetN)). The result would be two tuples, [data1, data2, ..., dataN] and [target1, target2, ..., targetN]. The former one is assigned to x and the other is assigned to y.

In this way, we get separate data structure, that contains all data and target respectively.