forked from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #41 from cameron-a-johnson/dev/simple-tcn-input-ve…
…ctor Adding simple 'locsAndConfs' TCN input vector
- Loading branch information
Showing
5 changed files
with
311 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# @package _global_ | ||
|
||
# to execute this experiment run: | ||
# python train.py experiment=example | ||
task: "m2" | ||
# feature_version: 6 | ||
topic: "medical" | ||
|
||
defaults: | ||
- override /data: ptg | ||
- override /model: ptg | ||
- override /callbacks: default | ||
- override /trainer: gpu | ||
- override /paths: default | ||
#- override /logger: aim | ||
- override /logger: csv | ||
|
||
# all parameters below will be merged with parameters from default configurations set above | ||
# this allows you to overwrite only specified parameters | ||
|
||
# Change this name to something descriptive and unique for this experiment. | ||
# This will differentiate the run logs and output to be separate from other | ||
# experiments that may have been run under the configured | ||
# Setting this value influences: | ||
# - the name of the directory under `${paths.root_dir}/logs/` in which training | ||
# run files are stored. | ||
# Default is "train" set in the "configs/train.yaml" file. | ||
#task_name: | ||
|
||
# simply provide checkpoint path to resume training | ||
#ckpt_path: null | ||
|
||
tags: ["m2", "ms_tcn", "debug"] | ||
|
||
seed: 12345 | ||
|
||
trainer: | ||
min_epochs: 50 | ||
max_epochs: 500 | ||
log_every_n_steps: 1 | ||
|
||
model: | ||
compile: false | ||
net: | ||
# Length of feature vector for a single frame. | ||
# Currently derived from feature version and other hyperparameters. | ||
dim: 102 | ||
num_classes: 9 | ||
|
||
data: | ||
coco_train_activities: "${paths.coco_file_root}/TRAIN-activity_truth.coco.json" | ||
coco_train_objects: "${paths.coco_file_root}/TRAIN-object_detections.coco.json" | ||
coco_train_poses: "${paths.coco_file_root}/TRAIN-pose_estimates.coco.json" | ||
|
||
coco_validation_activities: "${paths.coco_file_root}/VALIDATION-activity_truth.coco.json" | ||
coco_validation_objects: "${paths.coco_file_root}/VALIDATION-object_detections.coco.json" | ||
coco_validation_poses: "${paths.coco_file_root}/VALIDATION-pose_estimates.coco.json" | ||
|
||
coco_test_activities: "${paths.coco_file_root}/TEST-activity_truth.coco.json" | ||
coco_test_objects: "${paths.coco_file_root}/TEST-object_detections.coco.json" | ||
coco_test_poses: "${paths.coco_file_root}/TEST-pose_estimates.coco.json" | ||
|
||
batch_size: 16384 | ||
num_workers: 16 | ||
target_framerate: 15 # BBN Hololens2 Framerate | ||
epoch_length: 200000 | ||
|
||
train_dataset: | ||
window_size: 25 | ||
vectorizer: | ||
_target_: tcn_hpl.data.vectorize.locs_and_confs.LocsAndConfs | ||
top_k: 1 | ||
num_classes: 7 | ||
use_joint_confs: True | ||
use_pixel_norm: True | ||
use_hand_obj_offsets: False | ||
background_idx: 0 | ||
transform: | ||
transforms: [] # no transforms | ||
# - _target_: tcn_hpl.data.components.augmentations.MoveCenterPts | ||
# hand_dist_delta: 0.05 | ||
# obj_dist_delta: 0.05 | ||
# joint_dist_delta: 0.025 | ||
# im_w: 1280 | ||
# im_h: 720 | ||
# num_obj_classes: 42 | ||
# feat_version: 2 | ||
# top_k_objects: 1 | ||
# - _target_: tcn_hpl.data.components.augmentations.NormalizePixelPts | ||
# im_w: 1280 | ||
# im_h: 720 | ||
# num_obj_classes: 42 | ||
# feat_version: 2 | ||
# top_k_objects: 1 | ||
val_dataset: | ||
transform: | ||
transforms: [] # no transforms | ||
# - _target_: tcn_hpl.data.components.augmentations.NormalizePixelPts | ||
# im_w: 1280 | ||
# im_h: 720 | ||
# num_obj_classes: 42 | ||
# feat_version: 2 | ||
# top_k_objects: 1 | ||
# Test dataset usually configured the same as val, unless there is some | ||
# different set of transforms that should be used during test/prediction. | ||
|
||
paths: | ||
# root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL/" | ||
# root_dir: "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/training_root" | ||
root_dir: "/home/local/KHQ/cameron.johnson/code/TCN_HPL/tcn_hpl/train-TCN-M2_bbn_hololens/training_root" | ||
|
||
# Convenience variable to where your train/val/test split COCO file datasets | ||
# are stored. | ||
# coco_file_root: "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens" | ||
coco_file_root: "/home/local/KHQ/cameron.johnson/code/TCN_HPL/train-TCN-M2_bbn_hololens" | ||
|
||
#exp_name: "tcn_training_revive" | ||
#logger: | ||
# aim: | ||
# experiment: ${task_name} | ||
# capture_terminal_logs: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import functools | ||
import typing as tg | ||
|
||
import numpy as np | ||
from numpy import typing as npt | ||
|
||
from tcn_hpl.data.vectorize._interface import Vectorize, FrameData | ||
|
||
NUM_POSE_JOINTS = 22 | ||
|
||
class LocsAndConfs(Vectorize): | ||
""" | ||
Previous manual approach to vectorization. | ||
Arguments: | ||
top_k: The number of top per-class examples to use in vector | ||
construction. | ||
num_classes: the number of classes in the object detector. | ||
use_joint_confs: use the confidence of each pose joint. | ||
(changes the length of the input vector, which needs to | ||
be manually updated if this flag changes.) | ||
use_pixel_norm: Normalize pixel coordinates by dividing by | ||
frame height and width, respectively. Normalized values | ||
are between 0 and 1. Does not change input vector length. | ||
use_joint_obj_offsets: add abs(X and Y offsets) for between joints and | ||
each object. | ||
(changes the length of the input vector, which needs to | ||
be manually updated if this flag changes.) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
top_k: int = 1, | ||
num_classes: int = 7, | ||
use_joint_confs: bool = True, | ||
use_pixel_norm: bool = True, | ||
use_joint_obj_offsets: bool = False, | ||
background_idx: int = 0 | ||
): | ||
super().__init__() | ||
|
||
self._top_k = top_k | ||
self._num_classes = num_classes | ||
self._use_joint_confs = use_joint_confs | ||
self._use_pixel_norm = use_pixel_norm | ||
self._use_joint_obj_offsets = use_joint_obj_offsets | ||
self._background_idx = background_idx | ||
|
||
# Get the top "k" object indexes for each object | ||
@staticmethod | ||
def get_top_k_indexes_of_one_obj_type(f_dets, k, label_ind): | ||
""" | ||
Find all instances of a label index in object detections. | ||
Then sort them and return the top K. | ||
Inputs: | ||
- object_dets: | ||
""" | ||
labels = f_dets.labels | ||
scores = f_dets.scores | ||
# Get all labels of an obj type | ||
filtered_idxs = [i for i, e in enumerate(labels) if e == label_ind] | ||
if not filtered_idxs: | ||
return None | ||
filtered_scores = [scores[i] for i in filtered_idxs] | ||
# Sort labels by score values. | ||
sorted_inds = [i[1] for i in sorted(zip(filtered_scores, filtered_idxs))] | ||
return sorted_inds[:k] | ||
|
||
@staticmethod | ||
def append_vector(frame_feat, i, number): | ||
frame_feat[i] = number | ||
return frame_feat, i + 1 | ||
|
||
def determine_vector_length(self, data: FrameData) -> int: | ||
######################### | ||
# Feature vector | ||
######################### | ||
# Length: pose confs * 22, pose X's * 22, pose Y's * 22, | ||
# obj confs * num_objects(7 for M2), | ||
# obj X * num_objects(7 for M2), | ||
# obj Y * num_objects(7 for M2) | ||
# obj W * num_objects(7 for M2) | ||
# obj H * num_objects(7 for M2) | ||
# casualty conf * 1 | ||
vector_length = 0 | ||
# Joint confidences | ||
if self._use_joint_confs: | ||
vector_length += NUM_POSE_JOINTS | ||
# X and Y for each joint | ||
vector_length += 2 * NUM_POSE_JOINTS | ||
# [Conf, X, Y, W, H] for k instances of each object class. | ||
vector_length = 5 * self._top_k * self._num_classes | ||
return vector_length | ||
|
||
|
||
def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: | ||
|
||
vector_len = self.determine_vector_length(data) | ||
frame_feat = np.zeros(vector_len, dtype=np.float32) | ||
# TODO: instead of carrying around this vector_ind, we should | ||
# directly compute the offset of each feature we add to the TCN | ||
# input vector. This would be much easier to debug. | ||
vector_ind = 0 | ||
if self._use_pixel_norm: | ||
W = data.size[0] | ||
H = data.size[1] | ||
else: | ||
W = 1 | ||
H = 1 | ||
f_dets = data.object_detections | ||
|
||
# Loop through all classes: populate obj conf, obj X, obj Y. | ||
# Assumption: class labels are [0, 1, 2,... num_classes-1]. | ||
# TODO: this will break if top_k is ever > 1. Fix that. | ||
for obj_ind in range(0,self._num_classes): | ||
top_k_idxs = self.get_top_k_indexes_of_one_obj_type(f_dets, self._top_k, obj_ind) | ||
if top_k_idxs: # This is None if there were no detections to sort for this class | ||
for idx in top_k_idxs: | ||
# Conf | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.scores[idx]) | ||
# X | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][0] / W) | ||
# Y | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][1] / H) | ||
# W | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][2] / W) | ||
# H | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][3] / H) | ||
else: | ||
for _ in range(0, self._top_k * 5): | ||
# 5 Zeros | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) | ||
|
||
f_poses = data.poses | ||
if f_poses: | ||
# Find most confident body detection | ||
confident_pose_idx = np.argmax(f_poses.scores) | ||
num_joints = f_poses.joint_positions.shape[1] | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.scores[confident_pose_idx]) | ||
|
||
for joint_ind in range(0, num_joints): | ||
# Conf | ||
if self._use_joint_confs: | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_scores[confident_pose_idx][joint_ind]) | ||
# X | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_positions[confident_pose_idx][joint_ind][0] / W) | ||
# Y | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_positions[confident_pose_idx][joint_ind][1] / H) | ||
else: | ||
num_joints = f_poses.joint_positions.shape[1] | ||
if self._use_joint_confs: | ||
rows_per_joint = 3 | ||
else: | ||
rows_per_joint = 2 | ||
for _ in range(num_joints * rows_per_joint + 1): | ||
frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) | ||
|
||
assert vector_ind == vector_len | ||
|
||
return frame_feat |