modeling_vit_classifier.py
| 1 | from transformers import PreTrainedModel |
| 2 | import torch |
| 3 | import torch.nn as nn |
| 4 | import torchvision.transforms as transforms |
| 5 | import timm |
| 6 | import PIL.Image as Image |
| 7 | |
| 8 | class ViTClassifier(nn.Module): |
| 9 | def __init__(self, config, device='cuda', dtype=torch.float32): |
| 10 | super(ViTClassifier, self).__init__() |
| 11 | self.config = config |
| 12 | self.device = device |
| 13 | self.dtype = dtype |
| 14 | |
| 15 | # Create the ViT model without unsupported arguments |
| 16 | self.vit = timm.create_model( |
| 17 | config['model']['variant'], |
| 18 | pretrained=False, |
| 19 | num_classes=config['model']['num_classes'], |
| 20 | drop_rate=config['model']['hidden_dropout_prob'], |
| 21 | attn_drop_rate=config['model']['attention_probs_dropout_prob'] |
| 22 | ).to(device) |
| 23 | |
| 24 | # Replace the head with a custom head |
| 25 | self.vit.head = nn.Linear( |
| 26 | in_features=config['model']['head']['in_features'], |
| 27 | out_features=config['model']['head']['out_features'], |
| 28 | bias=config['model']['head']['bias'], |
| 29 | device=device, |
| 30 | dtype=dtype |
| 31 | ) |
| 32 | |
| 33 | if config['model']['freeze_backbone']: |
| 34 | for param in self.vit.parameters(): |
| 35 | param.requires_grad = False |
| 36 | |
| 37 | for param in self.vit.head.parameters(): |
| 38 | assert param.requires_grad == True, "Model head should be trainable." |
| 39 | |
| 40 | def preprocess_input(self, x): |
| 41 | norm_mean = self.config['preprocessing']['norm_mean'] |
| 42 | norm_std = self.config['preprocessing']['norm_std'] |
| 43 | resize_size = self.config['preprocessing']['resize_size'] |
| 44 | crop_size = self.config['preprocessing']['crop_size'] |
| 45 | |
| 46 | augment_list = [ |
| 47 | transforms.Resize(resize_size), |
| 48 | transforms.CenterCrop(crop_size), |
| 49 | transforms.ToTensor(), |
| 50 | transforms.Normalize(mean=norm_mean, std=norm_std), |
| 51 | transforms.ConvertImageDtype(self.dtype), |
| 52 | ] |
| 53 | |
| 54 | preprocess = transforms.Compose(augment_list) |
| 55 | x = preprocess(x) |
| 56 | x = x.unsqueeze(0) |
| 57 | return x |
| 58 | |
| 59 | def forward(self, x): |
| 60 | x = self.preprocess_input(x).to(self.device) |
| 61 | x = self.vit(x) |
| 62 | x = torch.nn.functional.sigmoid(x) |
| 63 | return x |