scripts/convert_to_pytorch.py
10.0 KB · 240 lines · python Raw
1 """Convert ViT and non-distilled DeiT checkpoints from the timm library."""
2
3 import argparse
4 from pathlib import Path
5
6 import requests
7 import timm
8 import torch
9 from PIL import Image
10 from timm.data import ImageNetInfo, infer_imagenet_subset
11
12 from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
13 from transformers.utils import logging
14
15
16 logging.set_verbosity_info()
17 logger = logging.get_logger(__name__)
18
19
20 # here we list all keys to be renamed (original name on the left, our name on the right)
21 def create_rename_keys(config, base_model=False):
22 rename_keys = []
23 for i in range(config.num_hidden_layers):
24 # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
25 rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
26 rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
27 rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
28 rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
29 rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
30 rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
31 rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
32 rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
33 rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
34 rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
35
36 # projection layer + position embeddings
37 rename_keys.extend(
38 [
39 ("cls_token", "vit.embeddings.cls_token"),
40 ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
41 ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
42 ("pos_embed", "vit.embeddings.position_embeddings"),
43 ]
44 )
45
46 if base_model:
47 # layernorm
48 rename_keys.extend(
49 [
50 ("norm.weight", "layernorm.weight"),
51 ("norm.bias", "layernorm.bias"),
52 ]
53 )
54
55 # if just the base model, we should remove "vit" from all keys that start with "vit"
56 rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
57 else:
58 # layernorm + classification head
59 rename_keys.extend(
60 [
61 ("norm.weight", "vit.layernorm.weight"),
62 ("norm.bias", "vit.layernorm.bias"),
63 ("head.weight", "classifier.weight"),
64 ("head.bias", "classifier.bias"),
65 ]
66 )
67
68 return rename_keys
69
70
71 # we split up the matrix of each encoder layer into queries, keys and values
72 def read_in_q_k_v(state_dict, config, base_model=False):
73 for i in range(config.num_hidden_layers):
74 if base_model:
75 prefix = ""
76 else:
77 prefix = "vit."
78 # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
79 in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
80 in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
81 # next, add query, keys and values (in that order) to the state dict
82 state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
83 : config.hidden_size, :
84 ]
85 state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
86 state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
87 config.hidden_size : config.hidden_size * 2, :
88 ]
89 state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
90 config.hidden_size : config.hidden_size * 2
91 ]
92 state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
93 -config.hidden_size :, :
94 ]
95 state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
96
97
98 def remove_classification_head_(state_dict):
99 ignore_keys = ["head.weight", "head.bias"]
100 for k in ignore_keys:
101 state_dict.pop(k, None)
102
103
104 def rename_key(dct, old, new):
105 val = dct.pop(old)
106 dct[new] = val
107
108
109 # We will verify our results on an image of cute cats
110 def prepare_img():
111 url = "http://images.cocodataset.org/val2017/000000039769.jpg"
112 im = Image.open(requests.get(url, stream=True).raw)
113 return im
114
115
116 @torch.no_grad()
117 def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
118 """
119 Copy/paste/tweak model's weights to our ViT structure.
120 """
121
122 # define default ViT configuration
123 config = ViTConfig()
124 base_model = False
125
126 # load original model from timm
127 timm_model = timm.create_model(vit_name, pretrained=True)
128 timm_model.eval()
129
130 # detect unsupported ViT models in transformers
131 # fc_norm is present
132 if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity):
133 raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.")
134
135 # use of global average pooling in combination (or without) class token
136 if getattr(timm_model, "global_pool", None) == "avg":
137 raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.")
138
139 # CLIP style vit with norm_pre layer present
140 if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity):
141 raise ValueError(
142 f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer."
143 )
144
145 # SigLIP style vit with attn_pool layer present
146 if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map":
147 raise ValueError(
148 f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool."
149 )
150
151 # use of layer scale in ViT model blocks
152 if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance(
153 getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity
154 ):
155 raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.")
156
157 # Hybrid ResNet-ViTs
158 if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed):
159 raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.")
160
161 # get patch size and image size from the patch embedding submodule
162 config.patch_size = timm_model.patch_embed.patch_size[0]
163 config.image_size = timm_model.patch_embed.img_size[0]
164
165 # retrieve architecture-specific parameters from the timm model
166 config.hidden_size = timm_model.embed_dim
167 config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features
168 config.num_hidden_layers = len(timm_model.blocks)
169 config.num_attention_heads = timm_model.blocks[0].attn.num_heads
170
171 # check whether the model has a classification head or not
172 if timm_model.num_classes != 0:
173 config.num_labels = timm_model.num_classes
174 # infer ImageNet subset from timm model
175 imagenet_subset = infer_imagenet_subset(timm_model)
176 dataset_info = ImageNetInfo(imagenet_subset)
177 config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())}
178 config.label2id = {v: k for k, v in config.id2label.items()}
179 else:
180 print(f"{vit_name} is going to be converted as a feature extractor only.")
181 base_model = True
182
183 # load state_dict of original model
184 state_dict = timm_model.state_dict()
185
186 # remove and rename some keys in the state dict
187 if base_model:
188 remove_classification_head_(state_dict)
189 rename_keys = create_rename_keys(config, base_model)
190 for src, dest in rename_keys:
191 rename_key(state_dict, src, dest)
192 read_in_q_k_v(state_dict, config, base_model)
193
194 # load HuggingFace model
195 if base_model:
196 model = ViTModel(config, add_pooling_layer=False).eval()
197 else:
198 model = ViTForImageClassification(config).eval()
199 model.load_state_dict(state_dict)
200
201 # Check outputs on an image, prepared by ViTImageProcessor/DeiTImageProcessor
202 if "deit" in vit_name:
203 image_processor = DeiTImageProcessor(size=config.image_size)
204 else:
205 image_processor = ViTImageProcessor(size=config.image_size)
206 encoding = image_processor(images=prepare_img(), return_tensors="pt")
207 pixel_values = encoding["pixel_values"]
208 outputs = model(pixel_values)
209
210 if base_model:
211 timm_pooled_output = timm_model.forward_features(pixel_values)
212 assert timm_pooled_output.shape == outputs.last_hidden_state.shape
213 assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1)
214 else:
215 timm_logits = timm_model(pixel_values)
216 assert timm_logits.shape == outputs.logits.shape
217 assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
218
219 Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
220 print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
221 model.save_pretrained(pytorch_dump_folder_path)
222 print(f"Saving image processor to {pytorch_dump_folder_path}")
223 image_processor.save_pretrained(pytorch_dump_folder_path)
224
225
226 if __name__ == "__main__":
227 parser = argparse.ArgumentParser()
228 # Required parameters
229 parser.add_argument(
230 "--vit_name",
231 default="vit_base_patch16_224",
232 type=str,
233 help="Name of the ViT timm model you'd like to convert.",
234 )
235 parser.add_argument(
236 "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
237 )
238
239 args = parser.parse_args()
240 convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path)