README.md
5.9 KB · 166 lines · markdown Raw
1 ---
2 metrics:
3 - accuracy
4 pipeline_tag: image-classification
5 base_model: google/vit-base-patch16-384
6 model-index:
7 - name: AdamCodd/vit-base-nsfw-detector
8 results:
9 - task:
10 type: image-classification
11 name: Image Classification
12 metrics:
13 - type: accuracy
14 value: 0.9654
15 name: Accuracy
16 - type: AUC
17 value: 0.9948
18 - type: loss
19 value: 0.0937
20 name: Loss
21 license: apache-2.0
22 tags:
23 - transformers.js
24 - transformers
25 - nlp
26 ---
27
28 # vit-base-nsfw-detector
29
30 This model is a fine-tuned version of [vit-base-patch16-384](https://huggingface.co/google/vit-base-patch16-384) on around 25_000 images (drawings, photos...).
31 It achieves the following results on the evaluation set:
32 - Loss: 0.0937
33 - Accuracy: 0.9654
34
35 **<u>New [07/30]</u>**: I created a new ViT model specifically to detect NSFW/SFW images for stable diffusion usage (read the disclaimer below for the reason): [**AdamCodd/vit-nsfw-stable-diffusion**](https://huggingface.co/AdamCodd/vit-nsfw-stable-diffusion).
36
37 **Disclaimer**: This model wasn't made with generative images in mind! There is no generated image in the dataset used here, and it performs significantly worse on generative images, which will require another ViT model specifically trained on generative images. Here are the model's actual scores for generative images to give you an idea:
38 - Loss: 0.3682 (↑ 292.95%)
39 - Accuracy: 0.8600 (↓ 10.91%)
40 - F1: 0.8654
41 - AUC: 0.9376 (↓ 5.75%)
42 - Precision: 0.8350
43 - Recall: 0.8980
44
45 ## Model description
46
47 The Vision Transformer (ViT) is a transformer encoder model (BERT-like) pretrained on a large collection of images in a supervised fashion, namely ImageNet-21k, at a resolution of 224x224 pixels. Next, the model was fine-tuned on ImageNet (also referred to as ILSVRC2012), a dataset comprising 1 million images and 1,000 classes, at a higher resolution of 384x384.
48
49 ## Intended uses & limitations
50
51 There are two classes: SFW and NSFW. The model has been trained to be restrictive and therefore classify "sexy" images as NSFW. That is, if the image shows cleavage or too much skin, it will be classified as NSFW. This is normal.
52
53 Usage for a local image:
54 ```python
55 from transformers import pipeline
56 from PIL import Image
57
58 img = Image.open("<path_to_image_file>")
59 predict = pipeline("image-classification", model="AdamCodd/vit-base-nsfw-detector")
60 predict(img)
61 ```
62
63 Usage for a distant image:
64 ```python
65 from transformers import ViTImageProcessor, AutoModelForImageClassification
66 from PIL import Image
67 import requests
68
69 url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
70 image = Image.open(requests.get(url, stream=True).raw)
71 processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
72 model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
73 inputs = processor(images=image, return_tensors="pt")
74 outputs = model(**inputs)
75 logits = outputs.logits
76
77 predicted_class_idx = logits.argmax(-1).item()
78 print("Predicted class:", model.config.id2label[predicted_class_idx])
79 # Predicted class: sfw
80 ```
81
82 Usage with Transformers.js (Vanilla JS):
83 ```js
84 /* Instructions:
85 * - Place this script in an HTML file using the <script type="module"> tag.
86 * - Ensure the HTML file is served over a local or remote server (e.g., using Python's http.server, Node.js server, or similar).
87 * - Replace 'https://example.com/path/to/image.jpg' in the classifyImage function call with the URL of the image you want to classify.
88 *
89 * Example of how to include this script in HTML:
90 * <script type="module" src="path/to/this_script.js"></script>
91 *
92 * This setup ensures that the script can use imports and perform network requests without CORS issues.
93 */
94 import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.1';
95
96 // Since we will download the model from HuggingFace Hub, we can skip the local model check
97 env.allowLocalModels = false;
98
99 // Load the image classification model
100 const classifier = await pipeline('image-classification', 'AdamCodd/vit-base-nsfw-detector');
101
102 // Function to fetch and classify an image from a URL
103 async function classifyImage(url) {
104 try {
105 const response = await fetch(url);
106 if (!response.ok) throw new Error('Failed to load image');
107
108 const blob = await response.blob();
109 const image = new Image();
110 const imagePromise = new Promise((resolve, reject) => {
111 image.onload = () => resolve(image);
112 image.onerror = reject;
113 image.src = URL.createObjectURL(blob);
114 });
115
116 const img = await imagePromise; // Ensure the image is loaded
117 const classificationResults = await classifier([img.src]); // Classify the image
118 console.log('Predicted class: ', classificationResults[0].label);
119 } catch (error) {
120 console.error('Error classifying image:', error);
121 }
122 }
123
124 // Example usage
125 classifyImage('https://example.com/path/to/image.jpg');
126 // Predicted class: sfw
127 ```
128
129
130 The model has been trained on a variety of images (realistic, 3D, drawings), yet it is not perfect and some images may be wrongly classified as NSFW when they are not. Additionally, please note that using the quantized ONNX model within the transformers.js pipeline will slightly reduce the model's accuracy.
131 You can find a toy implementation of this model with Transformers.js [here](https://github.com/AdamCodd/media-random-generator).
132
133 ## Training and evaluation data
134
135 More information needed
136
137 ## Training procedure
138
139 ### Training hyperparameters
140
141 The following hyperparameters were used during training:
142 - learning_rate: 3e-05
143 - train_batch_size: 32
144 - eval_batch_size: 32
145 - seed: 42
146 - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
147 - num_epochs: 1
148
149 ### Training results
150
151 - Validation Loss: 0.0937
152 - Accuracy: 0.9654,
153 - AUC: 0.9948
154
155 [Confusion matrix](https://huggingface.co/AdamCodd/vit-base-nsfw-detector/resolve/main/confusion_matrix.png) (eval):
156
157 [1076 37]
158
159 [ 60 1627]
160
161 ### Framework versions
162
163 - Transformers 4.36.2
164 - Evaluate 0.4.1
165
166 If you want to support me, you can [here](https://ko-fi.com/adamcodd).