-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
241 lines (206 loc) · 8.3 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import os
import numpy as np
import shutil
import torch
import pickle
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import flow_transforms
from imageio import imread, imwrite
import imageio
from scipy.spatial.transform import Rotation as R
def save_checkpoint(state, is_best, save_path, filename="checkpoint.pth.tar"):
torch.save(state, os.path.join(save_path, filename))
if is_best:
shutil.copyfile(
os.path.join(save_path, filename),
os.path.join(save_path, "model_best.pth.tar"),
)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return "{:.3f} ({:.3f})".format(self.val, self.avg)
def flow2rgb(flow_map, max_value):
flow_map_np = flow_map.detach().cpu().numpy()
_, h, w = flow_map_np.shape
flow_map_np[:, (flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float("nan")
rgb_map = np.ones((3, h, w)).astype(np.float32)
if max_value is not None:
normalized_flow_map = flow_map_np / max_value
else:
normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
rgb_map[0] += normalized_flow_map[0]
rgb_map[1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1])
rgb_map[2] += normalized_flow_map[1]
return rgb_map.clip(0, 1)
def read_gt_file_to_dict(file_path):
'''
groundtruth is a tuple of absolute pose and relative pose (1,12) = (1,6+6)
'''
result_dict = {}
prevValues = []
with open(file_path, 'r') as file:
for line in file:
# Skip comments and empty lines
if line.startswith('#') or line.strip() == '':
continue
# Split the line into parts
parts = line.strip().split()
if len(parts) == 8:
timestamp = parts[0]
valuesList = list(map(float, parts[1:])) # Convert the rest to float and store in a list
values = valuesList[:3]
angs = R.from_quat(valuesList[3:7]).as_euler('xyz', degrees=True)
for ang in angs:
values.append(ang)
if len(prevValues) == 0:
# dont add at first timestamp
prevValues = values
else:
diffVal = []
for i in range(len(values)):
diffVal.append(values[i]-prevValues[i])
result_dict[float(timestamp)] = tuple(values + diffVal)
return result_dict
def read_rgb_file_to_dict(file_path):
result_dict = {}
with open(file_path, 'r') as file:
for line in file:
# Skip comments and empty lines
if line.startswith('#') or line.strip() == '':
continue
# Split the line into timestamp and filename
parts = line.strip().split()
if len(parts) == 2:
timestamp, filename = parts
result_dict[float(timestamp)] = filename
return result_dict
class CustomTUMDataset(Dataset):
''' datasetList is a list of tuples:
The first element is a tuple of consecutive images
The second element is the pose of the images.
'''
def __init__(self, data_list, device, transform=None):
self.data_list = data_list
self.transform = transform
self.device = device
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
input_transform = transforms.Compose(
[
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]),
]
)
listImgsPoses = []
items = self.data_list[idx]
for item in items:
imagesPaths = item[0]
pose = item[1]
image0 = input_transform(imageio.v2.imread(imagesPaths[0])).to(self.device)
image1 = input_transform(imageio.v2.imread(imagesPaths[1])).to(self.device)
image = torch.cat([image0, image1])#.unsqueeze(0)
pose = torch.tensor(pose, dtype=torch.float32)
image.to(self.device)
pose.to(self.device)
listImgsPoses.append((image,pose))
return listImgsPoses
def datasetsGet(trainPaths, testPaths):
datasetsTrain = []
if (os.path.isfile('trainDatasets.pkl') == False):
datasetsTrain = datasetsListGet(trainPaths)
with open('trainDatasets.pkl', 'wb') as file:
pickle.dump(datasetsTrain, file)
else:
print("loading train pickle file")
with open('trainDatasets.pkl', 'rb') as file:
datasetsTrain = pickle.load(file)
datasetsTest = []
if (os.path.isfile('testDatasets.pkl') == False):
datasetsTest = datasetsListGet(testPaths)
with open('testDatasets.pkl', 'wb') as file:
pickle.dump(datasetsTrain, file)
else:
print("loading test pickle file")
with open('testDatasets.pkl', 'rb') as file:
datasetsTest = pickle.load(file)
return datasetsTrain, datasetsTest
def averagePoseGet(pose1, pose2, distance1, distance2):
newPose = []
for i in range(len(pose1)):
# this is a weighted average between the two closest poses
newPose.append((pose1[i]*distance2 + pose2[i]*distance1)/(distance1 + distance2))
# this is just the closes timestamp: pose1
#newPose.append(pose1[i])
return tuple(newPose)
def are_paths_consistent(file_list):
if not file_list:
return False # If the list is empty, return True
# Extract the directory of the first file
first_path = os.path.dirname(file_list[0][0][0])
# Compare the directory of each file to the first path
for files, _ in file_list:
if os.path.dirname(files[0]) != first_path:
return False
return True
def datasetsListGet(paths, NUM_PAIRS_PER_INPUT = 10):
'''
need to capture the images and the ground truth
input: tuple(frame[k-1], frame[k])
output: tuple(6d pose - xyz + 3 angles)
the input to the neural net is a sequence of 11 frames
so we're gonna make a list of 10 tuples (input, gt)
'''
datasetIO = []
datasetListOfLists = []
for path in paths:
gtDict = read_gt_file_to_dict(path + 'groundtruth.txt')
gtDictSorted = sorted(gtDict)
rgbDict = read_rgb_file_to_dict(path + 'rgb.txt')
rgbDictSorted = sorted(rgbDict)
for i in range(len(gtDictSorted)-1):
if(gtDictSorted[i+1] < gtDictSorted[i]):
# if condition matched update the out
print('error in order of gt')
for i in range(len(rgbDictSorted)-1):
if(rgbDictSorted[i+1] < rgbDictSorted[i]):
# if condition matched update the out
print('error in order of rgb')
# gt has more elements than rgb, so we get the closest elements of rgb that are inside gt
# Matching GT to RGB timestamps - we need to "fake" matching keypoints
for i in range(1,len(rgbDictSorted)):
frameCurrTimeStamp = rgbDictSorted[i]
framePrevTimeStamp = rgbDictSorted[i-1]
gtSortedRelativeTS = sorted(gtDictSorted, key=lambda x: abs(x - frameCurrTimeStamp))
closest1, closest2 = gtSortedRelativeTS[:2]
distance1 = abs(closest1 - frameCurrTimeStamp)
distance2 = abs(closest2 - frameCurrTimeStamp)
gtPose = averagePoseGet(gtDict[closest1], gtDict[closest2], distance1, distance2)
input = (path + rgbDict[framePrevTimeStamp], path + rgbDict[frameCurrTimeStamp])
datasetIO.append((input,gtPose))
print('\n')
count = 0
tempList = []
for element in datasetIO:
tempList.append(element)
count += 1
if count == NUM_PAIRS_PER_INPUT:
if are_paths_consistent(tempList)==True:
datasetListOfLists.append(tempList)
tempList = []
count = 0
return datasetListOfLists