You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using distributed or parallel set-up in script: No
Information
I am using the transformers-CLIP model and compiling it using JIT. The compilation code works fine when I don't have adapter-transformers installed. I face the error only after I install the adapter-transformers package. I can't avoid using adapter-transformers because other modules require it.
To reproduce
Steps to reproduce the behavior:
create a conda environment: conda create -n exp python==3.10
File ~/miniconda3/envs/exp2/lib/python3.10/site-packages/torch/_jit_internal.py:758, in module_has_exports(mod)
756 def module_has_exports(mod):
757 for name in dir(mod):
--> 758 if hasattr(mod, name):
759 item = getattr(mod, name)
760 if callable(item):
File ~/miniconda3/envs/exp2/lib/python3.10/site-packages/transformers/adapters/model_mixin.py:316, in EmbeddingAdaptersWrapperMixin.active_embeddings(self)
314 @property
315 def active_embeddings(self):
--> 316 return self.base_model.active_embeddings
File ~/miniconda3/envs/exp2/lib/python3.10/site-packages/transformers/adapters/model_mixin.py:316, in EmbeddingAdaptersWrapperMixin.active_embeddings(self)
314 @property
315 def active_embeddings(self):
--> 316 return self.base_model.active_embeddings
[... skipping similar frames: EmbeddingAdaptersWrapperMixin.active_embeddings at line 316 (2971 times)]
File ~/miniconda3/envs/exp2/lib/python3.10/site-packages/transformers/adapters/model_mixin.py:316, in EmbeddingAdaptersWrapperMixin.active_embeddings(self)
314 @property
315 def active_embeddings(self):
--> 316 return self.base_model.active_embeddings
File ~/miniconda3/envs/exp2/lib/python3.10/site-packages/transformers/modeling_utils.py:1117, in PreTrainedModel.base_model(self)
1112 @property
1113 def base_model(self) -> nn.Module:
1114 """
1115 `torch.nn.Module`: The main body of the model.
1116 """
-> 1117 return getattr(self, self.base_model_prefix, self)
File ~/miniconda3/envs/exp2/lib/python3.10/site-packages/torch/nn/modules/module.py:1695, in Module.__getattr__(self, name)
1693 if name in modules:
1694 return modules[name]
-> 1695 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
RecursionError: maximum recursion depth exceeded while calling a Python object
Expected behavior
The expected outcome is a properly saved jit_model.pth. I have tried increase the recursion depth to 5000 and it still doesn't work. Ideally the/adapters/model_mixin.py should not be called since my Clip model is simply dependent on transformers. It looks like adapter-transformer CLIP model is invoked.
Any help here would be appreciated. Thank you so much!!
The text was updated successfully, but these errors were encountered:
The adapter-transformer package is deprecated and the new and actively maintained package we provide is the adapters package (see here or #584 for more information).
However, I tried to reproduce the error with the new package and it is still possible; here is my slimmed-down version to reproduce the error with adapters:
from transformers import CLIPModel
import torch
import torch.nn as nn
from adapters import init
testmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
testmodel.eval()
# enable adapter support
init(testmodel)
shape = [1, 3, 224, 224]
_random_example = torch.rand(shape, requires_grad=False, dtype=torch.float32)
x = torch.index_select(
_random_example,
dim=0,
index=torch.Tensor([0]).to(dtype=torch.int32),
)
print(x)
_inputs = {"X": x}
traced_model = torch.jit.trace(
testmodel,
example_inputs=(
_inputs["X"],
)
)
torch.jit.save(traced_model, "jit_model.pth")
Since this problem still exists in the new package I will have a look into this and get back to you.
If you have gained any new insights since you posted this issue, please let me know.
Environment info
adapter-transformers
version: 3.2.1Information
I am using the transformers-CLIP model and compiling it using JIT. The compilation code works fine when I don't have
adapter-transformers
installed. I face the error only after I install theadapter-transformers
package. I can't avoid usingadapter-transformers
because other modules require it.To reproduce
Steps to reproduce the behavior:
conda create -n exp python==3.10
conda activate exp
pip install torch
,pip install transformers
,pip install adapter-transformers
,pip install ipython
ipython
Error:
Expected behavior
The expected outcome is a properly saved
jit_model.pth
. I have tried increase the recursion depth to 5000 and it still doesn't work. Ideally the/adapters/model_mixin.py
should not be called since my Clip model is simply dependent on transformers. It looks likeadapter-transformer
CLIP model is invoked.Any help here would be appreciated. Thank you so much!!
The text was updated successfully, but these errors were encountered: