configuration_xlm_roberta.py
| 1 | from transformers import PretrainedConfig |
| 2 | import torch |
| 3 | |
| 4 | class XLMRobertaFlashConfig(PretrainedConfig): |
| 5 | def __init__( |
| 6 | self, |
| 7 | vocab_size=30522, |
| 8 | hidden_size=768, |
| 9 | num_hidden_layers=12, |
| 10 | num_attention_heads=12, |
| 11 | intermediate_size=3072, |
| 12 | hidden_act="gelu", |
| 13 | hidden_dropout_prob=0.1, |
| 14 | attention_probs_dropout_prob=0.1, |
| 15 | max_position_embeddings=512, |
| 16 | type_vocab_size=2, |
| 17 | initializer_range=0.02, |
| 18 | layer_norm_eps=1e-12, |
| 19 | pad_token_id=1, |
| 20 | bos_token_id=0, |
| 21 | eos_token_id=2, |
| 22 | position_embedding_type="absolute", |
| 23 | use_cache=True, |
| 24 | classifier_dropout=None, |
| 25 | lora_adaptations=None, |
| 26 | lora_rank=4, |
| 27 | lora_dropout_p=0.0, |
| 28 | lora_alpha=1, |
| 29 | lora_main_params_trainable=False, |
| 30 | load_trained_adapters=False, |
| 31 | use_flash_attn=True, |
| 32 | torch_dtype=None, |
| 33 | emb_pooler=None, |
| 34 | matryoshka_dimensions=None, |
| 35 | truncate_dim=None, |
| 36 | **kwargs, |
| 37 | ): |
| 38 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
| 39 | |
| 40 | |
| 41 | self.vocab_size = vocab_size |
| 42 | self.hidden_size = hidden_size |
| 43 | self.num_hidden_layers = num_hidden_layers |
| 44 | self.num_attention_heads = num_attention_heads |
| 45 | self.hidden_act = hidden_act |
| 46 | self.intermediate_size = intermediate_size |
| 47 | self.hidden_dropout_prob = hidden_dropout_prob |
| 48 | self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| 49 | self.max_position_embeddings = max_position_embeddings |
| 50 | self.type_vocab_size = type_vocab_size |
| 51 | self.initializer_range = initializer_range |
| 52 | self.layer_norm_eps = layer_norm_eps |
| 53 | self.position_embedding_type = position_embedding_type |
| 54 | self.use_cache = use_cache |
| 55 | self.classifier_dropout = classifier_dropout |
| 56 | self.load_trained_adapters = load_trained_adapters |
| 57 | self.lora_adaptations = lora_adaptations |
| 58 | self.lora_rank = lora_rank |
| 59 | self.lora_dropout_p = lora_dropout_p |
| 60 | self.lora_alpha = lora_alpha |
| 61 | self.lora_main_params_trainable = lora_main_params_trainable |
| 62 | self.use_flash_attn = use_flash_attn |
| 63 | self.emb_pooler = emb_pooler |
| 64 | self.matryoshka_dimensions = matryoshka_dimensions |
| 65 | self.truncate_dim = truncate_dim |
| 66 | if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype: |
| 67 | self.torch_dtype = getattr(torch, torch_dtype) |
| 68 | else: |
| 69 | self.torch_dtype = torch_dtype |
| 70 | |