Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better expose and document batch prediction from dataloaders #849

Open
bw4sz opened this issue Dec 11, 2024 · 3 comments
Open

Better expose and document batch prediction from dataloaders #849

bw4sz opened this issue Dec 11, 2024 · 3 comments
Labels
API This tag is used for small improvements to the readability and usability of the python API. good first issue Good for newcomers

Comments

@bw4sz
Copy link
Collaborator

bw4sz commented Dec 11, 2024

The current API prioritizes simple functions over performance. In cases where we have lots of images already yielded from a dataloader, its quite annoying to either save them to file to use main.predict_file, or to manipulate the dataloader to get the images preprocessed as expected. We could loop through them individually and call predict_image, but that is really wasteful given modern GPU memory.

for batch in test_loader:
    for image_metadata, image, image_targets in batch:
        # preprocess that image, for example Deepforest likes 0-255 data, channels first
        pred = m.predict_image(channels_first)

not great.

Instead a batch prediction mechanism is just sitting there, already in the codebase, but its not quite clear.

for idx, batch in enumerate(test_loader):
    metadata, images, targets  = batch
    # Preprocessing here 
    predictions = m.predict_step(images,idx)

where predictions is a list of results that have been formatted into dataframes from tensors, like the other predict family of functions. predict_step is a pytorch ligthning method and reserved for trainer.predict it can't be renamed, but it can be wrapped into some other function if we wanted to.

This pathway isn't really in docs and it would take an astute user to recognize it. It should be much faster since the batches might be quite large if you have a big GPU

Next steps

  1. Document this behavior

  2. Consider a predict_batch function that mirrors the format of predict_file, predict_image, predict_tile to help guide users through this.

@bw4sz bw4sz added good first issue Good for newcomers API This tag is used for small improvements to the readability and usability of the python API. labels Dec 11, 2024
@RohitP2005
Copy link

hey @bw4sz i would like to solve this issue if its still unassigned!

@bw4sz
Copy link
Collaborator Author

bw4sz commented Dec 11, 2024 via email

@RohitP2005
Copy link

Fine, then i will start to work on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API This tag is used for small improvements to the readability and usability of the python API. good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants