Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Wrapper comments by BertP #5

Merged
merged 6 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_incorrect_leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def test_close_incorrect_leaf_name_recommendation_on_top_level(self):
with self.assertRaises(AttributeError) as context:
self.query.filter(lambda x: x.id == 247).select(lambda x: x.parent).get()

self.assertEqual(str(context.exception), "Field 'parent' not found. Did you mean 'parent_id'?")
self.assertEqual(str(context.exception), "Field 'parent' not found. Try one of the following 'parent_id,department_id,parent_user_id'")

def test_very_incorrect_leaf_name_recommendation_on_top_level(self):
with self.assertRaises(AttributeError) as context:
Expand All @@ -17,7 +17,7 @@ def test_close_incorrect_leaf_name_recommendation_on_nested(self):
with self.assertRaises(AttributeError) as context:
self.query.filter(lambda x: x.id == 247).select(lambda x: x.parent_id.namm).get()

self.assertEqual(str(context.exception), "Field 'namm' not found in 'hr.employee'. Did you mean 'name'?")
self.assertEqual(str(context.exception), "Field 'namm' not found in 'hr.employee'. Try one of the following 'name'")

def test_very_incorrect_leaf_name_recommendation_on_nested(self):
with self.assertRaises(AttributeError) as context:
Expand Down
82 changes: 30 additions & 52 deletions wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,26 @@ def __getattr__(self, attr: str) -> "FieldProxy":
"""Handle attribute access for relational fields."""
field_def = self.fields.get(self.field_name)
if not field_def:
closest_match = difflib.get_close_matches(self.field_name, self.fields.keys(), n=1)
if closest_match:
raise AttributeError(f"Field '{self.field_name}' not found. Did you mean '{closest_match[0]}'?")
closest_matches = difflib.get_close_matches(self.field_name, self.fields.keys())
if closest_matches:
raise AttributeError(f"Field '{self.field_name}' not found. Try one of the following '{','.join(closest_matches)}'")
else:
raise AttributeError(f"Field '{self.field_name}' not found.")

if field_def.get("type") in {"many2one", "one2many", "many2many"}:
relation = field_def.get("relation")
if not relation:
raise AttributeError(f"No relation found for field '{self.field_name}'.")
raise AttributeError(f"'{self.field_name}' has no relation. This indicates an issue with your odoo database, contact your odoo administrator.")

related_fields = self.model.query.orm.fields_cache.setdefault(
relation, self.model.query.orm._introspect_fields(relation)
)
if relation not in self.model.query.orm.fields_cache:
self.model.query.orm.fields_cache[relation] = self.model.query.orm._introspect_fields(relation)
related_fields = self.model.query.orm.fields_cache[relation]

# Allow 'id' and 'external_id' even if they are not in related_fields
if attr not in related_fields and attr not in {"id", "external_id"}:
closest_match = difflib.get_close_matches(attr, related_fields.keys(), n=1)
if closest_match:
raise AttributeError(f"Field '{attr}' not found in '{relation}'. Did you mean '{closest_match[0]}'?")
closest_matches = difflib.get_close_matches(attr, related_fields.keys())
if closest_matches:
raise AttributeError(f"Field '{attr}' not found in '{relation}'. Try one of the following '{','.join(closest_matches)}'")
else:
raise AttributeError(f"Field '{attr}' not found in '{relation}'")

Expand All @@ -71,56 +71,44 @@ def __getattr__(self, attr: str) -> "FieldProxy":
export_field_path=f"{self.export_field_path}/{attr}",
)
else:
raise AttributeError(
f"Field '{self.field_name}' is not a relational field and has no attribute '{attr}'."
)
raise AttributeError(f"'{self.field_name}' has no attributes. Remove '.{attr}'")



def __eq__(self, other: Any) -> "ModelProxy":
def __eq__(self, other: Any):
self.model._register_condition((self.field_path, "=", other))
return self.model

def __ne__(self, other: Any) -> "ModelProxy":

def __ne__(self, other: Any):
self.model._register_condition((self.field_path, "!=", other))
return self.model

def __contains__(self, other: Any) -> "ModelProxy":
def __contains__(self, other: Any):
operator = "ilike" if isinstance(other, str) else "in"
value = other if isinstance(other, str) else ([other] if not isinstance(other, list) else other)
self.model._register_condition((self.field_path, operator, value))
return self.model

