perin / model /module /encoder.py
Larisa Kolesnichenko
Add the original perin code
1d5604f
raw
history blame
No virus
4.03 kB
#!/usr/bin/env python3
# coding=utf-8
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
from model.module.char_embedding import CharEmbedding
class WordDropout(nn.Dropout):
def forward(self, input_tensor):
if self.p == 0:
return input_tensor
ones = input_tensor.new_ones(input_tensor.shape[:-1])
dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)
return dropout_mask.unsqueeze(-1) * input_tensor
class Encoder(nn.Module):
def __init__(self, args, dataset):
super(Encoder, self).__init__()
self.dim = args.hidden_size
self.n_layers = args.n_encoder_layers
self.width_factor = args.query_length
self.bert = AutoModel.from_pretrained(args.encoder, add_pooling_layer=False)
# self.bert._set_gradient_checkpointing(self.bert.encoder, value=True)
if args.encoder_freeze_embedding:
self.bert.embeddings.requires_grad_(False)
self.bert.embeddings.LayerNorm.requires_grad_(True)
if args.freeze_bert:
self.bert.requires_grad_(False)
self.use_char_embedding = args.char_embedding
if self.use_char_embedding:
self.form_char_embedding = CharEmbedding(dataset.char_form_vocab_size, args.char_embedding_size, self.dim)
self.word_dropout = WordDropout(args.dropout_word)
self.post_layer_norm = nn.LayerNorm(self.dim)
self.subword_attention = nn.Linear(self.dim, 1)
if self.width_factor > 1:
self.query_generator = nn.Linear(self.dim, self.dim * self.width_factor)
else:
self.query_generator = nn.Identity()
self.encoded_layer_norm = nn.LayerNorm(self.dim)
self.scores = nn.Parameter(torch.zeros(self.n_layers, 1, 1, 1), requires_grad=True)
def forward(self, bert_input, form_chars, to_scatter, n_words):
tokens, mask = bert_input
batch_size = tokens.size(0)
encoded = self.bert(tokens, attention_mask=mask, output_hidden_states=True).hidden_states[1:]
encoded = torch.stack(encoded, dim=0) # shape: (12, B, T, H)
encoded = self.encoded_layer_norm(encoded)
if self.training:
time_len = encoded.size(2)
scores = self.scores.expand(-1, batch_size, time_len, -1)
dropout = torch.empty(self.n_layers, batch_size, 1, 1, dtype=torch.bool, device=self.scores.device)
dropout.bernoulli_(0.1)
scores = scores.masked_fill(dropout, float("-inf"))
else:
scores = self.scores
scores = F.softmax(scores, dim=0)
encoded = (scores * encoded).sum(0) # shape: (B, T, H)
encoded = encoded.masked_fill(mask.unsqueeze(-1) == 0, 0.0) # shape: (B, T, H)
subword_attention = self.subword_attention(encoded) / math.sqrt(self.dim) # shape: (B, T, 1)
subword_attention = subword_attention.expand_as(to_scatter) # shape: (B, T_subword, T_word)
subword_attention = subword_attention.masked_fill(to_scatter == 0, float("-inf")) # shape: (B, T_subword, T_word)
subword_attention = torch.softmax(subword_attention, dim=1) # shape: (B, T_subword, T_word)
subword_attention = subword_attention.masked_fill(to_scatter.sum(1, keepdim=True) == 0, value=0.0) # shape: (B, T_subword, T_word)
encoder_output = torch.einsum("bsd,bsw->bwd", encoded, subword_attention)
encoder_output = self.post_layer_norm(encoder_output)
if self.use_char_embedding:
form_char_embedding = self.form_char_embedding(form_chars[0], form_chars[1], form_chars[2])
encoder_output = self.word_dropout(encoder_output) + form_char_embedding
decoder_input = self.query_generator(encoder_output)
decoder_input = decoder_input.view(batch_size, -1, self.width_factor, self.dim).flatten(1, 2) # shape: (B, T*Q, D)
return encoder_output, decoder_input