You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When the input shape is fixed, the model can be successfully converted to coreml, but when the input is a flexible shape, the following errors will occur:
ValueError: Cannot add const 518.0001/is96
Python native vals (list, tuple), np.array that areoperation inputs cannot have symbolic values. Consider feedingsymbolic shape in through placeholder and use mb.shape() operator. Input input.337_scale_factor_height: 518.0001/is96
Maybe the problem is in upsample_bilinear2d? I tried to modify it incorrectly. The model conversion was successful, but run failed.
To Reproduce
If the input shape is fixed, the following code can successfully converted the model.
fromtransformersimportAutoModelForDepthEstimationimportcoremltoolsasctimportnumpyasnpimporttorchimporttorchvisionmodel_version="Base"model_id=f"depth-anything/Depth-Anything-V2-{model_version}-hf"model=AutoModelForDepthEstimation.from_pretrained(model_id)
model=model.eval()
classWrapper(torch.nn.Module):
def__init__(self, model):
super().__init__()
self.model=model@torch.no_grad()defforward(self, pixel_values):
"""pixel_values are floats in the range `[0, 255]`"""# Apply ImageNet normalizationn_mean= [0.485*255, 0.456*255, 0.406*255]
n_std= [0.229*255, 0.224*255, 0.225*255]
pixel_values=torchvision.transforms.functional.normalize(pixel_values, mean=n_mean, std=n_std)
outputs=self.model(pixel_values, return_dict=False)
# Normalize output to `[0, 1]` and add batch size dimensionnormalized=outputs[0] /outputs[0].max()
returnnormalized.unsqueeze(0)
to_trace=Wrapper(model)
inputs_shape=torch.rand(1, 3, 518, 518)
traced_model=torch.jit.trace(to_trace.eval(), inputs_shape)
flexible_shape=ct.Shape(
shape=(1, 3, ct.RangeDim(lower_bound=259, upper_bound=1036, default=518),
ct.RangeDim(lower_bound=259, upper_bound=1036, default=518))
)
input_types= [ct.ImageType(name="image", shape=flexible_shape, color_layout=ct.colorlayout.RGB)]
output_types= [ct.ImageType("depth", color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)]
fromcoremltools.converters.milimportBuilderasmbfromcoremltools.converters.milimportregister_torch_op@register_torch_opdefupsample_bicubic2d(context, node):
a=context[node.inputs[0]]
align_corners=context[node.inputs[2]].valscale=context[node.inputs[3]]
scale_h=scale.val[0]
scale_w=scale.val[1]
x=mb.upsample_bilinear(x=a, scale_factor_height=scale_h, scale_factor_width=scale_w, align_corners=align_corners, name=node.name)
context.add(x)
coreml_model=ct.convert(
traced_model,
minimum_deployment_target=ct.target.macOS14,
inputs=input_types,
outputs=output_types,
convert_to="mlprogram",
compute_units=ct.ComputeUnit.CPU_ONLY,
compute_precision=ct.precision.FLOAT16,
)
model_name=f"DepthAnythingV2-{model_version}"coreml_model.name=model_namecoreml_model.version="2.0"coreml_model.short_description="Depth Anything V2 is a state-of-the-art deep learning model for depth estimation."coreml_model.author="Original Paper: Lihe Yang et al. (Depth Anything V2)"coreml_model.license="Apache 2"coreml_model.input_description["image"] ="Input image whose depth will be estimated."coreml_model.output_description["depth"] ="Estimated depth map, as a grayscale output image."coreml_model.user_defined_metadata["com.apple.coreml.model.preview.type"] ="depthEstimation"coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.category"] ="image"coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.name"] =f"{model_name}.mlpackage"coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.version"] ="2.0"coreml_model.user_defined_metadata["com.apple.developer.machine-learning.models.release-date"] ="2024-06"coreml_model.save(f"{model_name}.mlpackage")
System environment (please complete the following information):
coremltools version: v8.0
OS (e.g. MacOS version or Linux type): MacOS 15
transformers==4.40.0
torch==2.4.0
The text was updated successfully, but these errors were encountered:
czkoko
added
the
bug
Unexpected behaviour that should be corrected (type)
label
Nov 1, 2024
By definition, scale_factor_height and scale_factor_weight both must be a constant at the compile time,
however, if the input has flexible shape, the two args will not be constant anymore.
🐞Describing the bug
When the input shape is fixed, the model can be successfully converted to coreml, but when the input is a flexible shape, the following errors will occur:
Maybe the problem is in
upsample_bilinear2d
? I tried to modify it incorrectly. The model conversion was successful, but run failed.To Reproduce
System environment (please complete the following information):
The text was updated successfully, but these errors were encountered: