from transformers import PretrainedConfig from typing import List, Any, Optional, Union class ScgptConfig(PretrainedConfig): model_type = "scgpt" def __init__( self, ntoken: int = 60697, d_model: int = 512, nhead: int = 8, d_hid: int = 512, nlayers: int = 12, nlayers_cls: int = 3, n_cls: int = 1, vocab: Any = None, dropout: float = 0.5, pad_token: str = "", pad_value: int = 0, pert_pad_id: int = 2, do_mvc: bool = False, do_dab: bool = False, use_batch_labels: bool = False, num_batch_labels: Optional[int] = None, domain_spec_batchnorm: Union[bool, str] = False, input_emb_style: str = "continuous", n_input_bins: Optional[int] = None, cell_emb_style: str = "cls", mvc_decoder_style: str = "inner product", ecs_threshold: float = 0.3, explicit_zero_prob: bool = False, use_fast_transformer: bool = False, fast_transformer_backend: str = "flash", pre_norm: bool = False, use_mod: bool = False, ntokens_mod: Optional[int] = None, vocab_mod: Optional[Any] = None, **kwargs, ): #if block_type not in ["basic", "bottleneck"]: # raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.") #if stem_type not in ["", "deep", "deep-tiered"]: # raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.") self.ntoken = ntoken self.d_model = d_model self.nhead = nhead self.d_hid = d_hid self.nlayers = nlayers self.nlayers_cls = nlayers_cls self.n_cls = n_cls self.vocab = vocab self.dropout = dropout self.pad_token = pad_token self.pad_value = pad_value self.pert_pad_id = pert_pad_id self.do_mvc = do_mvc self.do_dab = do_dab self.use_batch_labels = use_batch_labels self.num_batch_labels = num_batch_labels self.domain_spec_batchnorm = domain_spec_batchnorm self.input_emb_style = input_emb_style self.n_input_bins = n_input_bins self.cell_emb_style = cell_emb_style self.mvc_decoder_style = mvc_decoder_style self.ecs_threshold = ecs_threshold self.explicit_zero_prob = explicit_zero_prob self.use_fast_transformer = use_fast_transformer self.fast_transformer_backend = fast_transformer_backend self.pre_norm = pre_norm self.use_mod = use_mod self.ntokens_mod = ntokens_mod self.vocab_mod = vocab_mod super().__init__(**kwargs)