Skip to content

Commit

Permalink
Merge branch '3.x' into 2.0-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
josephmancuso authored Nov 25, 2024
2 parents 5da7c38 + 1f788da commit ede0ee0
Show file tree
Hide file tree
Showing 24 changed files with 167 additions and 97 deletions.
4 changes: 4 additions & 0 deletions src/masoniteorm/connections/BaseConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,7 @@ def enable_disable_foreign_keys(self):
self._connection.execute(platform.enable_foreign_key_constraints())
elif foreign_keys is not None:
self._connection.execute(platform.disable_foreign_key_constraints())

def get_row_count(self):
return self._cursor.rowcount

1 change: 0 additions & 1 deletion src/masoniteorm/connections/MSSQLConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def query(self, query, bindings=(), results="*"):
for q in query:
self.statement(q, ())
return
query = query.replace("'?'", "?")
self.statement(query, bindings)
if results == 1:
if not cursor.description:
Expand Down
4 changes: 2 additions & 2 deletions src/masoniteorm/connections/MySQLConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def query(self, query, bindings=(), results="*"):
with self._cursor as cursor:
if isinstance(query, list):
for q in query:
q = q.replace("'?'", "%s")
q = q.replace("?", "%s")
self.statement(q, ())
return

query = query.replace("'?'", "%s")
query = query.replace("?", "%s")
self.statement(query, bindings)
if results == 1:
return self.format_cursor_results(cursor.fetchone())
Expand Down
2 changes: 1 addition & 1 deletion src/masoniteorm/connections/PostgresConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def query(self, query, bindings=(), results="*"):
self.statement(q, ())
return

