File size: 3,138 Bytes
27763e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch


def foot_contact_by_height(pos):
    eps = 0.25
    return (-eps < pos[..., 1]) * (pos[..., 1] < eps)


def velocity(pos, padding=False):
    velo = pos[1:, ...] - pos[:-1, ...]
    velo_norm = torch.norm(velo, dim=-1)
    if padding:
        pad = torch.zeros_like(velo_norm[:1, :])
        velo_norm = torch.cat([pad, velo_norm], dim=0)
    return velo_norm


def foot_contact(pos, ref_height=1., threshold=0.018):
    velo_norm = velocity(pos)
    contact = velo_norm < threshold
    contact = contact.int()
    padding = torch.zeros_like(contact)
    contact = torch.cat([padding[:1, :], contact], dim=0)
    return contact


def alpha(t):
    return 2.0 * t * t * t - 3.0 * t * t + 1


def lerp(a, l, r):
    return (1 - a) * l + a * r


def constrain_from_contact(contact, glb, fid='TBD', L=5):
    """
    :param contact: contact label
    :param glb: original global position
    :param fid: joint id to fix, corresponding to the order in contact
    :param L: frame to look forward/backward
    :return:
    """
    T = glb.shape[0]

    for i, fidx in enumerate(fid):  # fidx: index of the foot joint
        fixed = contact[:, i]  # [T]
        s = 0
        while s < T:
            while s < T and fixed[s] == 0:
                s += 1
            if s >= T:
                break
            t = s
            avg = glb[t, fidx].clone()
            while t + 1 < T and fixed[t + 1] == 1:
                t += 1
                avg += glb[t, fidx].clone()
            avg /= (t - s + 1)

            for j in range(s, t + 1):
                glb[j, fidx] = avg.clone()
            s = t + 1

        for s in range(T):
            if fixed[s] == 1:
                continue
            l, r = None, None
            consl, consr = False, False
            for k in range(L):
                if s - k - 1 < 0:
                    break
                if fixed[s - k - 1]:
                    l = s - k - 1
                    consl = True
                    break
            for k in range(L):
                if s + k + 1 >= T:
                    break
                if fixed[s + k + 1]:
                    r = s + k + 1
                    consr = True
                    break
            if not consl and not consr:
                continue
            if consl and consr:
                litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
                            glb[s, fidx], glb[l, fidx])
                ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
                            glb[s, fidx], glb[r, fidx])
                itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)),
                           ritp, litp)
                glb[s, fidx] = itp.clone()
                continue
            if consl:
                litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
                            glb[s, fidx], glb[l, fidx])
                glb[s, fidx] = litp.clone()
                continue
            if consr:
                ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
                            glb[s, fidx], glb[r, fidx])
                glb[s, fidx] = ritp.clone()
    return glb