Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine tuning SAM, 'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).' #34862

Open
leemorton opened this issue Nov 21, 2024 · 2 comments

Comments

@leemorton
Copy link

Hi,

I am trying to fine tune SAM on custom images and masks but am struggling and am hoping someone can point me in the right direction to resolving it.

I have been referencing this code:
https://github.com/bnsreenu/python_for_microscopists/blob/master/331_fine_tune_SAM_mito.ipynb
which is based on this:
https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb

I cannot get the training to work as I get this message at the forward pass step:
'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).'

I think the input_boxes is wrong somehow?

image

The images I am using are colour PNG images rather than the tiff images in the reference code and are showing with 3 channels here....
image

My SamDataset code is:

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])
    
    # get bounding box prompt
    # prompt = get_bounding_box(ground_truth_mask)
    prompt = item["bounding_box"]

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

and this is where I run into trouble...
image

@Rocketknight1
Copy link
Member

cc @NielsRogge for the tutorial!

@FerRomeroGalvan
Copy link

FerRomeroGalvan commented Nov 22, 2024

I am having the same issue :(

Edit: Actually, I was able to resolve my issue. I suspect the OP also is having the same problem. My issue is related to a custom bounding_box function I wrote. It was returning multiple boxes, but the individual boxes were not in the required structure [[x,y,z,w]].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants