brivangl commited on
Commit
fb138fe
1 Parent(s): a2327bf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +136 -6
README.md CHANGED
@@ -1,9 +1,139 @@
1
  ---
2
- tags:
3
- - pytorch_model_hub_mixin
4
- - model_hub_mixin
 
 
5
  ---
 
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: mit
3
+ datasets:
4
+ - imagenet1k
5
+ metrics:
6
+ - accuracy
7
  ---
8
+ # VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials
9
 
10
+ This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our [torch-conv-kan](https://github.com/IvanDrokin/torch-conv-kan) repository on GitHub.
11
+
12
+ ## Model description
13
+
14
+ The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head:
15
+
16
+ 1. KAGN Convolution, 32 filters, 3x3
17
+ 2. Max pooling, 2x2
18
+ 3. KAGN Convolution, 64 filters, 3x3
19
+ 4. Max pooling, 2x2
20
+ 5. KAGN Convolution, 128 filters, 3x3
21
+ 6. KAGN Convolution, 128 filters, 3x3
22
+ 7. Max pooling, 2x2
23
+ 8. KAGN Convolution, 256 filters, 3x3
24
+ 9. KAGN Convolution, 256 filters, 3x3
25
+ 10 Max pooling, 2x2
26
+ 11. KAGN Convolution, 256 filters, 3x3
27
+ 12. KAGN Convolution, 256 filters, 3x3
28
+ 13. Max pooling, 2x2
29
+ 14. KAGN Convolution, 512 filters, 3x3
30
+ 15. KAGN Convolution, 512 filters, 3x3
31
+ 16. Global Average pooling
32
+ 17. Output layer, 1000 nodes.
33
+
34
+ ![model image](https://github.com/IvanDrokin/torch-conv-kan/blob/main/assets/vgg_kagn_11_v4.png?raw=true)
35
+
36
+
37
+ ## Intended uses & limitations
38
+
39
+ You can use the raw model for image classification or use it as pretrained model for further finetuning.
40
+
41
+ ### How to use
42
+
43
+ First, clone the repository:
44
+
45
+ ```
46
+ git clone https://github.com/IvanDrokin/torch-conv-kan.git
47
+ cd torch-conv-kan
48
+ pip install -r requirements.txt
49
+ ```
50
+ Then you can initialize the model and load weights.
51
+
52
+ ```python
53
+ import torch
54
+ from models import vggkagn
55
+ model = vggkagn(3,
56
+ 1000,
57
+ groups=1,
58
+ degree=5,
59
+ dropout=0.15,
60
+ l1_decay=0,
61
+ dropout_linear=0.25,
62
+ width_scale=2,
63
+ vgg_type='VGG11v4',
64
+ expected_feature_shape=(1, 1),
65
+ affine=True
66
+ )
67
+ model.from_pretrained('brivangl/vgg_kagn11_v4')
68
+ ```
69
+
70
+ Transforms, used for validation on Imagenet1k:
71
+
72
+ ```python
73
+ from torchvision.transforms import v2
74
+ transforms_val = v2.Compose([
75
+ v2.ToImage(),
76
+ v2.Resize(256, antialias=True),
77
+ v2.CenterCrop(224),
78
+ v2.ToDtype(torch.float32, scale=True),
79
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
80
+ ])
81
+ ```
82
+
83
+
84
+
85
+ ## Training data
86
+ This model trained on Imagenet1k dataset (1281167 images in train set)
87
+
88
+ ## Training procedure
89
+
90
+ Model was trained during 200 full epochs with AdamW optimizer, with following parameters:
91
+ ```python
92
+ {'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06,
93
+ 'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False}
94
+ ```
95
+ And this augmnetations:
96
+ ```python
97
+ transforms_train = v2.Compose([
98
+ v2.ToImage(),
99
+ v2.RandomHorizontalFlip(p=0.5),
100
+ v2.RandomResizedCrop(224, antialias=True),
101
+ v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
102
+ v2.AutoAugment(AutoAugmentPolicy.IMAGENET)
103
+ ]),
104
+ v2.ToDtype(torch.float32, scale=True),
105
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
106
+ ])
107
+ ```
108
+
109
+ ## Evaluation results
110
+
111
+ On Imagenet1k Validation:
112
+
113
+ | Accuracy, top1 | Accuracy, top5 | AUC (ovo) | AUC (ovr) |
114
+ |:--------------:|:--------------:|:---------:|:---------:|
115
+ | 61.17 | 83.26 | 99.42 | 99.43 |
116
+
117
+ On Imagenet1k Test:
118
+ Coming soon
119
+
120
+ ### BibTeX entry and citation info
121
+
122
+ If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.
123
+
124
+ ```bibtex
125
+ @misc{torch-conv-kan,
126
+ author = {Ivan Drokin},
127
+ title = {Torch Conv KAN},
128
+ year = {2024},
129
+ publisher = {GitHub},
130
+ journal = {GitHub repository},
131
+ howpublished = {\url{https://github.com/IvanDrokin/torch-conv-kan}}
132
+ }
133
+ ```
134
+
135
+ ## References
136
+
137
+ - [1] Ziming Liu et al., "KAN: Kolmogorov-Arnold Networks", 2024, arXiv. https://arxiv.org/abs/2404.19756
138
+ - [2] https://github.com/KindXiaoming/pykan
139
+ - [3] https://github.com/Khochawongwat/GRAMKAN