configuration_MERT.py
5.2 KB · 142 lines · python Raw
1 """
2 MERT model configuration
3 """
4
5 import functools
6 import operator
7
8 # from ...configuration_utils import PretrainedConfig
9 # from ...utils import logging
10 from transformers.configuration_utils import PretrainedConfig
11 from transformers.utils import logging
12
13 logger = logging.get_logger(__name__)
14
15 # TODO: use this MAP while uploading to Huggingface
16 # HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
17 # "facebook/hubert-base-ls960": "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/config.json",
18 # # See all Hubert models at https://huggingface.co/models?filter=hubert
19 # }
20
21
22 class MERTConfig(PretrainedConfig):
23 r"""
24 """
25 model_type = "mert_model"
26
27 def __init__(
28 self,
29 vocab_size=32,
30 hidden_size=768,
31 num_hidden_layers=12,
32 num_attention_heads=12,
33 intermediate_size=3072,
34 hidden_act="gelu",
35 hidden_dropout=0.1,
36 activation_dropout=0.1,
37 attention_dropout=0.1,
38 feat_proj_layer_norm=True,
39 feat_proj_dropout=0.0,
40 final_dropout=0.1,
41 layerdrop=0.1,
42 initializer_range=0.02,
43 layer_norm_eps=1e-5,
44 feat_extract_norm="group",
45 feat_extract_activation="gelu",
46 conv_dim=(512, 512, 512, 512, 512, 512, 512),
47 conv_stride=(5, 2, 2, 2, 2, 2, 2),
48 conv_kernel=(10, 3, 3, 3, 3, 2, 2),
49 conv_bias=False,
50 num_conv_pos_embeddings=128,
51 num_conv_pos_embedding_groups=16,
52 do_stable_layer_norm=False,
53 apply_spec_augment=True,
54 mask_time_prob=0.05,
55 mask_time_length=10,
56 mask_time_min_masks=2,
57 mask_feature_prob=0.0,
58 mask_feature_length=10,
59 mask_feature_min_masks=0,
60 ctc_loss_reduction="sum",
61 ctc_zero_infinity=False,
62 use_weighted_layer_sum=False,
63 classifier_proj_size=256,
64 pad_token_id=0,
65 bos_token_id=1,
66 eos_token_id=2,
67 feature_extractor_cqt=False,
68 feature_extractor_cqt_bins=336,
69 deepnorm=False,
70 attention_relax=-1.0,
71 **kwargs
72 ):
73 super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
74 self.hidden_size = hidden_size
75 self.feat_extract_norm = feat_extract_norm
76 self.feat_extract_activation = feat_extract_activation
77 self.conv_dim = list(conv_dim)
78 self.conv_stride = list(conv_stride)
79 self.conv_kernel = list(conv_kernel)
80 self.conv_bias = conv_bias
81 self.num_conv_pos_embeddings = num_conv_pos_embeddings
82 self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
83 self.num_feat_extract_layers = len(self.conv_dim)
84 self.num_hidden_layers = num_hidden_layers
85 self.intermediate_size = intermediate_size
86 self.hidden_act = hidden_act
87 self.num_attention_heads = num_attention_heads
88 self.hidden_dropout = hidden_dropout
89 self.attention_dropout = attention_dropout
90 self.activation_dropout = activation_dropout
91 self.feat_proj_layer_norm = feat_proj_layer_norm
92 self.feat_proj_dropout = feat_proj_dropout
93 self.final_dropout = final_dropout
94 self.layerdrop = layerdrop
95 self.layer_norm_eps = layer_norm_eps
96 self.initializer_range = initializer_range
97 self.vocab_size = vocab_size
98 self.do_stable_layer_norm = do_stable_layer_norm
99 self.use_weighted_layer_sum = use_weighted_layer_sum
100 self.classifier_proj_size = classifier_proj_size
101
102 if (
103 (len(self.conv_stride) != self.num_feat_extract_layers)
104 or (len(self.conv_kernel) != self.num_feat_extract_layers)
105 or (len(self.conv_dim) != self.num_feat_extract_layers)
106 ):
107 raise ValueError(
108 "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
109 " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
110 f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
111 f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
112 )
113
114 # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
115 self.apply_spec_augment = apply_spec_augment
116 self.mask_time_prob = mask_time_prob
117 self.mask_time_length = mask_time_length
118 self.mask_time_min_masks = mask_time_min_masks
119 self.mask_feature_prob = mask_feature_prob
120 self.mask_feature_length = mask_feature_length
121 self.mask_feature_min_masks = mask_feature_min_masks
122
123 # ctc loss
124 self.ctc_loss_reduction = ctc_loss_reduction
125 self.ctc_zero_infinity = ctc_zero_infinity
126
127 # cqt feature extractor
128 self.feature_extractor_cqt = feature_extractor_cqt
129 self.feature_extractor_cqt_bins = feature_extractor_cqt_bins
130
131 # deepnorm: up-scale weighted residual conection + down-scale initial value transformer encoder
132 self.deepnorm = deepnorm
133
134 self.attention_relax = attention_relax
135
136 # fix bug with hf > 4.42
137 self.conv_pos_batch_norm = False
138
139 @property
140 def inputs_to_logits_ratio(self):
141 return functools.reduce(operator.mul, self.conv_stride, 1)
142