Part of the article is generated by [ChatGPT]
collate_fn
This post records about the collate_fn
from torch.util.data.DataLoader
and the python built-in function zip
.
Each batch, the dataloader collects batch_size
number of items. They are picked from the dataset one by one. So currently, the batch data is [(data1, target1), (data2, target2), ..., (dataN, targetN)]
.
The default collate_fn
would change it into [torch.tensor([data1, data2, ..., dataN]), torch.tensor([target1, target2, ..., targetN])]
.
However, in some NLP tasks, the data is not in the same length. So we need to apply torch.nn.utils.rnn.pad_sequence
to make each data same length (usually the maximum length in this batch). A typical implementation is:
1 | def collate_fn_train(batch): |
zip and *
- What does the
zip
do in the above function?
The zip()
function in Python is a built-in function that takes one or more iterables (such as lists, tuples, or strings) and “zips” them together, returning an iterator of tuples where the i-th tuple contains the i-th element from each of the input iterables.
1 | a = [1, 2, 3] |
- What is the
*
?
In Python, the asterisk (*) symbol can be used to unpack iterables like lists or tuples. When used in this way, the asterisk is sometimes called the “splat” operator or the “unpacking” operator. The unpacking operator is used to extract the individual elements from an iterable and assign them to separate variables.
1 | my_list = [1, 2, 3] |
- So, what is the
x, y = zip(*batch)
?
First, batch
is [(data1, target1), (data2, target2), ..., (dataN, targetN)]
.
*batch
would unpack the outmost list, to be zip((data1, target1), (data2, target2), ..., (dataN, targetN))
. The result would be two tuples, [data1, data2, ..., dataN]
and [target1, target2, ..., targetN]
. The former one is assigned to x
and the other is assigned to y
.
In this way, we get separate data structure, that contains all data
and target
respectively.