Better expose and document batch prediction from dataloaders #849
Labels
API
This tag is used for small improvements to the readability and usability of the python API.
good first issue
Good for newcomers
The current API prioritizes simple functions over performance. In cases where we have lots of images already yielded from a dataloader, its quite annoying to either save them to file to use main.predict_file, or to manipulate the dataloader to get the images preprocessed as expected. We could loop through them individually and call predict_image, but that is really wasteful given modern GPU memory.
not great.
Instead a batch prediction mechanism is just sitting there, already in the codebase, but its not quite clear.
where predictions is a list of results that have been formatted into dataframes from tensors, like the other predict family of functions. predict_step is a pytorch ligthning method and reserved for trainer.predict it can't be renamed, but it can be wrapped into some other function if we wanted to.
This pathway isn't really in docs and it would take an astute user to recognize it. It should be much faster since the batches might be quite large if you have a big GPU
Next steps
Document this behavior
Consider a predict_batch function that mirrors the format of predict_file, predict_image, predict_tile to help guide users through this.
The text was updated successfully, but these errors were encountered: