A Recurrent Neural Network (RNN) often uses ordered sequences as inputs. Real-world sequences have different lengths, especially in Natural Language Processing (NLP) because all words don’t have the same number of characters and all sentences don’t have the same number of words. In PyTorch, the inputs of a neural network are often managed by a DataLoader. A DataLoader groups the input in batches. This is better for training a neural network because it’s faster and more efficient than sending the inputs one by one to the neural network. The issue with this approach is that it assumes every input has the same shape. As stated before, sequences don’t have a consistent shape, so how one can train a RNN in PyTorch with variable-length sequences and still benefit from the DataLoader class?

Here I will show a complete training example based on an official PyTorch RNN tutorial, whose goal is to classify names according to their origin.

Constructing an IterableDataset #

To create a DataLoader, we first need to create an IterableDataset that represents how to generate training examples.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Dataset_Loader_Iterable(IterableDataset):
    def __init__(self, iters):
        self.length = iters
        self.all_letters = string.ascii_letters + " .,;'"
        self.LETTERS = {
            letter: idx for idx, letter in enumerate(self.all_letters, start=0)
        }
        self.n_letters = len(self.all_letters)
        self.category_lines = {}
        self.all_categories = []
        for filename in self.find_files("data/names/*.txt"):
            category = os.path.splitext(os.path.basename(filename))[0]
            self.all_categories.append(category)
            lines = self.read_lines(filename)
            self.category_lines[category] = lines

        self.n_categories = len(self.all_categories)

    def find_files(self, path):
        pass
    def read_lines(self, filename):
        pass
    def random_choice(self, l):
        pass
    def alphabet_position(self, text):
        pass
    def unicode_to_ascii(self, s):
        pass
    def prepare_sequence(self, seq):
        pass

    def random_training_example(self):
        category = self.random_choice(self.all_categories)
        line = self.random_choice(self.category_lines[category])
        category_tensor = torch.tensor(
            [self.all_categories.index(category)], dtype=torch.long
        )
        line_tensor = self.prepare_sequence(line)

        return category, line, category_tensor, line_tensor

    def gen_examples(self, group):
        return self.random_training_example()

    def __len__(self):
        return self.length

    def __iter__(self):
        return map(self.gen_examples, (range(self.length)))

n_iters = 100000
data = Dataset_Loader_Iterable(n_iters)

Here we copy the code and functions from the PyTorch tutorial and define a __iter__() method that calls random_training_example(). This returns:

  1. The origin of the name (the country)
  2. The name itself
  3. The integer-encoded category tensor
  4. The integer-encoded name tensor (of variable length)

Constructing the DataLoader #

This part highlights the problem with variable length sequences. random_training_example() generates data of variable lengths because names have inconsistent lengths. If we create a DataLoader with our IterableDataset, PyTorch will complain that it cannot create batches of examples if they all have different shapes. We must therefore create an intermediate collate() function that will equalize the sequences lengths and tell the DataLoader to call the collate function.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class PadSequence:
    def __call__(self, batch):
        data = [item[3] for item in batch]
        data_text = [item[1] for item in batch]
        labels = [x for item in batch for x in item[2]]
        labels_text = [item[0] for item in batch]
        batch = list(zip(data, labels, labels_text, data_text))
        sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
        sequences = [x[0] for x in sorted_batch]
        sequences_padded = pad_sequence(sequences)
        lengths = torch.LongTensor([len(x) for x in sequences])
        labels = torch.LongTensor(list(map(lambda x: x[1], sorted_batch)))
        labels_text = list(map(lambda x: x[2], sorted_batch))
        data_text = list(map(lambda x: x[3], sorted_batch))

        return sequences_padded, lengths, labels, labels_text, data_text

data_loader = DataLoader(data, batch_size=8, collate_fn=PadSequence())

I have taken this function from here and adapted it to our custom IterableDataset. This also makes the code compatible with the rest of the tutorial. It works as follows. When executed, the function __call__() gets N items coming from the __iter__() method defined earlier, where N is the size of the batches. Line 3-6 retrieve the 4 parts returned by random_training_example(). Then line 7 zips the items so that we have a Python list of N elements composed of data (the integer-encoded names), labels (the integer-encoded labels), labels_text and data_text. Then, line 8, the list is sorted according to the length of the first item (the integer-encoded name) so that the longest names come first and the shortest ones come last. On line 9 and 10, the sequences representing the names are retrieved and padded with the pad_sequence() function meaning that they are filled with a padding value. To understand what pad_sequence() does, let’s see an example:

If we have a sequences variable (a list of integer-encoded names) of:

