File size: 4,461 Bytes
d848f34
54ad464
 
d848f34
54ad464
 
 
 
 
5b65ee9
5e396e0
 
 
54ad464
 
5b65ee9
 
54ad464
 
 
 
5e396e0
5b65ee9
 
 
 
249878a
54ad464
 
 
5b65ee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54ad464
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
---
library_name: peft
base_model: Qwen/Qwen-VL-Chat
---

# Model Card for Model ID

<!-- Provide a quick summary of what the model is/does. -->

- LoRA: wdtag -> long caption.
  
LICENSE: Tongyi Qianwen LICENSE  
  
## Model Details

- Finetuned.

### Model Description

<!-- Provide a longer summary of what this model is. -->

- **Developed by:** cella
- **Model type:** LoRA
- **Language(s) (NLP):** Eng
- **License:** Tongyi Qianwen LICENSE
- **Finetuned from model [optional]:** Qwen-VL-Chat

## Uses


### Model Load
```
LoRA_DIR = "/path-to-LoRA-dir"

if OPTION_VLM_METHOD == 'qwen_chat_LoRA':
        from peft import AutoPeftModelForCausalLM
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from transformers.generation import GenerationConfig
        import torch
        torch.manual_seed(1234)

        # Note: The default behavior now has injection attack prevention off.
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
        \
        # use cuda device
        model = AutoPeftModelForCausalLM.from_pretrained(
                LoRA_DIR, # path to the output directory
                device_map="auto",
                trust_remote_code=True
        ).eval()

        # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
        model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
        
else:
    print("skipped.")
```

### Captioning
```
if OPTION_VLM_METHOD == 'qwen_chat':
    from PIL import Image
    from langdetect import detect
    import string
    import re

    COMMON_QUERY = 'What is in tha image? Briefly describe the overall, in English'
    MORE_QUERY = 'What is in tha image? Describe the overall in detail, in English'
    LESS_QUERY = 'What is in tha image? Briefly summerize the description, in English'

    for image in dataset.images:       
        img_name = os.path.basename(image.path)
        img_name = os.path.splitext(img_name)[0]

        # すでにアウトプットフォルダに同名のtxtファイルが存在する場合はスキップ
        if OPTION_SKIP_EXISTING and os.path.exists(os.path.join(output_dir_VLM, img_name + '.txt')):
            clear_output(True)
            print("skipped: ", image.path)
            continue

        query = tokenizer.from_list_format([
            {'image': image.path },
            {'text': 'Make description using following words' + ', '.join(image.captions).replace('_', ' ') },
        ])
        response, history = model.chat(tokenizer, query=query, history=None)
            
        # ASCIIチェック、言語チェック、長さチェック
        retry_count = 0
        while not is_ascii(response) or not is_english(response) or not is_sufficient_length(response) or not is_over_length(response):
            clear_output(True)
            retry_count +=1
            print("Retry count:", retry_count)
            if retry_count >= 25 and is_ascii(response):
                break
            if not is_sufficient_length(response):
                print("Too short. Retry...")
                query = tokenizer.from_list_format([
                    {'image': image.path },
                    {'text': MORE_QUERY },
                ])
            if not is_over_length(response):
                print("Too long. Retry...")
                query = tokenizer.from_list_format([
                    {'image': image.path },
                    {'text': LESS_QUERY },
                ])
            if retry_count % 5 == 0:
                history = None
                query = tokenizer.from_list_format([
                    {'image': image.path },
                    {'text': COMMON_QUERY },
                ])
            response, history = model.chat(tokenizer, query=query, history=history)
            
        response = remove_fixed_patterns(response)

        if OPTION_SAVE_TAGS:
            # タグを保存
            with open(os.path.join(output_dir_VLM, img_name + '.txt'), 'w') as file:
                file.write(response)

        image.captions = response

        clear_output(True)

        print("Saved for ", image.path, ": ", response)
        
        #画像を表示
        img = Image.open(image.path)
        plt.imshow(np.asarray(img))
        plt.show()
        
else:
    print("skipped.")
```

### Framework versions

- PEFT 0.7.1