-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert events into voxels #12
Comments
I guess that your timestamp file is incorrect. If you have 26 images, the timestamp file should contain 26 lines, meaning each image timestamp. The generated events (N, 4) should have their timestamps (ts) range between the timestamp of the first frame and the timestamp of the last frame. Below is a snippet of my code for converting bidirectional voxel. Some functions can be seen in our DataPreparation.md. Hope it helps you. Thanks! def package_bidirectional_event_voxels(self, x, y, t, p, timestamp_list, backward, bins, sensor_size, h5_name, error_txt):
"""
params:
x: ndarray, x-position of events
y: ndarray, y-position of events
t: ndarray, timestamp of events
p: ndarray, polarity of events
backward: bool, if forward or backward
timestamp_list: list, to split events via timestamp
bins: voxel num_bins
returns:
no return.
"""
# Step 1: convert data type
assert x.shape == y.shape == t.shape == p.shape
x = torch.from_numpy(x.astype(np.int16))
y = torch.from_numpy(y.astype(np.int16))
t = torch.from_numpy(t.astype(np.float64))
p = torch.from_numpy(p.astype(np.int16))
assert x.shape == y.shape == t.shape == p.shape
# Step 2: select events between two frames according to timestamp
temp = t.numpy().tolist()
output = [
temp[
bisect.bisect_left(temp, timestamp_list[i]):bisect.bisect_left(temp, timestamp_list[i+1])
]
for i in range(len(timestamp_list) - 1)
]
# Debug: Check if data error!!!
assert len(output) == len(timestamp_list) - 1, f"len(output) is {len(output)}, but len(timestamp_list) is {len(timestamp_list)}"
sum_output = []
sum = 0
for i in range(len(output)):
if len(output[i]) == 0:
raise ValueError(f"{h5_name} len(output[{i}] == 0)")
elif len(output[i]) == 1:
raise ValueError(f"{h5_name} len(output[{i}] == 1)")
sum += len(output[i])
sum_output.append(sum)
assert len(sum_output) == len(output)
# Step 3: After checking data, continue.
start_idx = 0
for voxel_idx in range(len(timestamp_list) - 1):
if len(output[voxel_idx]) == 0 or len(output[voxel_idx]) == 1:
print(f'{h5_name} len(output[{voxel_idx}])): ', len(
output[voxel_idx]))
with open(error_txt, 'a+') as f:
f.write(h5_name + '\n')
return
end_idx = start_idx + len(output[voxel_idx])
if end_idx > len(t):
with open(error_txt, 'a+') as f:
f.write(f"{h5_name} voxel_idx: {voxel_idx}, start_idx {start_idx} end_idx {end_idx} exceed bound." + '\n')
print(f"{h5_name} voxel_idx: {voxel_idx}, start_idx {start_idx} end_idx {end_idx} with exceed bound len(t) {len(t)}.")
return
xs = x[start_idx:end_idx]
ys = y[start_idx:end_idx]
ts = t[start_idx:end_idx]
ps = p[start_idx:end_idx]
if ts == torch.Size([]) or ts.shape == torch.Size([1]) or ts.shape == torch.Size([0]):
with open(error_txt, 'a+') as f:
f.write(f"{h5_name} len(output[{voxel_idx}]) backward {backward} start_idx {start_idx} end_idx {end_idx} is error! Please check the data." + '\n')
print(f"{h5_name} len(output[{voxel_idx}]) backward {backward} start_idx {start_idx} end_idx {end_idx} is error! Please check the data.")
return
if backward:
t_start = timestamp_list[voxel_idx]
t_end = timestamp_list[voxel_idx + 1]
xs = torch.flip(xs, dims=[0])
ys = torch.flip(ys, dims=[0])
ts = torch.flip(t_end - ts + t_start, dims=[0])
ps = torch.flip(-ps, dims=[0])
voxel = events_to_voxel_torch(
xs, ys, ts, ps, bins, device=None, sensor_size=sensor_size)
normed_voxel = voxel_normalization(voxel)
np_voxel = normed_voxel.numpy()
if backward:
self.events_file.create_dataset("voxels_b/{:06d}".format(
voxel_idx), data=np_voxel, dtype=np.dtype(np.float64), compression="gzip")
else:
self.events_file.create_dataset("voxels_f/{:06d}".format(
voxel_idx), data=np_voxel, dtype=np.dtype(np.float64), compression="gzip")
start_idx = end_idx |
Thank you. Issue was with the timestamp file. Fixed now |
@DachunKai Hi,when I was processing voxels_f through this code, I encountered an error: ValueError: 000.h5 len(output[0] == 0) or ValueError: 000.h5 len(output[0] == 1). What error is this? Do each image need to have more than 2 events?My code is:
|
@hongsixin Yes, in our paper, Section 3.1 Equation (1), if there is only one event between two frames, then in Equation (1) |
Thank you for sharing the data preparation details. I have created events for an image dataset, and the events.h5 file looks as follows:
events/ps/ : (7013772,) bool
events/ts/ : (7013772,) float64
events/xs/ : (7013772,) int16
events/ys/ : (7013772,) int16
Could you please share the code snippet that can convert these events into voxels in the format specified below?
voxels_b/000000/ : (5, 180, 320) float64
voxels_b/000001/ : (5, 180, 320) float64
voxels_b/000002/ : (5, 180, 320) float64
voxels_b/000003/ : (5, 180, 320) float64
voxels_b/000004/ : (5, 180, 320) float64
I tried using [events_contrast_maximization] but ended up generating voxels in a different format, which is incorrect. My dataset contains 26 images, but the voxel file only has 5 lines, and the tensor shape [bins, H, W] is also incorrect.
voxels_f/000000/ : (720, 1280) float64
voxels_f/000001/ : (720, 1280) float64
voxels_f/000002/ : (720, 1280) float64
voxels_f/000003/ : (720, 1280) float64
voxels_f/000004/ : (720, 1280) float64
The generated voxel file does not include all 26 images. Please assist. Thank you.
The text was updated successfully, but these errors were encountered: