File size: 2,243 Bytes
93e390f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import PretrainedConfig


class GEBConfig(PretrainedConfig):
    model_type = "geblm"
    def __init__(
        self,
        num_layers=24,
        padded_vocab_size=64896,
        hidden_size=2048,
        ffn_hidden_size=5632,
        kv_channels=128,
        num_attention_heads=16,
        torch_dtype='bfloat16',
        seq_length=4096,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        layernorm_epsilon=1e-5,
        max_position_embeddings=4096,
        bias_dropout_fusion=True,
        use_cache=True,
        apply_residual_connection_post_layernorm=False,
        post_layer_norm=True,
        add_bias_linear=False,
        use_flash_attn=True,
        num_key_value_heads=4,
        apply_query_key_layer_scaling=False,
        attention_softmax_in_fp32=False,
        fp32_residual_connection=False,
        pre_seq_len=None,
        prefix_projection=False,
        tie_word_embeddings=False,
        **kwargs
    ):
        self.num_layers=num_layers
        self.padded_vocab_size=padded_vocab_size
        self.hidden_size=hidden_size
        self.ffn_hidden_size=ffn_hidden_size
        self.kv_channels=kv_channels
        self.num_attention_heads=num_attention_heads
        self.torch_dtype=torch_dtype
        self.seq_length=seq_length
        self.hidden_dropout=hidden_dropout,
        self.attention_dropout=attention_dropout
        self.layernorm_epsilon=layernorm_epsilon
        self.max_position_embeddings=max_position_embeddings
        self.bias_dropout_fusion=bias_dropout_fusion
        self.use_cache=use_cache
        self.apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm
        self.post_layer_norm=post_layer_norm
        self.add_bias_linear=add_bias_linear
        self.use_flash_attn=use_flash_attn
        self.num_key_value_heads=num_key_value_heads
        self.apply_query_key_layer_scaling=apply_query_key_layer_scaling
        self.attention_softmax_in_fp32=attention_softmax_in_fp32
        self.fp32_residual_connection=fp32_residual_connection
        self.pre_seq_len=pre_seq_len
        self.prefix_projection=prefix_projection
        self.tie_word_embeddings=tie_word_embeddings
        super().__init__(**kwargs)