3DOI / monoarti /model.py
shengyi-qian's picture
init
9afcee2
raw
history blame contribute delete
No virus
3.41 kB
from functools import partial
import torch
from .transformer import INTR
from .sam_transformer import SamTransformer
from .sam import ImageEncoderViT, MaskDecoder, PromptEncoder, TwoWayTransformer
def build_demo_model():
# model = INTR(
# backbone_name='resnet50',
# image_size=[768, 1024],
# num_queries=15,
# freeze_backbone=False,
# transformer_hidden_dim=256,
# transformer_dropout=0,
# transformer_nhead=8,
# transformer_dim_feedforward=2048,
# transformer_num_encoder_layers=6,
# transformer_num_decoder_layers=6,
# transformer_normalize_before=False,
# transformer_return_intermediate_dec=True,
# layers_movable=1,
# layers_rigid=1,
# layers_kinematic=1,
# layers_action=1,
# layers_axis=3,
# layers_affordance=3,
# depth_on=True,
# )
# sam_vit_b
encoder_embed_dim=768
encoder_depth=12
encoder_num_heads=12
encoder_global_attn_indexes=[2, 5, 8, 11]
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
model = SamTransformer(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
properties_on=True,
),
affordance_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
properties_on=False,
),
depth_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
properties_on=False,
),
transformer_hidden_dim=prompt_embed_dim,
backbone_name='vit_b',
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
return model