Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep getting input type with weight type not match error after quantization #89

Open
skill-diver opened this issue Dec 15, 2024 · 3 comments

Comments

@skill-diver
Copy link

Hi,

Author, I keep getting this error after quantize the model:

Traceback (most recent call last):
File "/home/my/roma_quant/experiments/quant.py", line 145, in
test_mega1500(model, "quant")
File "/home/my/roma_quant/experiments/quant.py", line 49, in test_mega1500
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/benchmarks/megadepth_pose_estimation_benchmark.py", line 73, in benchmark
dense_matches, dense_certainty = model.match(
^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/models/matcher.py", line 704, in match
corresps = self.forward_symmetric(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/models/matcher.py", line 559, in forward_symmetric
corresps, dec_timer = self.decoder(f_q_pyramid,
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/models/matcher.py", line 370, in forward
f1_s, f2_s = self.projnew_scale, self.projnew_scale
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py", line 219, in forward
input = module(input)
^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/nn/qconv2d.py", line 55, in forward
return self._conv_forward(input, self.qweight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 272, in torch_function
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 324, in torch_dispatch
return qfallback(op, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/qtensor.py", line 29, in qfallback
return callable(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/ops.py", line 1061, in call
return self
._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

Do you know what kind of setting I could add to avoid this issue?

@Parskatt
Copy link
Owner

I would run a debugger and set a breakpoint around the error. From what I remember there is some .float() around there which may cause issues if not autocasting. Do you have autocast on or off?

@skill-diver
Copy link
Author

Yes, I find I turn off autocast in matcher.py and set torch.float32 when define the model could solve this annoying error.

@skill-diver
Copy link
Author

I mean you should delete "with autocast" before the code block in matcher.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants