Ojimi commited on
Commit
f472eef
1 Parent(s): 9613a1a

add readme and training guide

Browse files
Files changed (2) hide show
  1. readme.md +61 -0
  2. training_guide.md +243 -0
readme.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: image-classification
4
+ tags:
5
+ - pytorch
6
+ - vision
7
+ ---
8
+
9
+ This model is the product of curiosity—imagine a choice that allows you to label anime images!
10
+
11
+ **Disclaimer**: The model has been trained on an entirely new dataset. Predictions made by the model *prior to 2023 might be off*. It's advisable to fine-tune the model according to your specific use case.
12
+
13
+ # Quick setup guide:
14
+
15
+ ```python
16
+ from transformers.modeling_outputs import ImageClassifierOutput
17
+ from transformers import ViTImageProcessor, ViTForImageClassification
18
+ import torch
19
+ from PIL import Image
20
+
21
+ model_name_or_path = "vit-anime-base/"
22
+ processor = ViTImageProcessor.from_pretrained(model_name_or_path)
23
+ model = ViTForImageClassification.from_pretrained(model_name_or_path)
24
+ threshold = 0.3
25
+
26
+ device = torch.device('cuda')
27
+
28
+ image = Image.open(YOUR_IMAGE_PATH)
29
+
30
+ inputs = processor(image, return_tensors='pt')
31
+
32
+ model.to(device=device)
33
+ model.eval()
34
+
35
+
36
+ with torch.no_grad():
37
+ pixel_values = inputs['pixel_values'].to(device=device)
38
+
39
+ outputs : ImageClassifierOutput = model(pixel_values=pixel_values)
40
+
41
+ logits = outputs.logits # The raw scores before applying any activation
42
+ sigmoid = torch.nn.Sigmoid() # Sigmoid function to convert logits to probabilities
43
+ logits : torch.FloatTensor = sigmoid(logits) # Applying sigmoid activation
44
+
45
+ predictions = [] # List to store predictions
46
+
47
+ for idx, p in enumerate(logits[0]):
48
+ if p > threshold: # Applying a threshold of 0.3 to consider a class prediction
49
+ predictions.append((model.config.id2label[idx], p.item())) # Storing class label and probability
50
+
51
+ for tag in predictions:
52
+ print(tag)
53
+
54
+
55
+ ```
56
+
57
+ Why the `Sigmoid`?
58
+ - Sigmoid turns boring scores into fun probabilities, so you can use thresholds and find more cool tags.
59
+ - It's like a wizard turning regular stuff into magic potions!
60
+
61
+ [Training guide](/training_guide.md)
training_guide.md ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick training guide
2
+ Combine it with this guide, it's really helpful!
3
+
4
+ [Fine-Tune ViT for Image Classification with Hugging Face Transformers](https://huggingface.co/blog/fine-tune-vit)
5
+ ## Start
6
+ ```bash
7
+ pip install transformers datasets
8
+ ```
9
+ ## Preparing the data:
10
+
11
+ Your data shouldn't look like this:
12
+ ```json
13
+ {
14
+ "file_name": "train/aeae3547df6be819a42dcbb83e65586fd6deb424f134375c1dbc00188b37e2bf.jpeg",
15
+ "labels": ["general", "furina (genshin impact)", "1girl", "ahoge", "bangs", "bare shoulders", ...]
16
+ }
17
+ ```
18
+ But it should look more like this:
19
+ ```json
20
+ {
21
+ {
22
+ "file_name": "train/aeae3547df6be819a42dcbb83e65586fd6deb424f134375c1dbc00188b37e2bf.jpeg",
23
+ "labels": ["0", "3028", "4", "702", "8", "9", "382", ...]
24
+ }
25
+ }
26
+ ```
27
+
28
+ Where the labels should be represented as a list of integers (or anything you define as a number) that correspond to the tags you want to train with – essentially, they're the IDs of the labels.
29
+
30
+ Loading labels and their IDs:
31
+
32
+ ```python
33
+ import csv
34
+
35
+ with open("labels.csv", "r", encoding="utf-8") as f:
36
+ reader = csv.reader(f)
37
+ l = [row for row in reader]
38
+ header = l[0] # tag_id,name,category
39
+ rows = l[1:]
40
+
41
+ id2labels = {}
42
+ labels2id = {}
43
+
44
+ for row in rows:
45
+ id2labels[str(row[0])] = row[1]
46
+ labels2id[row[1]] = str(row[0])
47
+ ```
48
+
49
+ Where `labels.csv` is a file containing labels and their respective IDs.
50
+
51
+ Load dataset:
52
+ ```python
53
+ from datasets import load_dataset
54
+ dataset = load_dataset("./vit_dataset")
55
+ ```
56
+ Congratulations! You've completed the toughest challenge. Why, you ask? Training this model took me a whole week just to gather and label the data.
57
+
58
+ ## Preprocess:
59
+
60
+ ```python
61
+ from transformers import ViTImageProcessor
62
+ import torch
63
+ model_name_or_path = 'google/vit-base-patch16-224-in21k'
64
+ processor = ViTImageProcessor.from_pretrained(model_name_or_path)
65
+
66
+ def transform(example_batch):
67
+ inputs = processor([x for x in example_batch['image']], return_tensors='pt')
68
+
69
+ inputs['labels'] = []
70
+ inputs['label_names'] = [[id2labels[tagid] for tagid in x] for x in example_batch['labels']]
71
+
72
+ for x in example_batch['labels']:
73
+ x : list
74
+ one_hot = [0 for x in range(0, len(labels2id.items()))]
75
+ for index in x:
76
+ one_hot[int(index)] = 1
77
+
78
+ inputs['labels'] += [one_hot]
79
+
80
+
81
+ return inputs
82
+ ```
83
+ Well, this code might not look pretty, but it gets the job done! As for the images (inputs), we resize them to 224x224 and flatten them out. Now, for the labels (target), we're transforming them into a multi-hot format. Why, you ask? Because I like it that way, and it's simple.
84
+
85
+ ## Training
86
+ These parts are relatively simple so I'll go quickly.
87
+
88
+ - Load dataset:
89
+
90
+ ```python
91
+ from torch.utils.data import DataLoader
92
+
93
+ batch_size = 16
94
+
95
+ def collate_fn(batch):
96
+ data = {
97
+ 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
98
+ 'labels': torch.stack([torch.tensor(x['labels']) for x in batch]),
99
+ 'label_names' : [x['label_names'] for x in batch]
100
+ }
101
+
102
+ return data
103
+
104
+ train_dataloader = DataLoader(prepared_dataset['train'], collate_fn=collate_fn, batch_size=batch_size)
105
+
106
+ eval_dataloader = DataLoader(prepared_dataset['test'], collate_fn=collate_fn, batch_size=1)
107
+ ```
108
+ - Initialize the model:
109
+
110
+ ```python
111
+ from transformers import ViTForImageClassification, ViTConfig
112
+
113
+ configuration = ViTConfig(
114
+ num_labels=len(id2labels.items()),
115
+ id2label=id2labels,
116
+ label2id=labels2id)
117
+ model = ViTForImageClassification(config=configuration)
118
+ ```
119
+
120
+ Setup train:
121
+
122
+ ```python
123
+ device = torch.device('cuda')
124
+ test_steps = 5000
125
+ epochs = 50
126
+ mix_precision = torch.float16
127
+ global_steps = 0
128
+
129
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
130
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=optimizer)
131
+ ```
132
+
133
+ Test and Evaluation:
134
+ ```python
135
+ import torch
136
+ from transformers.modeling_outputs import ImageClassifierOutput
137
+
138
+ def test(eval_dataloader : DataLoader, model : ViTForImageClassification, device, t=0.7):
139
+ batchs = list(iter(eval_dataloader))
140
+ batch = batchs[0]
141
+
142
+ with torch.no_grad():
143
+ pixel_values = batch['pixel_values'].to(device=device)
144
+ labels = batch['labels'].to(device=device, dtype=torch.float)
145
+
146
+ outputs : ImageClassifierOutput = model(pixel_values=pixel_values)
147
+
148
+ logits = outputs.logits
149
+ sigmod = torch.nn.Sigmoid()
150
+ logits : torch.FloatTensor = sigmod(logits)
151
+ predictions = []
152
+
153
+ for idx, p in enumerate(logits[0]):
154
+ if p > t:
155
+ predictions.append((model.config.id2label[idx], p.item()))
156
+
157
+ print(f"label_names : {batch['label_names'][0]}")
158
+ print(f"predictions : {predictions}")
159
+
160
+ def eval(eval_dataloader : DataLoader, model : ViTForImageClassification, device, t=0.7):
161
+
162
+ result = {
163
+ "eval_predictions" : 0,
164
+ "eval_loss" : 0,
165
+ "total_predictions" : 0,
166
+ "total_loss" : 0
167
+ }
168
+
169
+ for batch in eval_dataloader:
170
+ pixel_values = batch['pixel_values'].to(device=device)
171
+ labels = batch['labels'].to(device=device, dtype=torch.float)
172
+ label_names = batch['label_names'][0]
173
+
174
+ prediction = 0
175
+ with torch.no_grad():
176
+
177
+ outputs : ImageClassifierOutput = model(pixel_values=pixel_values, labels=labels)
178
+
179
+ logits = outputs.logits
180
+ loss = outputs.loss
181
+ predictions = []
182
+
183
+ for idx, p in enumerate(logits[0]):
184
+ if p > t:
185
+ predictions.append(model.config.id2label[idx])
186
+
187
+ for p in predictions:
188
+ if p in label_names:
189
+ prediction += 1 / len(label_names)
190
+
191
+ result['total_predictions'] += prediction
192
+ result['total_loss'] += loss.item()
193
+
194
+ result['eval_predictions'] = result['total_predictions'] / len(eval_dataloader)
195
+ result['eval_loss'] = result['total_loss'] / len(eval_dataloader)
196
+ print(result)
197
+ ```
198
+ Train:
199
+ ```python
200
+ import tqdm
201
+ from transformers.modeling_outputs import ImageClassifierOutput
202
+
203
+ process_bar = tqdm.tqdm(total=epochs * len(train_dataloader))
204
+
205
+ for e in range(1, epochs + 1):
206
+ model.train()
207
+
208
+ total_loss = 0
209
+
210
+ for idx, (batch) in enumerate(train_dataloader):
211
+
212
+ pixel_values = batch['pixel_values'].to(device=device)
213
+ labels = batch['labels'].to(device=device, dtype=torch.float)
214
+
215
+ with torch.autocast(device_type=str(device), dtype=mix_precision):
216
+ outputs : ImageClassifierOutput = model(pixel_values=pixel_values, labels=labels)
217
+
218
+ loss = outputs.loss
219
+ total_loss += loss.detach().float()
220
+
221
+ loss.backward()
222
+ if torch.isnan(loss):
223
+ assert False, "NaN detection."
224
+
225
+ optimizer.step()
226
+ lr_scheduler.step()
227
+ optimizer.zero_grad()
228
+
229
+ process_bar.update(1)
230
+ process_bar.desc = f"{model.config.problem_type} - Epoch: {e}/{epochs}"
231
+ process_bar.set_postfix({'loss' : f'{loss.item():.5f}', "train_loss" : total_loss.item() / len(train_dataloader)})
232
+
233
+ if global_steps % test_steps == 0 and global_steps > 1:
234
+ model.eval()
235
+ process_bar.desc = f"Evalute - Epoch: {e}/{epochs}"
236
+ eval(eval_dataloader=eval_dataloader, model=model, device=device, t=0.3)
237
+ test(eval_dataloader, model, device, 0.3)
238
+ model.train()
239
+
240
+ global_steps += 1
241
+ ```
242
+
243
+ Thank you for reading through all this verbose stuff. Of course, all the code above is impromptu; there might be some inconsistencies. Your contributions are highly appreciated.