Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't convert Depth-Anything with flexible shape #2382

Open
czkoko opened this issue Nov 1, 2024 · 1 comment
Open

can't convert Depth-Anything with flexible shape #2382

czkoko opened this issue Nov 1, 2024 · 1 comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)

Comments

@czkoko
Copy link

czkoko commented Nov 1, 2024

🐞Describing the bug

When the input shape is fixed, the model can be successfully converted to coreml, but when the input is a flexible shape, the following errors will occur:

ValueError: Cannot add const 518.0001/is96
Python native vals (list, tuple), np.array that areoperation inputs cannot have symbolic values. Consider feedingsymbolic shape in through placeholder and use mb.shape() operator. Input input.337_scale_factor_height: 518.0001/is96

Maybe the problem is in upsample_bilinear2d? I tried to modify it incorrectly. The model conversion was successful, but run failed.

To Reproduce

  • If the input shape is fixed, the following code can successfully converted the model.
from transformers import AutoModelForDepthEstimation
import coremltools as ct
import numpy as np
import torch
import torchvision

model_version = "Base"
model_id = f"depth-anything/Depth-Anything-V2-{model_version}-hf"
model = AutoModelForDepthEstimation.from_pretrained(model_id)
model = model.eval()

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    @torch.no_grad()
    def forward(self, pixel_values):
        """pixel_values are floats in the range `[0, 255]`"""
        # Apply ImageNet normalization
        n_mean = [0.485*255, 0.456*255, 0.406*255]
        n_std = [0.229*255, 0.224*255, 0.225*255]
        pixel_values = torchvision.transforms.functional.normalize(pixel_values, mean=n_mean, std=n_std)

        outputs = self.model(pixel_values, return_dict=False)
        # Normalize output to `[0, 1]` and add batch size dimension
        normalized = outputs[0] / outputs[0].max()
        return normalized.unsqueeze(0)
        
to_trace = Wrapper(model)
inputs_shape = torch.rand(1, 3, 518, 518)
traced_model = torch.jit.trace(to_trace.eval(), inputs_shape)
    
flexible_shape = ct.Shape(
   shape=(1, 3, ct.RangeDim(lower_bound=259, upper_bound=1036, default=518), 
          ct.RangeDim(lower_bound=259, upper_bound=1036, default=518))
)

input_types = [ct.ImageType(name="image", shape=flexible_shape, color_layout=ct.colorlayout.RGB)]
output_types = [ct.ImageType("depth", color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)]

from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op

@register_torch_op
def upsample_bicubic2d(context, node):
    a = context[node.inputs[0]]
    align_corners = context[node.inputs[2]].val
    scale = context[node.inputs[3]]
    scale_h = scale.val[0]
    scale_w = scale.val[1]

    x = mb.upsample_bilinear(x=a, scale_factor_height=scale_h, scale_factor_width=scale_w, align_corners=align_corners, name=node.name)
    context.add(x)
    
coreml_model = ct.convert(
    traced_model,
    minimum_deployment_target = ct.target.macOS14,
    inputs = input_types,
    outputs = output_types,
    convert_to="mlprogram",
    compute_units = ct.ComputeUnit.CPU_ONLY,
    compute_precision = ct.precision.FLOAT16,
)

model_name = f"DepthAnythingV2-{model_version}"

coreml_model.name = model_name
coreml_model.version = "2.0"
coreml_model.short_description = "Depth Anything V2 is a state-of-the-art deep learning model for depth estimation."
coreml_model.author = "Original Paper: Lihe Yang et al. (Depth Anything V2)"
coreml_model.license = "Apache 2"
coreml_model.input_description["image"] = "Input image whose depth will be estimated."
coreml_model.output_description["depth"] = "Estimated depth map, as a grayscale output image."

coreml_model.user_defined_metadata["com.apple.coreml.model.preview.type"] = "depthEstimation"
coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.category"] = "image"
coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.name"] = f"{model_name}.mlpackage"
coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.version"] = "2.0"
coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.release-date"] = "2024-06"

coreml_model.save(f"{model_name}.mlpackage")

System environment (please complete the following information):

  • coremltools version: v8.0
  • OS (e.g. MacOS version or Linux type): MacOS 15
  • transformers==4.40.0
  • torch==2.4.0
@czkoko czkoko added the bug Unexpected behaviour that should be corrected (type) label Nov 1, 2024
@jakesabathia2
Copy link
Collaborator

jakesabathia2 commented Nov 19, 2024

By definition,
scale_factor_height and scale_factor_weight both must be a constant at the compile time,
however, if the input has flexible shape, the two args will not be constant anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)
Projects
None yet
Development

No branches or pull requests

2 participants