jdrechsel commited on
Commit
2ef2a13
1 Parent(s): 29affe2

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pretrained model based on [microsoft/deberta-v3-base](https://huggingface.co/microsoft/deberta-v3-base) with further mathematical pre-training.
2
+
3
+ Compared to deberta-v3-base, 300 additional mathematical LaTeX tokens have been added before the mathematical pre-training. As this additional pre-training used NSP-like tasks, a pooling layer has been added to the model (`bias` and `weight`). If you don't need this pooling layer, just use the standard transformers DeBERTa model. If you want to use the additional pooling layer like the BERT one, a wrapper class like the following may be used:
4
+ ```python
5
+ from typing import Mapping, Any
6
+
7
+ import torch
8
+ from torch import nn
9
+ from transformers import DebertaV2Model, DebertaV2Tokenizer, AutoConfig, AutoTokenizer
10
+
11
+ class DebertaV2ModelWithPoolingLayer:
12
+
13
+ def __init__(self, pretrained_model_name):
14
+ super(DebertaV2ModelWithPoolingLayer, self).__init__()
15
+
16
+ # Load the Deberta model and tokenizer
17
+ self.deberta = DebertaV2Model.from_pretrained(pretrained_model_name)
18
+ self.tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name)
19
+
20
+ # Add a pooling layer (Linear + tanh activation) for the CLS token
21
+ self.pooling_layer = nn.Sequential(
22
+ nn.Linear(self.deberta.config.hidden_size, self.deberta.config.hidden_size),
23
+ nn.Tanh()
24
+ )
25
+
26
+ self.config = self.deberta.config
27
+ self.embeddings = self.deberta.embeddings
28
+
29
+
30
+ def forward(self, input_ids, attention_mask, *args, **kwargs):
31
+ # Forward pass through the Deberta model
32
+ outputs = self.deberta(input_ids, attention_mask=attention_mask, *args, **kwargs)
33
+
34
+ # Extract the hidden states from the output
35
+ hidden_states = outputs.last_hidden_state
36
+
37
+ # Get the CLS token representation (first token)
38
+ cls_token = hidden_states[:, 0, :]
39
+
40
+ # Apply the pooling layer to the CLS token representation
41
+ pooled_output = self.pooling_layer(cls_token)
42
+ # Include the pooled_output in the output dictionary as 'pooling_layer'
43
+ outputs["pooler_output"] = pooled_output
44
+
45
+ return outputs
46
+
47
+ def save_pretrained(self, path):
48
+ # Save the model's state_dict, configuration, and tokenizer
49
+ state_dict = self.deberta.state_dict()
50
+ state_dict.update(self.pooling_layer[0].state_dict())
51
+
52
+ torch.save(state_dict, f"{path}/pytorch_model.bin")
53
+ self.deberta.config.save_pretrained(path)
54
+ self.tokenizer.save_pretrained(path)
55
+
56
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
57
+ pooler_keys = ['bias', 'weight']
58
+ deberta_state_dict = {k: v for k, v in state_dict.items() if k not in pooler_keys}
59
+ pooler_state_dict = {k: v for k, v in state_dict.items() if k in pooler_keys}
60
+ self.deberta.load_state_dict(deberta_state_dict, strict=strict)
61
+ self.pooling_layer[0].load_state_dict(pooler_state_dict)
62
+
63
+ @classmethod
64
+ def from_pretrained(cls, name):
65
+ # Initialize the instance
66
+ instance = cls(name)
67
+
68
+ try:
69
+ # Load the model's state_dict
70
+ instance.load_state_dict(torch.load(f"{name}/pytorch_model.bin"))
71
+ except FileNotFoundError:
72
+ print("Could not find DeBERTa pooling layer. Initialize new values")
73
+
74
+ # Load the configuration and tokenizer
75
+ instance.deberta.config = AutoConfig.from_pretrained(name)
76
+ instance.tokenizer = AutoTokenizer.from_pretrained(name)
77
+
78
+ return instance
79
+ ```