query = query.replace("'?'", "%s")
query = query.replace("?", "%s")
self.statement(query, bindings)
if results == 1:
return dict(cursor.fetchone() or {})
Expand Down
1 change: 0 additions & 1 deletion src/masoniteorm/connections/SQLiteConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def query(self, query, bindings=(), results="*"):
for query in query:
self.statement(query)
else:
query = query.replace("'?'", "?")
self.statement(query, bindings)
if results == 1:
result = [dict(row) for row in self._cursor.fetchall()]
Expand Down
31 changes: 22 additions & 9 deletions src/masoniteorm/models/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def get_builder(self):
self.builder = QueryBuilder(
connection=self.__connection__,
table=self.get_table_name(),
connection_details=self.get_connection_details(),
# connection_details=self.get_connection_details(),
model=self,
scopes=self._scopes.get(self.__class__),
dry=self.__dry__,
Expand Down Expand Up @@ -544,13 +544,9 @@ def new_collection(cls, data):
return Collection(data)

@classmethod
def create(
cls,
dictionary: Dict[str, Any] = None,
query: bool = False,
cast: bool = False,
**kwargs,
):


def create(cls, dictionary=None, query=False, cast=True, **kwargs)>>>> 3.x
"""Creates new records based off of a dictionary as well as data set on the model
such as fillable values.
Expand Down Expand Up @@ -882,9 +878,13 @@ def save(self, query=False):

if not query:
if self.is_loaded():

result = builder.update(
self.__dirty_attributes__, ignore_mass_assignment=True
)

builder.update(self.__dirty_attributes__)

else:
result = self.create(
self.__dirty_attributes__,
Expand All @@ -893,8 +893,9 @@ def save(self, query=False):
ignore_mass_assignment=True,
)
self.observe_events(self, "saved")
self.fill(result.__attributes__)
self.__dirty_attributes__ = {}
if self.is_loaded():
return self
return result

if self.is_loaded():
Expand Down Expand Up @@ -967,6 +968,18 @@ def _set_cast_attribute(self, attribute, value):

return cast_method(value)

def transform_dict(self, attributes: dict):
new_dict = {}
for key, value in attributes.items():
if key in self.get_dates():
new_dict.update({key: self.get_new_datetime_string(value)})
elif key in self.__casts__:
new_dict.update({key: self._cast_attribute(key, value)})
else:
new_dict.update({key: value})

return new_dict

@classmethod
def load(cls, *loads):
cls.boot()
Expand Down
73 changes: 33 additions & 40 deletions src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,14 @@ def delete(self, column=None, value=None, query=False):
self.where(model.get_primary_key(), model.get_primary_key_value())
self.observe_events(model, "deleting")

result = self.new_connection().query(self.to_qmark(), self._bindings)
connection = self.new_connection()

connection.query(self.to_qmark(), self._bindings)

if model:
self.observe_events(model, "deleted")

return result
return connection.get_row_count()

def where(self, column, *args):
"""Specifies a where expression.
Expand Down Expand Up @@ -1420,54 +1422,40 @@ def update(
additional.update({model.get_primary_key(): model.get_primary_key_value()})

self.observe_events(model, "updating")

if model:
if not model.__force_update__ and not force:
# Filter updates to only those with changes
updates = {
attr: value
for attr, value in updates.items()
if (
value is None
or model.__original_attributes__.get(attr, None) != value
)
}

# Do not execute query if no changes
if not updates:
return self if dry or self.dry else model

# Cast date fields
date_fields = model.get_dates()
for key, value in updates.items():
if key in date_fields:
if value:
updates[key] = model.get_new_datetime_string(value)
else:
updates[key] = value
# Cast value if necessary
if cast:
if value:
updates[key] = model.cast_value(value)
else:
updates[key] = value
elif not updates:
# Do not perform query if there are no updates
return self
# update only attributes with changes
if model and not model.__force_update__ and not force:
changes = {}
for attribute, value in updates.items():
if (
model.__original_attributes__.get(attribute, None) != value
or value is None
):
changes.update({attribute: value})
updates = changes

if model and updates:
updates = model.transform_dict(updates)

# do not perform update query if no changes
if len(updates.keys()) == 0:
if dry or self.dry:
return self
return 0

self._updates = (UpdateQueryExpression(updates),)
self.set_action("update")
if dry or self.dry:
return self

additional.update(updates)
connection = self.new_connection()

self.new_connection().query(self.to_qmark(), self._bindings)
connection.query(self.to_qmark(), self._bindings)
if model:
model.fill(updates)
self.observe_events(model, "updated")
model.fill_original(updates)
return model
return connection.get_row_count()
return additional

def force_update(self, updates: dict, dry=False):
Expand All @@ -1488,7 +1476,7 @@ def set_updates(self, updates: dict, dry=False):
self._updates += (UpdateQueryExpression(updates),)
return self

def increment(self, column, value=1):
def increment(self, column, value=1, dry=False):
"""Increments a column's value.
Arguments:
Expand Down Expand Up @@ -1521,13 +1509,16 @@ def increment(self, column, value=1):
)

self.set_action("update")
if dry:
return self

results = self.new_connection().query(self.to_qmark(), self._bindings)
processed_results = self.get_processor().get_column_value(
self, column, results, id_key, id_value
)
return processed_results

def decrement(self, column, value=1):
def decrement(self, column, value=1, dry=False):
"""Decrements a column's value.
Arguments:
Expand Down Expand Up @@ -1560,6 +1551,8 @@ def decrement(self, column, value=1):
)

self.set_action("update")
if dry:
return self
result = self.new_connection().query(self.to_qmark(), self._bindings)
processed_results = self.get_processor().get_column_value(
self, column, result, id_key, id_value
Expand Down
36 changes: 20 additions & 16 deletions src/masoniteorm/query/grammars/BaseGrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def columnize_bulk_values(self, columns=[], qmark=False):
if qmark:
self.add_binding(y)
inner += (
"'?', "
"?, "
if qmark
else self.value_string().format(value=y, separator=", ")
)
Expand All @@ -190,7 +190,7 @@ def columnize_bulk_values(self, columns=[], qmark=False):
if qmark:
self.add_binding(x)
sql += (
"'?', "
"?, "
if qmark
else self.process_value_string().format(
value="?" if qmark else x, separator=", "
Expand Down Expand Up @@ -264,7 +264,7 @@ def process_joins(self, qmark=False):
)
else:
if qmark:
value = "'?'"
value = "?"
self.add_binding(clause.value)
else:
value = self._compile_value(clause.value)
Expand Down Expand Up @@ -312,7 +312,7 @@ def _compile_key_value_equals(self, qmark=False):
else:
sql += sql_string.format(
column=self._table_column_string(key),
value=value if not qmark else "?",
value=self.value_string().format(value=value, separator="") if not qmark else "?",
separator=", ",
)

Expand All @@ -321,7 +321,7 @@ def _compile_key_value_equals(self, qmark=False):
else:
sql += sql_string.format(
column=self._table_column_string(column),
value=value if not qmark else "?",
value=self.value_string().format(value=value, separator=", ") if not qmark else "?",
separator=", ",
)
if qmark:
Expand Down Expand Up @@ -588,19 +588,23 @@ def process_wheres(self, qmark=False, strip_first_where=False):
if qmark:
self.add_binding(low)
self.add_binding(high)
low = "?"
high = "?"

sql_string = self.between_string().format(
low=self._compile_value(low),
high=self._compile_value(high),
low=self._compile_value(low) if not qmark else "?",
high=self._compile_value(high) if not qmark else "?",
column=self._table_column_string(where.column),
keyword=keyword,
)
elif equality == "NOT BETWEEN":
low = where.low
high = where.high
if qmark:
self.add_binding(low)
self.add_binding(high)

sql_string = self.not_between_string().format(
low=self._compile_value(where.low),
high=self._compile_value(where.high),
low=self._compile_value(low) if not qmark else "?",
high=self._compile_value(high) if not qmark else "?",
column=self._table_column_string(where.column),
keyword=keyword,
)
Expand Down Expand Up @@ -657,7 +661,7 @@ def process_wheres(self, qmark=False, strip_first_where=False):
query_value = "("
for val in value:
if qmark:
query_value += "'?', "
query_value += "?, "
self.add_binding(val)
else:
query_value += self.value_string().format(
Expand All @@ -671,7 +675,7 @@ def process_wheres(self, qmark=False, strip_first_where=False):
sql_string = self.get_false_column_string()
query_value = 0
elif qmark and value_type != "column":
query_value = "'?'"
query_value = "?"
if (
value is not True
and value_type != "value_equals"
Expand All @@ -681,7 +685,7 @@ def process_wheres(self, qmark=False, strip_first_where=False):
self.add_binding(value)
elif value_type == "value":
if qmark:
query_value = "'?'"
query_value = "?"
else:
query_value = self.value_string().format(value=value, separator="")

Expand Down Expand Up @@ -833,14 +837,14 @@ def process_values(self, separator="", qmark=False):
for column, value in dict(c).items():
if qmark:
self.add_binding(value)
sql += f"'?'{separator}".strip()
sql += f"?{separator}".strip()
else:
sql += self._compile_value(value, separator=separator)
else:
for column, value in dict(self._columns).items():
if qmark:
self.add_binding(value)
sql += f"'?'{separator}".strip()
sql += f"?{separator}".strip()
else:
sql += self._compile_value(value, separator=separator)

Expand Down
2 changes: 1 addition & 1 deletion src/masoniteorm/query/grammars/MSSQLGrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def aggregate_string_with_alias(self):
return "{aggregate_function}({column}) AS {alias}"

def key_value_string(self):
return "{column} = '{value}'{separator}"
return "{column} = {value}{separator}"

def column_value_string(self):
return "{column} = {value}{separator}"
Expand Down
3 changes: 2 additions & 1 deletion src/masoniteorm/query/grammars/MySQLGrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,13 @@ def subquery_alias_string(self):
return "AS {alias}"

def key_value_string(self):
return "{column} = '{value}'{separator}"
return "{column} = {value}{separator}"

def column_value_string(self):
return "{column} = {value}{separator}"

def increment_string(self):

return "{column} = {column} + '{value}'{separator}"

def decrement_string(self):
Expand Down
Loading

0 comments on commit ede0ee0

Please sign in to comment.