forked from shuaishiliu/SGCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
95 lines (74 loc) · 2.54 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import math
import torch
import numpy as np
def ade(predAll,targetAll,count_):
All = len(predAll)
sum_all = 0
for s in range(All):
pred = np.swapaxes(predAll[s][:,:count_[s],:],0,1)
target = np.swapaxes(targetAll[s][:,:count_[s],:],0,1)
N = pred.shape[0]
T = pred.shape[1]
sum_ = 0
for i in range(N):
for t in range(T):
sum_+=math.sqrt((pred[i,t,0] - target[i,t,0])**2+(pred[i,t,1] - target[i,t,1])**2)
sum_all += sum_/(N*T)
return sum_all/All
def fde(predAll,targetAll,count_):
All = len(predAll)
sum_all = 0
for s in range(All):
pred = np.swapaxes(predAll[s][:,:count_[s],:],0,1)
target = np.swapaxes(targetAll[s][:,:count_[s],:],0,1)
N = pred.shape[0]
T = pred.shape[1]
sum_ = 0
for i in range(N):
for t in range(T-1,T):
sum_+=math.sqrt((pred[i,t,0] - target[i,t,0])**2+(pred[i,t,1] - target[i,t,1])**2)
sum_all += sum_/(N)
return sum_all/All
def seq_to_nodes(seq_,max_nodes = 88):
seq_ = seq_.squeeze()
seq_len = seq_.shape[2]
V = np.zeros((seq_len,max_nodes,2))
for s in range(seq_len):
step_ = seq_[:,:,s]
for h in range(len(step_)):
V[s,h,:] = step_[h]
return V.squeeze()
def nodes_rel_to_nodes_abs(nodes,init_node):
nodes_ = np.zeros_like(nodes)
for s in range(nodes.shape[0]):
for ped in range(nodes.shape[1]):
nodes_[s,ped,:] = np.sum(nodes[:s+1,ped,:],axis=0) + init_node[ped,:]
return nodes_.squeeze()
def closer_to_zero(current,new_v):
dec = min([(abs(current),current),(abs(new_v),new_v)])[1]
if dec != current:
return True
else:
return False
def bivariate_loss(V_pred,V_trgt):
#mux, muy, sx, sy, corr
#assert V_pred.shape == V_trgt.shape
normx = V_trgt[:,:,0]- V_pred[:,:,0]
normy = V_trgt[:,:,1]- V_pred[:,:,1]
sx = torch.exp(V_pred[:,:,2]) #sx
sy = torch.exp(V_pred[:,:,3]) #sy
corr = torch.tanh(V_pred[:,:,4]) #corr
sxsy = sx * sy
z = (normx/sx)**2 + (normy/sy)**2 - 2*((corr*normx*normy)/sxsy)
negRho = 1 - corr**2
# Numerator
result = torch.exp(-z/(2*negRho))
# Normalization factor
denom = 2 * np.pi * (sxsy * torch.sqrt(negRho))
# Final PDF calculation
result = result / denom
# Numerical stability
epsilon = 1e-20
result = -torch.log(torch.clamp(result, min=epsilon))
result = torch.mean(result)
return result