Ideally, we want to predict shape parameters based on minimzing some divergence measure between a ground truth (GT) mask/sdf and the predicted SDF. In many ways, this is identical to predicting shape parameters directly, but the forward and backward passes and loss will propagate through the shape decoder/embedder.
We'll start first with predicting a translation model
Because we work with full 3D volumes, we probably can't work with the original resolution for all steps. For estimating translation, which has to occur on the full volume, we can usually get away with a coarser resolution, e.g., [4,4,4]
voxel_convert.resample_cts(in_folder, out_folder, out_spacing=[4,4,4])
The function also prints out the max size of the resampled volumes. To make constant sized volumes, each volume should be padded to this max size. Here, we assume CTs so we use a constant value of -1024 to pad the volumes
voxel_convert.pad_cts(in_folder, out_folder, out_size=MAX_SIZE, constant_values=-1024
To calculate a loss through the shape decoder, we must have GT coordinate/SDF pairs. To do this, we conduct a similar workflow to how we created the resampled CTs.
First we must pad the original masks to ensure that we don't have any areas where they are cut off (similar to what we had to do when converting them to meshes).
voxel_convert.pad_cts(orig_mask_folder, padded_mask_folder, out_size=[1000,1000,1000], constant_values=0)
Here we pick an out size that will definitely give us enough space, which is absolute overkill, but probably a smaller pad value would work just as well. With the padded masks created, we can then compute SDF versions of the masks. We also want them to resampled to the same space as our resampled CTs we use as input into our prediction model
voxel_convert.create_resample_sdfs(padded_mask_folder, sdf_folder, resampled_padded_ct_folder, anchor_mesh)
This function will compute the SDFs (difference values based on world coordinate spacing) and normalize them using the anchor mesh. This anchor mesh will typically be the first mesh in your folder of meshes, i.e., liver_0.obj
in Meshes_Simplify
. This function will also resample each resulting SDF to the same coordinate space as the correspoding CT found in resampled_padded_ct_folder
, meaning it will do the padding and the resampling.
Now with appropriate SDF volumes created, we can samples the SDF values inside each volume to create a list of SDF values and coordinate pairs that are compatible with the DeepSDF decoder model, i.e., a set of .npz
files that follows the format of files we created in CONVERT.md. Note, the files used to train the shape embedding model were in canonical [-1, 1] space. But here we create coordinate that are specified in pixel space. Corresponding SDF values are still normalized to canonical space because we normalize them using the anchor mesh.
voxel_convert.create_sdf_voxel_samples(sdf_folder, sdf_sample_folder)
We can use the pycpd package to estimate what translation, rotation, and scale is required to align the mean shape to each individual shape in pixel space. This doesn't give us perfect GT, because pycpd only does isotropic scale and is also an approximation that requires intuiting correspondence between shapes, but it will at least give us an idea on what we should aim for in our own pose estimation process.
First, create a mean mesh, discussed in DEEPSDF.md. Then simplify it by a factor 0.01 using the Fast Quadric method. The resulting mean mesh will be normalized to have vertices in the range of [-1, 1] and centered at [0,0,0] in some canonical space.
align_shape.create_scale_translation_json(in_folder, im_folder, mean_mesh_file, scale_factor)
This function will project all meshes in in_folder
to pixel space corresponding to the matching CTs in im_folder
. Here in_folder
should be the folder of your super simplified meshes from CONVERT.md and the im_folder
should be the resized and padded CTs from above. mean_mesh_file
should be the simplified mean mesh .obj file. scale_factor
should be the scale outputted scale.txt
in the SDF generation step in CONVERT.md.
After projecting to pixel space, the function will then align the mean mesh to each individual mesh using pycpd. It will then save a json
file in im_folder
that contains the estimated translation/rotation/scale in pixel_space, which is the space we need to operate in when we peform pose estimation. These rough GT values can be useful for debugging.
The final step is to create a json file that we can use for training. Here we'll make sure to also include in our training json the "ground truth" translation, scales and rotations in the json file calculated using the CDF using the align_shape.create_scale_translation_json(in_folder, im_folder, mean_mesh_file)
function seen in DIRECT_PREDICT. We do this for debugging and training purposes.
voxel_convert.create_sample_jsonv2(resampled_padded_ct_folder,sdf_sample_folder, json_out_path, gt_scale_json_path)
The resulting training json will look like this:
[
{
path: "path_to_SDF_npz",
im: "path_to_resampled_padded_CT",
t: [t_x, t_y, t_z], # Gt translations
s: s, # gt scale
R: [[]] # gt rotation
},
...
]
Since we would like to incorporate the concept of "state" to the network, so that the network can know what the initial guess is and then refine that guess. Here we use an SDF volume of the mean mesh that represents the guesses or current pose state for rigid pose estimation. We will use the initial guess to be the center of the image and we start with the scale in scale.txt
from the SDF generation step in CONVERT.md.
ins.infer_mean_aligned_sdf(shape_embedding_ckpt, shape_embedding_config, out_path, example_CT, scale=scale_factor)
Here example_CT
can be any resampled and padded CT, which will allow the function to create a mean SDF sampled in the same coordinate space. We will use this mean SDF file as an extra channel for the prediction step.
Currently only translation has been well-tested. The basic idea is this:
- We ask the encoder to predict a translation of the sampled GT SDF pixel coordinates to center them. Then an affine transform is applied to transform them to the canonical shape embedding space (
[-1, 1]
) - An extra channel is concatenated to the image, representing the initial guess. This is the mean SDF volume
- We use the mean embedding in the shape decoder, and minimize the divergence between the shape decoder SDF output and the GT SDF values
Several augmentations are applied
- In addition to intensity based augmentations, a random affine transform is also applied to the image only before being concatenated to the initial guess mean SDF channel. Note, any affine transformation of the image must be accounted for the in coordinate values of the GT SDF samples in the .npz files. An additional transform ensures that this is done (
ApplyAffineToPoints
) - In addition to augmenting the image, the initial guess should also be randomized, so the network can learn how to refine different initial guesses of the same image. The
AddMaskChannel
takes care of this. It accepts a universal starting guess applied to all images and also an image-specific starting guess. In addition, it will optionally randomly jitter the initial guess. After computing an initial guess, it applies any corresponding transforms to the mean SDF channel being concatenating it to the image. Thus, changes in the initial state are represented by changes in the mean SDF channel. In addition, the initial guess is also applied to the GT SDF coordinates. Thus, any initial guesses are accounted for in both the mean SDF channel and in how the SDF coordinates are transformed.
To train run
python train_episodic --im_root CT_FOLDER --sdf_sample_root SDF_NPZ_FOLDER --yaml_file configs/config_predict_48.yml --save_path SAVE_PATH --json_list JSON_LIST_PATH --embed_model_path DEEPSDF_CKPT --embed_yaml_file config.yml --mean_sdf_file MEAN_SDF_NIFTY --train_json_list TRAIN_JSON --val_json_list VAL_JSON --scale_factor SCALE_FACTOR
where CT_FOLDER
corresponds to the resmapled and padded CTs, SDF_NPZ_FOLDER
corresponds to the sdf npz files extracted from the SDFs resampled to the same resolution as the CTs, and SCALE_FACTOR
is the scale value in scale.txt
.
You can run the trained model on the training and validation set to obtain initial translation-only guesses. This will then be ingested in the scale+translation encoder.
predict_params.create_post_training_affines(model_encoder_path, im_root, json_list_path, mean_sdf_file, output_json_path, steps=10, do_scale=False, do_rotate=False, do_pca=False):
You can do this for the training and validation jsons, and the script will predict translation for each image and save the result in an output_json. Make sure to set the step number to an appropriate number for translation (when training this is set to 10 by default and you can use the same number in inference).
To generate an SDF volume in inference you can run:
predict_params.render_sdf(model_encoder_path_trans, model_decoder_path, decoder_config, volume_file, mean_sdf_file, save_path, scale, steps=10)
You must provide it the pose encoder checkpoint for translation that you just trained, plus the path to the shap embedding decoder and its config, along with the CT volume file and the mean shape SDF that you used in trained. The scale factor must also be inputted. Note this will only rigidly translate the mean shape to its optimal location, so accuracy will probably not be very good. To obtain a mask you can threshold the SDF.