Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow override of model default selects #912

Open
wants to merge 3 commits into
base: 2.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/masoniteorm/models/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def get_builder(self):
dry=self.__dry__,
)

return self.builder.select(*self.get_selects())
return self.builder

def get_selects(self):
return self.__selects__
Expand Down
8 changes: 6 additions & 2 deletions src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import inspect
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Optional, Callable
from typing import Any, Callable, Dict, List, Optional

from ..collection.Collection import Collection
from ..config import load_config
from ..exceptions import (
HTTP404,
ConnectionNotRegistered,
InvalidArgument,
ModelNotFound,
MultipleRecordsFound,
InvalidArgument,
)
from ..expressions.expressions import (
AggregateExpression,
Expand Down Expand Up @@ -1229,6 +1229,7 @@ def or_where_doesnt_have(self, relationship, callback):
return self

def with_count(self, relationship, callback=None):
self.select(*self._model.get_selects())
return getattr(self._model, relationship).get_with_count_query(
self, callback=callback
)
Expand Down Expand Up @@ -2067,6 +2068,9 @@ def get_grammar(self):

# Either _creates when creating, otherwise use columns
columns = self._creates or self._columns
if not columns and not self._aggregates and self._model:
self.select(*self._model.get_selects())
columns = self._columns

return self.grammar(
columns=columns,
Expand Down
6 changes: 3 additions & 3 deletions src/masoniteorm/query/grammars/BaseGrammar.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import re

from ...expressions.expressions import (
SubGroupExpression,
SubSelectExpression,
SelectExpression,
JoinClause,
OnClause,
SelectExpression,
SubGroupExpression,
SubSelectExpression,
)


Expand Down
3 changes: 2 additions & 1 deletion src/masoniteorm/query/grammars/PostgresGrammar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .BaseGrammar import BaseGrammar
import re

from .BaseGrammar import BaseGrammar


class PostgresGrammar(BaseGrammar):
"""Postgres grammar class."""
Expand Down
3 changes: 2 additions & 1 deletion src/masoniteorm/query/grammars/SQLiteGrammar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .BaseGrammar import BaseGrammar
import re

from .BaseGrammar import BaseGrammar


class SQLiteGrammar(BaseGrammar):
"""SQLite grammar class."""
Expand Down
6 changes: 3 additions & 3 deletions src/masoniteorm/query/grammars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .SQLiteGrammar import SQLiteGrammar
from .PostgresGrammar import PostgresGrammar
from .MySQLGrammar import MySQLGrammar
from .MSSQLGrammar import MSSQLGrammar
from .MySQLGrammar import MySQLGrammar
from .PostgresGrammar import PostgresGrammar
from .SQLiteGrammar import SQLiteGrammar
2 changes: 1 addition & 1 deletion src/masoniteorm/query/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .MSSQLPostProcessor import MSSQLPostProcessor
from .MySQLPostProcessor import MySQLPostProcessor
from .PostgresPostProcessor import PostgresPostProcessor
from .SQLitePostProcessor import SQLitePostProcessor
from .MSSQLPostProcessor import MSSQLPostProcessor
22 changes: 20 additions & 2 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ModelTestForced(Model):
__force_update__ = True

class BaseModel(Model):
__dry__ = True
def get_selects(self):
return [f"{self.get_table_name()}.*"]

Expand Down Expand Up @@ -267,9 +268,26 @@ def test_model_can_provide_default_select(self):
"""SELECT `users`.* FROM `users`""",
)

def test_model_can_add_to_default_select(self):
def test_model_can_override_to_default_select(self):
sql = ModelWithBaseModel.select(["products.name", "products.id", "store.name"]).to_sql()
self.assertEqual(
sql,
"""SELECT `users`.*, `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""",
"""SELECT `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""",
)

def test_model_can_use_aggregate_funcs_with_default_selects(self):
sql = ModelWithBaseModel.count().to_sql()
self.assertEqual(
sql,
"""SELECT COUNT(*) AS m_count_reserved FROM `users`""",
)
sql = ModelWithBaseModel.max("id").to_sql()
self.assertEqual(
sql,
"""SELECT MAX(`users`.`id`) AS id FROM `users`""",
)
sql = ModelWithBaseModel.min("id").to_sql()
self.assertEqual(
sql,
"""SELECT MIN(`users`.`id`) AS id FROM `users`""",
)
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_builder(self, table="users"):
connection_class=connection,
connection="mssql",
table=table,
model=User,
model=User(),
)

def test_has(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/mysql/builder/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestTransactions(unittest.TestCase):
pass
# def get_builder(self, table="users"):
# connection = ConnectionFactory().make("default")
# return QueryBuilder(MySQLGrammar, connection, table=table, model=User)
# return QueryBuilder(MySQLGrammar, connection, table=table, model=User())

# def test_can_start_transaction(self, table="users"):
# builder = self.get_builder()
Expand Down
1 change: 0 additions & 1 deletion tests/postgres/builder/test_postgres_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def get_builder(self, table="users"):
grammar=PostgresGrammar,
connection=connection,
table=table,
# model=User,
connection_details=DATABASES,
).on("postgres")

Expand Down
3 changes: 0 additions & 3 deletions tests/sqlite/builder/test_sqlite_builder_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory


class User(Model):
Expand All @@ -26,7 +24,6 @@ def get_builder(self, table="users"):
connection_class=connection,
connection="dev",
table=table,
# model=User,
connection_details=DATABASES,
).on("dev")

Expand Down
4 changes: 1 addition & 3 deletions tests/sqlite/builder/test_sqlite_builder_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory


class User(Model):
Expand All @@ -17,7 +15,7 @@ class User(Model):
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None

def get_builder(self, table="users", model=User):
def get_builder(self, table="users", model=User()):
connection = ConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to, has_many
from tests.utils import MockConnectionFactory


class Logo(Model):
Expand Down Expand Up @@ -58,7 +57,7 @@ def profile(self):
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None

def get_builder(self, table="users", model=User):
def get_builder(self, table="users", model=User()):
connection = ConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import inspect
import unittest

from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
Expand Down Expand Up @@ -47,7 +45,7 @@ class BaseTestQueryRelationships(unittest.TestCase):
def get_builder(self, table="users"):
connection = MockConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar, connection_class=connection, table=table, model=User
grammar=SQLiteGrammar, connection_class=connection, table=table, model=User()
)

def test_has(self):
Expand Down
5 changes: 1 addition & 4 deletions tests/sqlite/builder/test_sqlite_transaction.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import inspect
import unittest

from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory
from tests.integrations.config.database import DB
from src.masoniteorm.collection import Collection

Expand All @@ -26,7 +23,7 @@ def get_builder(self, table="users"):
grammar=SQLiteGrammar,
connection="dev",
table=table,
model=User,
model=User(),
connection_details=DATABASES,
).on("dev")

Expand Down