File size: 3,063 Bytes
85456ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
"""
import torch


class BoundingBox:
    """A rectangular bounding box determines the directed regions."""

    def __init__(self, resolution, box_ratios, margin=0.0):
        """
        Args:
            resolution(int): the resolution of the 2d spatial input
            box_ratios(List[float]):
        Returns:
        """
        assert (
            box_ratios[1] < box_ratios[3]
        ), "the boundary top ratio should be less than bottom"
        assert (
            box_ratios[0] < box_ratios[2]
        ), "the boundary left ratio should be less than right"
        self.left = int((box_ratios[0] - margin) * resolution)
        self.right = int((box_ratios[2] + margin) * resolution)
        self.top = int((box_ratios[1] - margin) * resolution)
        self.bottom = int((box_ratios[3] + margin) * resolution)
        self.height = self.bottom - self.top
        self.width = self.right - self.left
        if self.height == 0:
            self.height = 1
        if self.width == 0:
            self.width = 1

    def sliced_tensor_in_bbox(self, tensor: torch.tensor) -> torch.tensor:
        """ slicing the tensor with bbox area

        Args:
            tensor(torch.tensor): the original tensor in 4d
        Returns:
            (torch.tensor): the reduced tensor inside bbox
        """
        return tensor[:, self.top : self.bottom, self.left : self.right, :]

    def mask_reweight_out_bbox(
        self, tensor: torch.tensor, value: float = 0.0
    ) -> torch.tensor:
        """reweighting value outside bbox

        Args:
            tensor(torch.tensor): the original tensor in 4d
            value(float): reweighting factor default with 0.0
        Returns:
            (torch.tensor): the reweighted tensor
        """
        mask = torch.ones_like(tensor).to(tensor.device) * value
        mask[:, self.top : self.bottom, self.left : self.right, :] = 1
        return tensor * mask

    def mask_reweight_in_bbox(
        self, tensor: torch.tensor, value: float = 0.0
    ) -> torch.tensor:
        """reweighting value within bbox

        Args:
            tensor(torch.tensor): the original tensor in 4d
            value(float): reweighting factor default with 0.0
        Returns:
            (torch.tensor): the reweighted tensor
        """
        mask = torch.ones_like(tensor).to(tensor.device)
        mask[:, self.top : self.bottom, self.left : self.right, :] = value
        return tensor * mask

    def __str__(self):
        """it prints Box(L:%d, R:%d, T:%d, B:%d) for better ingestion"""
        return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})"

    def __rerp__(self):
        """ """
        return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})"


if __name__ == "__main__":
    # Example: second quadrant
    input_res = 32
    left = 0.0
    top = 0.0
    right = 0.5
    bottom = 0.5
    box_ratios = [left, top, right, bottom]
    bbox = BoundingBox(resolution=input_res, box_ratios=box_ratios)

    print(bbox)
    # Box(L:0, R:16, T:0, B:16)