File size: 3,821 Bytes
3424266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sys, os
import torch

from .vqvae import VQ, VQVAE, DiVAE, VQControlNet
from .scheduling import *


def get_image_tokenizer(tokenizer_id: str, 
                        tokenizers_root: str = './tokenizer_ckpts', 
                        encoder_only: bool = False, 
                        device: str = 'cuda', 
                        verbose: bool = True,
                        return_None_on_fail: bool = False,):
    """
    Load a pretrained image tokenizer from a checkpoint.

    Args:
        tokenizer_id (str): ID of the tokenizer to load (name of the checkpoint file without ".pth").
        tokenizers_root (str): Path to the directory containing the tokenizer checkpoints.
        encoder_only (bool): Set to True to load only the encoder part of the tokenizer.
        device (str): Device to load the tokenizer on.
        verbose (bool): Set to True to print load_state_dict warning/success messages
        return_None_on_fail (bool): Set to True to return None if the tokenizer fails to load (e.g. doesn't exist)

    Returns:
        model (nn.Module): The loaded tokenizer.
    """
    if return_None_on_fail and not os.path.exists(os.path.join(tokenizers_root, f'{tokenizer_id}.pth')):
        return None
    
    if verbose:
        print(f'Loading tokenizer {tokenizer_id} ... ', end='')
    
    ckpt = torch.load(os.path.join(tokenizers_root, f'{tokenizer_id}.pth'), map_location='cpu')

    # Handle renamed arguments
    if 'CLIP' in ckpt['args'].domain or 'DINO' in ckpt['args'].domain or 'ImageBind' in ckpt['args'].domain:
        ckpt['args'].patch_proj = False
    elif 'sam' in ckpt['args'].domain:
        ckpt['args'].input_size_min = ckpt['args'].mask_size
        ckpt['args'].input_size_max = ckpt['args'].mask_size
        ckpt['args'].input_size = ckpt['args'].mask_size
        
    ckpt['args'].quant_type = getattr(ckpt['args'], 'quantizer_type', None)
    ckpt['args'].enc_type = getattr(ckpt['args'], 'encoder_type', None)
    ckpt['args'].dec_type = getattr(ckpt['args'], 'decoder_type', None)
    ckpt['args'].image_size = getattr(ckpt['args'], 'input_size', None) or getattr(ckpt['args'], 'input_size_max', None)
    ckpt['args'].image_size_enc = getattr(ckpt['args'], 'input_size_enc', None)
    ckpt['args'].image_size_dec = getattr(ckpt['args'], 'input_size_dec', None)
    ckpt['args'].image_size_sd = getattr(ckpt['args'], 'input_size_sd', None)
    ckpt['args'].ema_decay = getattr(ckpt['args'], 'quantizer_ema_decay', None)
    ckpt['args'].enable_xformer = getattr(ckpt['args'], 'use_xformer', None)
    if 'cls_emb.weight' in ckpt['model']:
        ckpt['args'].n_labels, ckpt['args'].n_channels = n_labels, n_channels = ckpt['model']['cls_emb.weight'].shape
    elif 'encoder.linear_in.weight' in ckpt['model']:
        ckpt['args'].n_channels = ckpt['model']['encoder.linear_in.weight'].shape[1]
    else:
        ckpt['args'].n_channels = ckpt['model']['encoder.proj.weight'].shape[1]
    ckpt['args'].sync_codebook = False
    
    if encoder_only:
        model_type = VQ
        ckpt['model'] = {k: v for k, v in ckpt['model'].items() if 'decoder' not in k and 'post_quant_proj' not in k}
    else:
        # TODO: Add the model type to the checkpoint when training so we can avoid this hackery
        if any(['controlnet' in k for k in ckpt['model'].keys()]):
            ckpt['args'].model_type = 'VQControlNet'
        elif hasattr(ckpt['args'], 'beta_schedule'):
            ckpt['args'].model_type = 'DiVAE'
        else:
            ckpt['args'].model_type = 'VQVAE'
        model_type = getattr(sys.modules[__name__], ckpt['args'].model_type)
    model = model_type(**vars(ckpt['args']))
    
    msg = model.load_state_dict(ckpt['model'], strict=False)
    if verbose:
        print(msg)

    return model.to(device).eval(), ckpt['args']