How to get register token output values ?
are first 5-tokens [CLS] + 4x [REG] tokens?
Yes, num_prefix_tokens is 1 cls + 4 x reg for these models: https://github.com/huggingface/pytorch-image-models/blob/a6fe31b09670289dbc8e99a0cfae23de355534c9/timm/models/vision_transformer.py#L497-L498
easiest way to get them is forward_features() and take the [1:5] in the flattened output, or you can use forward_intermediates() to get the prefix tokens for all blocks
oo = mm.forward_intermediates(torch.randn(2,3,518,518), return_prefix_tokens=True)
>>>
oo[1][-1][1].shape
torch.Size([2, 5, 768])
output there is a tuple of the final features and block output features, each block output is a tuple of spatial features and prefix tokens when return_prefix_tokens is set to True.
Hi! I just would like to ask if what I did below is correct (please see screenshot)
So you said that the easiest way to get the prefix token embeddings is using forward_features()
and take the first 5 in the sequence. I did that (top) and compared it to using forward_intermediates()
... However, their outputs are different. Is there something that I have missed? Would appreciate your help! Thank you so much :)
EDIT: I was able to show that they're the same... I forgot to add norm=True
argument in forward_intermediates()
. Hope this helps!
Hi! It's me again. Just one more question:
Screenshot below is taken from (https://github.com/facebookresearch/dinov2/blob/main/MODEL_CARD.md)
As I've understood there's a total of 261 tokens (1 class + 4 prefix + 256 patch tokens). Now, going back to the timm
version, the output shape is (1, 1374, 768)
. Is the 1374
semantically equivalent to the 261
i.e., is the 1374
the sequences of tokens? How was it able to come up with this versus the 261
? Thank you :-)
dinov2 models I think are 518x518 by default ... so 37*37 spatial patches 1 + cls token + 4 reg tokens = 1374 ... it would be 261 if you resized and used 224x224 images
your snippets above are correct if you want both cls + reg tokens together, if you want just the regs then slice [1:5] to get the 4 reg tokens.