def __lt__(self, other: Any) -> "ModelProxy":
def __lt__(self, other: Any):
self.model._register_condition((self.field_path, "<", other))
return self.model

def __le__(self, other: Any) -> "ModelProxy":
def __le__(self, other: Any):
self.model._register_condition((self.field_path, "<=", other))
return self.model

def __gt__(self, other: Any) -> "ModelProxy":
def __gt__(self, other: Any):
self.model._register_condition((self.field_path, ">", other))
return self.model

def __ge__(self, other: Any) -> "ModelProxy":
def __ge__(self, other: Any):
self.model._register_condition((self.field_path, ">=", other))
return self.model


class ModelProxy:
def __init__(self, fields: Dict, query: "OdooQuery", collect_accesses: bool = False):
def __init__(self, fields: Dict, query: "OdooQuery"):
self.fields = fields
self.query = query
self.conditions = []
self.accesses = []
self.collect_accesses = collect_accesses

def __getattr__(self, item: str) -> FieldProxy:
"""Handle attribute access to dynamically return a FieldProxy."""
if item in self.fields or item in {"id", "external_id"}:
if self.collect_accesses:
self.accesses.append(item)
self.accesses.append(item)

return FieldProxy(
field_name=item,
Expand All @@ -130,9 +118,9 @@ def __getattr__(self, item: str) -> FieldProxy:
export_field_path=item,
)
else:
closest_match = difflib.get_close_matches(item, self.fields.keys(), n=1)
if closest_match:
raise AttributeError(f"Field '{item}' not found. Did you mean '{closest_match[0]}'?")
closest_matches = difflib.get_close_matches(item, self.fields.keys())
if closest_matches:
raise AttributeError(f"Field '{item}' not found. Try one of the following '{','.join(closest_matches)}'")
else:
raise AttributeError(f"Field '{item}' not found.")

Expand All @@ -158,7 +146,8 @@ def _authenticate(self) -> int:
return self.common_proxy.authenticate(self.db, self.username, self.password, {})

def __getitem__(self, model_name: str) -> "OdooQuery":
self.fields_cache.setdefault(model_name, self._introspect_fields(model_name))
if model_name not in self.fields_cache:
self.fields_cache[model_name] = self._introspect_fields(model_name)
return OdooQuery(self, model_name)

def _introspect_fields(self, model_name: str) -> Dict:
Expand All @@ -176,18 +165,7 @@ def _introspect_fields(self, model_name: str) -> Dict:
def set_context(self, **kwargs):
"""Set the context for subsequent queries."""
self.context.update(kwargs)

def close(self):
"""Close the ServerProxy connections."""
if self.common_proxy:
self.common_proxy("close") if hasattr(self.common_proxy, "close") else None
self.common_proxy = None
if self.object_proxy:
self.object_proxy("close") if hasattr(self.object_proxy, "close") else None
self.object_proxy = None

def __exit__(self):
self.close()


class OdooQuery:
def __init__(self, orm: Client, model_name: str):
Expand All @@ -203,7 +181,7 @@ def __init__(self, orm: Client, model_name: str):

def select(self, select_func) -> "OdooQuery":
"""Apply a projection."""
proxy = ModelProxy(self.fields, self, collect_accesses=False)
proxy = ModelProxy(self.fields, self)
result = select_func(proxy)

def collect_projections(res) -> List[FieldProxy]:
Expand All @@ -223,14 +201,14 @@ def collect_projections(res) -> List[FieldProxy]:

def filter(self, filter_func) -> "OdooQuery":
"""Apply filter conditions using a lambda function."""
proxy = ModelProxy(self.fields, self, collect_accesses=True)
proxy = ModelProxy(self.fields, self)
filter_func(proxy)
self.filters.extend(proxy.conditions)
return self

def order_by(self, order_func, descending: bool = False) -> "OdooQuery":
"""Apply ordering on fields."""
proxy = ModelProxy(self.fields, self, collect_accesses=False)
proxy = ModelProxy(self.fields, self)
result = order_func(proxy)

def collect_fields(res) -> List[FieldProxy]:
Expand Down
Loading