Skip to content

Commit

Permalink
[CHERRYPICK] PIL fill len 1 seq / float fill for int images (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Sep 8, 2023
1 parent eab7cfb commit a90e584
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 50 deletions.
26 changes: 16 additions & 10 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,12 @@ def adapt_fill(value, *, dtype):
return value

max_value = get_max_value(dtype)
value_type = float if dtype.is_floating_point else int

if isinstance(value, (int, float)):
return type(value)(value * max_value)
return value_type(value * max_value)
elif isinstance(value, (list, tuple)):
return type(value)(type(v)(v * max_value) for v in value)
return type(value)(value_type(v * max_value) for v in value)
else:
raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'.")

Expand Down Expand Up @@ -414,6 +415,10 @@ def affine_bounding_boxes(bounding_boxes):
)


# turns all warnings into errors for this module
pytestmark = pytest.mark.filterwarnings("error")


class TestResize:
INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
Expand Down Expand Up @@ -2575,18 +2580,19 @@ def test_functional_image_correctness(self, kwargs):
def test_transform(self, param, value, make_input):
input = make_input(self.INPUT_SIZE)

kwargs = {param: value}
if param == "fill":
# 1. size is required
# 2. the fill parameter only has an affect if we need padding
kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]

if isinstance(input, PIL.Image.Image) and isinstance(value, (tuple, list)) and len(value) == 1:
pytest.xfail("F._pad_image_pil does not support sequences of length 1 for fill.")

if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
pytest.skip("F.pad_mask doesn't support non-scalar fill.")

kwargs = dict(
# 1. size is required
# 2. the fill parameter only has an affect if we need padding
size=[s + 4 for s in self.INPUT_SIZE],
fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
)
else:
kwargs = {param: value}

check_transform(
transforms.RandomCrop(**kwargs, pad_if_needed=True),
input,
Expand Down
37 changes: 0 additions & 37 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import collections.abc

import pytest
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
Expand Down Expand Up @@ -112,32 +110,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
multi_crop_skips.append(skip_dispatch_tv_tensor)


def xfails_pil(reason, *, condition=None):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
]


def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
except KeyError:
return False

if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
return False

return image_loader.num_channels > 1


xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
condition=fill_sequence_needs_broadcast,
)


DISPATCHER_INFOS = [
DispatcherInfo(
F.resized_crop,
Expand All @@ -159,14 +131,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
*xfails_pil(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
xfail_jit_python_scalar_arg("padding"),
],
Expand All @@ -181,7 +145,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"),
],
),
Expand Down
6 changes: 4 additions & 2 deletions torchvision/transforms/_functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,13 @@ def _parse_fill(
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) != num_channels:
if len(fill) == 1:
fill = fill * num_channels
elif len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))

fill = tuple(fill)
fill = tuple(fill) # type: ignore[arg-type]

if img.mode != "F":
if isinstance(fill, (list, tuple)):
Expand Down
6 changes: 5 additions & 1 deletion torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,11 @@ def _pad_with_vector_fill(

output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)

# We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
# float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
# value.
fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)

if top > 0:
output[..., :top, :] = fill
Expand Down

0 comments on commit a90e584

Please sign in to comment.