File size: 7,458 Bytes
f472eef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Quick training guide
Combine it with this guide, it's really helpful!

[Fine-Tune ViT for Image Classification with Hugging Face Transformers](https://huggingface.co/blog/fine-tune-vit)
## Start
```bash
pip install transformers datasets
```
## Preparing the data:

Your data shouldn't look like this:
```json
{
    "file_name": "train/aeae3547df6be819a42dcbb83e65586fd6deb424f134375c1dbc00188b37e2bf.jpeg", 
    "labels":  ["general", "furina (genshin impact)", "1girl", "ahoge", "bangs", "bare shoulders", ...]
}
```
But it should look more like this:
```json
{
    {
    "file_name": "train/aeae3547df6be819a42dcbb83e65586fd6deb424f134375c1dbc00188b37e2bf.jpeg", 
    "labels": ["0", "3028", "4", "702", "8", "9", "382", ...]
}
}
```

Where the labels should be represented as a list of integers (or anything you define as a number) that correspond to the tags you want to train with – essentially, they're the IDs of the labels.

Loading labels and their IDs:

```python
import csv

with open("labels.csv", "r", encoding="utf-8") as f:
    reader = csv.reader(f)
    l = [row for row in reader]
    header = l[0]  # tag_id,name,category
    rows = l[1:]

id2labels = {}
labels2id = {}

for row in rows:
    id2labels[str(row[0])] = row[1]
    labels2id[row[1]] = str(row[0])
```

Where `labels.csv` is a file containing labels and their respective IDs.

Load dataset:
```python
from datasets import load_dataset
dataset = load_dataset("./vit_dataset")
```
Congratulations! You've completed the toughest challenge. Why, you ask? Training this model took me a whole week just to gather and label the data.

## Preprocess:

```python
from transformers import ViTImageProcessor
import torch
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

def transform(example_batch):
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    inputs['labels'] = []
    inputs['label_names'] = [[id2labels[tagid] for tagid in x] for x in example_batch['labels']]
    
    for x in example_batch['labels']:
        x : list
        one_hot = [0 for x in range(0, len(labels2id.items()))]
        for index in x:
            one_hot[int(index)] = 1

        inputs['labels'] += [one_hot]


    return inputs
```
Well, this code might not look pretty, but it gets the job done! As for the images (inputs), we resize them to 224x224 and flatten them out. Now, for the labels (target), we're transforming them into a multi-hot format. Why, you ask? Because I like it that way, and it's simple.

## Training
These parts are relatively simple so I'll go quickly.

- Load dataset:

```python
from torch.utils.data import DataLoader

batch_size = 16

def collate_fn(batch):
    data = {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([torch.tensor(x['labels']) for x in batch]),
        'label_names' : [x['label_names'] for x in batch]
    }

    return data

train_dataloader = DataLoader(prepared_dataset['train'], collate_fn=collate_fn, batch_size=batch_size)

eval_dataloader = DataLoader(prepared_dataset['test'], collate_fn=collate_fn, batch_size=1)
```
- Initialize the model:

```python
from transformers import ViTForImageClassification, ViTConfig

configuration  = ViTConfig(
    num_labels=len(id2labels.items()),
    id2label=id2labels,
    label2id=labels2id)
model = ViTForImageClassification(config=configuration)
```

Setup train:

```python
device = torch.device('cuda')
test_steps = 5000
epochs = 50
mix_precision = torch.float16
global_steps = 0

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=optimizer)
```

Test and Evaluation:
```python
import torch
from transformers.modeling_outputs import ImageClassifierOutput

def test(eval_dataloader : DataLoader, model : ViTForImageClassification, device, t=0.7):
    batchs = list(iter(eval_dataloader))
    batch = batchs[0]

    with torch.no_grad():
        pixel_values = batch['pixel_values'].to(device=device)
        labels = batch['labels'].to(device=device, dtype=torch.float)

        outputs : ImageClassifierOutput = model(pixel_values=pixel_values)

        logits = outputs.logits
        sigmod = torch.nn.Sigmoid()
        logits : torch.FloatTensor = sigmod(logits)
        predictions = []

        for idx, p in enumerate(logits[0]):
            if p > t:
                predictions.append((model.config.id2label[idx], p.item()))

    print(f"label_names : {batch['label_names'][0]}")
    print(f"predictions : {predictions}")

def eval(eval_dataloader : DataLoader, model : ViTForImageClassification, device, t=0.7):

    result = {
        "eval_predictions" : 0,
        "eval_loss" : 0,
        "total_predictions" : 0,
        "total_loss" : 0
    }

    for batch in eval_dataloader:
        pixel_values = batch['pixel_values'].to(device=device)
        labels = batch['labels'].to(device=device, dtype=torch.float)
        label_names = batch['label_names'][0]

        prediction = 0
        with torch.no_grad():
        
            outputs : ImageClassifierOutput = model(pixel_values=pixel_values, labels=labels)

        logits = outputs.logits
        loss = outputs.loss
        predictions = []

        for idx, p in enumerate(logits[0]):
            if p > t:
                predictions.append(model.config.id2label[idx])
        
        for p in predictions:
            if p in label_names:
                prediction += 1 / len(label_names)
        
        result['total_predictions'] += prediction
        result['total_loss'] += loss.item()

    result['eval_predictions'] = result['total_predictions'] / len(eval_dataloader)
    result['eval_loss'] = result['total_loss'] / len(eval_dataloader)
    print(result)
```
Train:
```python
import tqdm
from transformers.modeling_outputs import ImageClassifierOutput

process_bar = tqdm.tqdm(total=epochs * len(train_dataloader))

for e in range(1, epochs + 1):
    model.train()

    total_loss = 0

    for idx, (batch) in enumerate(train_dataloader):

        pixel_values = batch['pixel_values'].to(device=device)
        labels = batch['labels'].to(device=device, dtype=torch.float)
        
        with torch.autocast(device_type=str(device), dtype=mix_precision):
            outputs : ImageClassifierOutput = model(pixel_values=pixel_values, labels=labels)
        
        loss = outputs.loss
        total_loss += loss.detach().float()

        loss.backward()
        if torch.isnan(loss):
            assert False, "NaN detection."

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        process_bar.update(1)
        process_bar.desc = f"{model.config.problem_type} - Epoch: {e}/{epochs}"
        process_bar.set_postfix({'loss' : f'{loss.item():.5f}', "train_loss" : total_loss.item() / len(train_dataloader)})

        if global_steps % test_steps == 0 and global_steps > 1:
            model.eval()
            process_bar.desc = f"Evalute - Epoch: {e}/{epochs}"
            eval(eval_dataloader=eval_dataloader, model=model, device=device, t=0.3)
            test(eval_dataloader, model, device, 0.3)
            model.train()

        global_steps += 1
```

Thank you for reading through all this verbose stuff. Of course, all the code above is impromptu; there might be some inconsistencies. Your contributions are highly appreciated.