File size: 2,012 Bytes
1d5604f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/env python3
# coding=utf-8

import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence


class CharEmbedding(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, output_size: int):
        super(CharEmbedding, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_size, sparse=False)
        self.layer_norm = nn.LayerNorm(embedding_size)
        self.gru = nn.GRU(embedding_size, embedding_size, num_layers=1, bidirectional=True)
        self.out_linear = nn.Linear(2*embedding_size, output_size)
        self.layer_norm_2 = nn.LayerNorm(output_size)

    def forward(self, words, sentence_lens, word_lens):
        # input shape: (B, W, C)
        n_words = words.size(1)
        sentence_lens = sentence_lens.cpu()
        sentence_packed = pack_padded_sequence(words, sentence_lens, batch_first=True)  # shape: (B*W, C)
        lens_packed = pack_padded_sequence(word_lens, sentence_lens, batch_first=True)  # shape: (B*W)
        word_packed = pack_padded_sequence(sentence_packed.data, lens_packed.data.cpu(), batch_first=True, enforce_sorted=False)  # shape: (B*W*C)

        embedded = self.embedding(word_packed.data)  # shape: (B*W*C, D)
        embedded = self.layer_norm(embedded)  # shape: (B*W*C, D)

        embedded_packed = PackedSequence(embedded, word_packed[1], word_packed[2], word_packed[3])
        _, embedded = self.gru(embedded_packed)  # shape: (layers * 2, B*W, D)

        embedded = embedded[-2:, :, :].transpose(0, 1).flatten(1, 2)  # shape: (B*W, 2*D)
        embedded = F.relu(embedded)
        embedded = self.out_linear(embedded)
        embedded = self.layer_norm_2(embedded)

        embedded, _ = pad_packed_sequence(
            PackedSequence(embedded, sentence_packed[1], sentence_packed[2], sentence_packed[3]), batch_first=True, total_length=n_words,
        )  # shape: (B, W, 2*D)

        return embedded  # shape: (B, W, 2*D)