--- license: mit datasets: - imagenet1k metrics: - accuracy --- # VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our [torch-conv-kan](https://github.com/IvanDrokin/torch-conv-kan) repository on GitHub. ## Model description The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head: 1. KAGN Convolution, 32 filters, 3x3 2. Max pooling, 2x2 3. KAGN Convolution, 64 filters, 3x3 4. Max pooling, 2x2 5. KAGN Convolution, 128 filters, 3x3 6. KAGN Convolution, 128 filters, 3x3 7. Max pooling, 2x2 8. KAGN Convolution, 256 filters, 3x3 9. KAGN Convolution, 256 filters, 3x3 10 Max pooling, 2x2 11. KAGN Convolution, 256 filters, 3x3 12. KAGN Convolution, 256 filters, 3x3 13. Max pooling, 2x2 14. KAGN Convolution, 512 filters, 3x3 15. KAGN Convolution, 512 filters, 3x3 16. Global Average pooling 17. Output layer, 1000 nodes. ![model image](https://github.com/IvanDrokin/torch-conv-kan/blob/main/assets/vgg_kagn_11_v4.png?raw=true) ## Intended uses & limitations You can use the raw model for image classification or use it as pretrained model for further finetuning. ### How to use First, clone the repository: ``` git clone https://github.com/IvanDrokin/torch-conv-kan.git cd torch-conv-kan pip install -r requirements.txt ``` Then you can initialize the model and load weights. ```python import torch from models import vggkagn model = vggkagn(3, 1000, groups=1, degree=5, dropout=0.15, l1_decay=0, dropout_linear=0.25, width_scale=2, vgg_type='VGG11v4', expected_feature_shape=(1, 1), affine=True ) model.from_pretrained('brivangl/vgg_kagn11_v4') ``` Transforms, used for validation on Imagenet1k: ```python from torchvision.transforms import v2 transforms_val = v2.Compose([ v2.ToImage(), v2.Resize(256, antialias=True), v2.CenterCrop(224), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) ``` ## Training data This model trained on Imagenet1k dataset (1281167 images in train set) ## Training procedure Model was trained during 200 full epochs with AdamW optimizer, with following parameters: ```python {'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06, 'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False} ``` And this augmnetations: ```python transforms_train = v2.Compose([ v2.ToImage(), v2.RandomHorizontalFlip(p=0.5), v2.RandomResizedCrop(224, antialias=True), v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10), v2.AutoAugment(AutoAugmentPolicy.IMAGENET) ]), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) ``` ## Evaluation results On Imagenet1k Validation: | Accuracy, top1 | Accuracy, top5 | AUC (ovo) | AUC (ovr) | |:--------------:|:--------------:|:---------:|:---------:| | 61.17 | 83.26 | 99.42 | 99.43 | On Imagenet1k Test: Coming soon ### BibTeX entry and citation info If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry. ```bibtex @misc{torch-conv-kan, author = {Ivan Drokin}, title = {Torch Conv KAN}, year = {2024}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/IvanDrokin/torch-conv-kan}} } ``` ## References - [1] Ziming Liu et al., "KAN: Kolmogorov-Arnold Networks", 2024, arXiv. https://arxiv.org/abs/2404.19756 - [2] https://github.com/KindXiaoming/pykan - [3] https://github.com/Khochawongwat/GRAMKAN