tensor([20, 18, 19, 24, 20, 25,  7,  0, 13,  8, 13])
tensor([ 0, 22,  4, 17, 24,  0, 13, 14,  5,  5])
tensor([15,  4, 19, 19,  8,  6, 17,  4, 22])
tensor([ 1,  4, 11, 17, 14, 18,  4])
tensor([10,  0, 11, 20, 25,  0])
tensor([15, 14, 20, 11,  8, 13])
tensor([21,  8,  2, 19, 14, 17])
tensor([3,  0, 13, 10, 18])

Then pad_sequence() will output a PyTorch Tensor:

tensor([[20,  0, 15,  1, 10, 15, 21,  3],
        [18, 22,  4,  4,  0, 14,  8,  0],
        [19,  4, 19, 11, 11, 20,  2, 13],
        [24, 17, 19, 17, 20, 11, 19, 10],
        [20, 24,  8, 14, 25,  8, 14, 18],
        [25,  0,  6, 18,  0, 13, 17,  0],
        [ 7, 13, 17,  4,  0,  0,  0,  0],
        [ 0, 14,  4,  0,  0,  0,  0,  0],
        [13,  5, 22,  0,  0,  0,  0,  0],
        [ 8,  5,  0,  0,  0,  0,  0,  0],
        [13,  0,  0,  0,  0,  0,  0,  0]])

The first array in the Tensor [20, 0, 15, 1, 10, 15, 21, 3] is composed of the first integer-encoded character of each name (the first column of the sequences). That is, the first name in the batch starts with 20, the second name in the batch starts with 0 and so on. The 6th array (representing the 6th character of each name) in the Tensor [25, 0, 6, 18, 0, 13, 17, 0] has a 0 at the end because the last name has only 5 characters, so a 0 is used as a placeholder. The same logic applies for all subsequent arrays. Note that using 0 as a padding value is not an issue, because as we’ll see it cannot be confused with the integer-encoded ‘a’ character by the RNN model.

After the padding, line 11 we get the length of each name in the sorted list, and lines 12-14 retrieve the labels and textual representations of the input in the order of the sorted batch (so they’re in the same order as the padded sequences).

Constructing the RNN model #

Here we define the RNN model, composed of 4 steps:

  1. Map the padded, integer-encoded characters to embedding vectors
  2. Unpad the sequences and feed them to a Gated Recurrent Unit (GRU)
  3. Get the GRU output and feed it to a linear layer
  4. Apply the softmax function to interpret the output as probabilities
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRU, self).__init__()

        self.char_embed = nn.Embedding(input_size, hidden_size, sparse=False)
        torch.nn.init.xavier_normal_(self.char_embed.weight.data)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.hidden2country = nn.Linear(hidden_size, output_size)
        torch.nn.init.xavier_normal_(self.hidden2country.weight.data)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x: torch.Tensor, x_lengths) -> torch.Tensor:
        x = self.char_embed(x)
        x = pack_padded_sequence(x, x_lengths)
        gru_out, _ = self.gru(x)
        output, _ = pad_packed_sequence(gru_out)

        idx = (
            (torch.LongTensor(x_lengths) - 1)
            .view(-1, 1)
            .expand(len(x_lengths), output.size(2))
        )

        idx = idx.unsqueeze(0)
        last_output = F.selu(output.gather(0, idx).squeeze(0))
        country = self.hidden2country(last_output)

        return self.softmax(country)

rnn = GRU(data.n_letters, n_hidden, data.n_categories)

Lines 5-10 are the definition of the three layers (embeddings, GRU and linear) and their initialization. The forward() function takes the padded, integer-encoded sequences and their lengths as an input. Line 13 maps each integer-encoded character to an embedding vector. In this step, the padding value (0) is also mapped to an embedding vector, because it is confused with the letter ‘a’. But this is not an issue because line 14 the sequences are packed according to their original lengths so that the GRU doesn’t see the padded values. Line 16, we use the pad_packed_sequence() function on the GRU output to reverse the previous packing. Lines 18-24 retrieve the last output for each sequence and applies the last activation function. Line 26 applies the linear layer and line 28 the softmax function.

The training loop #

The training loop simply gets training examples, feeds them into the model, computes the loss and updates the model’s weight. Please refer to the PyTorch tutorial for a more useful training loop.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def train(seqs, lengths, labels):
    optimizer.zero_grad()

    output = rnn(seqs, lengths)

    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()

    return output, loss.item()

for seqs, lengths, labels, category, line in data_loader:
    output, loss = train(seqs, lengths, labels)