diff --git a/magnum/common/context.py b/magnum/common/context.py index 547c9cc9b4..b11689fea8 100644 --- a/magnum/common/context.py +++ b/magnum/common/context.py @@ -12,6 +12,7 @@ from eventlet.green import threading from oslo_context import context +from oslo_db.sqlalchemy import enginefacade from magnum.common import policy @@ -20,6 +21,7 @@ CONF = magnum.conf.CONF +@enginefacade.transaction_context_provider class RequestContext(context.RequestContext): """Extends security contexts from the OpenStack common library.""" diff --git a/magnum/db/sqlalchemy/alembic/env.py b/magnum/db/sqlalchemy/alembic/env.py index ff264b7652..e7690eee4e 100644 --- a/magnum/db/sqlalchemy/alembic/env.py +++ b/magnum/db/sqlalchemy/alembic/env.py @@ -13,8 +13,8 @@ from logging import config as log_config from alembic import context +from oslo_db.sqlalchemy import enginefacade -from magnum.db.sqlalchemy import api as sqla_api from magnum.db.sqlalchemy import models # this is the Alembic Config object, which provides @@ -43,7 +43,7 @@ def run_migrations_online(): and associate a connection with the context. """ - engine = sqla_api.get_engine() + engine = enginefacade.writer.get_engine() with engine.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/magnum/db/sqlalchemy/api.py b/magnum/db/sqlalchemy/api.py index 46457a981a..0b7dfc4d43 100644 --- a/magnum/db/sqlalchemy/api.py +++ b/magnum/db/sqlalchemy/api.py @@ -15,8 +15,11 @@ """SQLAlchemy storage backend.""" import six +import threading + +from oslo_db import api as oslo_db_api from oslo_db import exception as db_exc -from oslo_db.sqlalchemy import session as db_session +from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import utils as db_utils from oslo_log import log from oslo_utils import importutils @@ -42,30 +45,7 @@ LOG = log.getLogger(__name__) -_FACADE = None - - -def _create_facade_lazily(): - global _FACADE - if _FACADE is None: - # FIXME(karolinku): autocommit=True it's not compatible with - # SQLAlchemy 2.0, and will be removed in future - _FACADE = db_session.EngineFacade.from_config(CONF, autocommit=True) - if profiler_sqlalchemy: - if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy: - profiler_sqlalchemy.add_tracing(sa, _FACADE.get_engine(), "db") - - return _FACADE - - -def get_engine(): - facade = _create_facade_lazily() - return facade.get_engine() - - -def get_session(**kwargs): - facade = _create_facade_lazily() - return facade.get_session(**kwargs) +_CONTEXT = threading.local() def get_backend(): @@ -73,17 +53,22 @@ def get_backend(): return Connection() -def model_query(model, *args, **kwargs): - """Query helper for simpler session usage. +def _session_for_read(): + return _wrap_session(enginefacade.reader.using(_CONTEXT)) - :param session: if present, the session to use - """ - session = kwargs.get('session') or get_session() - query = session.query(model, *args) - return query +# Please add @oslo_db_api.retry_on_deadlock decorator to all methods using +# _session_for_write (as deadlocks happen on write), so that oslo_db is able +# to retry in case of deadlocks. +def _session_for_write(): + return _wrap_session(enginefacade.writer.using(_CONTEXT)) +def _wrap_session(session): + if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy: + session = profiler_sqlalchemy.wrap_session(sa, session) + return session + def add_identity_filter(query, value): """Adds an identity filter to a query. @@ -104,8 +89,6 @@ def add_identity_filter(query, value): def _paginate_query(model, limit=None, marker=None, sort_key=None, sort_dir=None, query=None): - if not query: - query = model_query(model) sort_keys = ['id'] if sort_key and sort_key not in sort_keys: sort_keys.insert(0, sort_key) @@ -169,15 +152,16 @@ def _add_clusters_filters(self, query, filters): # Helper to filter based on node_count field from nodegroups def filter_node_count(query, node_count, is_master=False): nfunc = func.sum(models.NodeGroup.node_count) - nquery = model_query(models.NodeGroup) - if is_master: - nquery = nquery.filter(models.NodeGroup.role == 'master') - else: - nquery = nquery.filter(models.NodeGroup.role != 'master') - nquery = nquery.group_by(models.NodeGroup.cluster_id) - nquery = nquery.having(nfunc == node_count) - uuids = [ng.cluster_id for ng in nquery.all()] - return query.filter(models.Cluster.uuid.in_(uuids)) + with _session_for_read() as session: + nquery = session.query(models.NodeGroup) + if is_master: + nquery = nquery.filter(models.NodeGroup.role == 'master') + else: + nquery = nquery.filter(models.NodeGroup.role != 'master') + nquery = nquery.group_by(models.NodeGroup.cluster_id) + nquery = nquery.having(nfunc == node_count) + uuids = [ng.cluster_id for ng in nquery.all()] + return query.filter(models.Cluster.uuid.in_(uuids)) if 'node_count' in filters: query = filter_node_count( @@ -190,12 +174,14 @@ def filter_node_count(query, node_count, is_master=False): def get_cluster_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = self._add_clusters_filters(query, filters) - return _paginate_query(models.Cluster, limit, marker, - sort_key, sort_dir, query) - + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = self._add_clusters_filters(query, filters) + return _paginate_query(models.Cluster, limit, marker, + sort_key, sort_dir, query) + + @oslo_db_api.retry_on_deadlock def create_cluster(self, values): # ensure defaults are present for new clusters if not values.get('uuid'): @@ -203,68 +189,77 @@ def create_cluster(self, values): cluster = models.Cluster() cluster.update(values) - try: - cluster.save() - except db_exc.DBDuplicateEntry: - raise exception.ClusterAlreadyExists(uuid=values['uuid']) - return cluster + + with _session_for_write() as session: + try: + session.add(cluster) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.ClusterAlreadyExists(uuid=values['uuid']) + return cluster def get_cluster_by_id(self, context, cluster_id): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=cluster_id) - try: - return query.one() - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_id) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=cluster_id) + try: + return query.one() + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_id) def get_cluster_by_name(self, context, cluster_name): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(name=cluster_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple clusters exist with same name.' - ' Please use the cluster uuid instead.') - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_name) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(name=cluster_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict('Multiple clusters exist with same name.' + ' Please use the cluster uuid instead.') + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_name) def get_cluster_by_uuid(self, context, cluster_uuid): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=cluster_uuid) - try: - return query.one() - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_uuid) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=cluster_uuid) + try: + return query.one() + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_uuid) def get_cluster_stats(self, context, project_id=None): - query = model_query(models.Cluster) - node_count_col = models.NodeGroup.node_count - ncfunc = func.sum(node_count_col) - - if project_id: - query = query.filter_by(project_id=project_id) - nquery = query.session.query(ncfunc.label("nodes")).filter_by( - project_id=project_id) - else: - nquery = query.session.query(ncfunc.label("nodes")) + with _session_for_read() as session: + query = session.query(models.Cluster) + node_count_col = models.NodeGroup.node_count + ncfunc = func.sum(node_count_col) + + if project_id: + query = query.filter_by(project_id=project_id) + # TODO(tylerchristie): hmmmmm?? + nquery = query.session.query(ncfunc.label("nodes")).filter_by( + project_id=project_id) + else: + nquery = query.session.query(ncfunc.label("nodes")) clusters = query.count() nodes = int(nquery.one()[0]) if nquery.one()[0] else 0 return clusters, nodes def get_cluster_count_all(self, context, filters=None): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = self._add_clusters_filters(query, filters) - return query.count() + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = self._add_clusters_filters(query, filters) + return query.count() + @oslo_db_api.retry_on_deadlock def destroy_cluster(self, cluster_id): - session = get_session() - with session.begin(): - query = model_query(models.Cluster, session=session) + with _session_for_write() as session: + query = session.query(models.Cluster) query = add_identity_filter(query, cluster_id) try: @@ -282,10 +277,10 @@ def update_cluster(self, cluster_id, values): return self._do_update_cluster(cluster_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_cluster(self, cluster_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Cluster, session=session) + with _session_for_write() as session: + query = session.query(models.Cluster) query = add_identity_filter(query, cluster_id) try: ref = query.with_for_update().one() @@ -312,22 +307,24 @@ def _add_cluster_template_filters(self, query, filters): def get_cluster_template_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - query = self._add_cluster_template_filters(query, filters) - # include public (and not hidden) ClusterTemplates - public_q = model_query(models.ClusterTemplate).filter_by( - public=True, hidden=False) - query = query.union(public_q) - # include hidden and public ClusterTemplate if admin - if context.is_admin: - hidden_q = model_query(models.ClusterTemplate).filter_by( - public=True, hidden=True) - query = query.union(hidden_q) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + query = self._add_cluster_template_filters(query, filters) + # include public (and not hidden) ClusterTemplates + public_q = session.query(models.ClusterTemplate).filter_by( + public=True, hidden=False) + query = query.union(public_q) + # include hidden and public ClusterTemplate if admin + if context.is_admin: + hidden_q = session.query(models.ClusterTemplate).filter_by( + public=True, hidden=True) + query = query.union(hidden_q) return _paginate_query(models.ClusterTemplate, limit, marker, sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def create_cluster_template(self, values): # ensure defaults are present for new ClusterTemplates if not values.get('uuid'): @@ -335,59 +332,66 @@ def create_cluster_template(self, values): cluster_template = models.ClusterTemplate() cluster_template.update(values) - try: - cluster_template.save() - except db_exc.DBDuplicateEntry: - raise exception.ClusterTemplateAlreadyExists(uuid=values['uuid']) - return cluster_template + + with _session_for_write() as session: + try: + session.add(cluster_template) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.ClusterTemplateAlreadyExists( + uuid=values['uuid']) + return cluster_template def get_cluster_template_by_id(self, context, cluster_template_id): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter(models.ClusterTemplate.id == cluster_template_id) - try: - return query.one() - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_id) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by(public=True) + query = query.union(public_q) + query = query.filter(models.ClusterTemplate.id == cluster_template_id) + try: + return query.one() + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_id) def get_cluster_template_by_uuid(self, context, cluster_template_uuid): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter( - models.ClusterTemplate.uuid == cluster_template_uuid) - try: - return query.one() - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_uuid) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by(public=True) + query = query.union(public_q) + query = query.filter( + models.ClusterTemplate.uuid == cluster_template_uuid) + try: + return query.one() + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_uuid) def get_cluster_template_by_name(self, context, cluster_template_name): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter( - models.ClusterTemplate.name == cluster_template_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple ClusterTemplates exist with' - ' same name. Please use the ' - 'ClusterTemplate uuid instead.') - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_name) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by(public=True) + query = query.union(public_q) + query = query.filter( + models.ClusterTemplate.name == cluster_template_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict('Multiple ClusterTemplates exist with' + ' same name. Please use the ' + 'ClusterTemplate uuid instead.') + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_name) def _is_cluster_template_referenced(self, session, cluster_template_uuid): """Checks whether the ClusterTemplate is referenced by cluster(s).""" - query = model_query(models.Cluster, session=session) + query = session.query(models.Cluster) query = self._add_clusters_filters(query, {'cluster_template_id': - cluster_template_uuid}) + cluster_template_uuid}) return query.count() != 0 def _is_publishing_cluster_template(self, values): @@ -398,10 +402,10 @@ def _is_publishing_cluster_template(self, values): return True return False + @oslo_db_api.retry_on_deadlock def destroy_cluster_template(self, cluster_template_id): - session = get_session() - with session.begin(): - query = model_query(models.ClusterTemplate, session=session) + with _session_for_write() as session: + query = session.query(models.ClusterTemplate) query = add_identity_filter(query, cluster_template_id) try: @@ -425,10 +429,10 @@ def update_cluster_template(self, cluster_template_id, values): return self._do_update_cluster_template(cluster_template_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_cluster_template(self, cluster_template_id, values): - session = get_session() - with session.begin(): - query = model_query(models.ClusterTemplate, session=session) + with _session_for_write() as session: + query = session.query(models.ClusterTemplate) query = add_identity_filter(query, cluster_template_id) try: ref = query.with_for_update().one() @@ -447,6 +451,7 @@ def _do_update_cluster_template(self, cluster_template_id, values): ref.update(values) return ref + @oslo_db_api.retry_on_deadlock def create_x509keypair(self, values): # ensure defaults are present for new x509keypairs if not values.get('uuid'): @@ -454,34 +459,39 @@ def create_x509keypair(self, values): x509keypair = models.X509KeyPair() x509keypair.update(values) - try: - x509keypair.save() - except db_exc.DBDuplicateEntry: - raise exception.X509KeyPairAlreadyExists(uuid=values['uuid']) - return x509keypair + + with _session_for_write() as session: + try: + session.add(x509keypair) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.X509KeyPairAlreadyExists(uuid=values['uuid']) + return x509keypair def get_x509keypair_by_id(self, context, x509keypair_id): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=x509keypair_id) - try: - return query.one() - except NoResultFound: - raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id) + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=x509keypair_id) + try: + return query.one() + except NoResultFound: + raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id) def get_x509keypair_by_uuid(self, context, x509keypair_uuid): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=x509keypair_uuid) - try: - return query.one() - except NoResultFound: - raise exception.X509KeyPairNotFound(x509keypair=x509keypair_uuid) + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=x509keypair_uuid) + try: + return query.one() + except NoResultFound: + raise exception.X509KeyPairNotFound(x509keypair=x509keypair_uuid) + @oslo_db_api.retry_on_deadlock def destroy_x509keypair(self, x509keypair_id): - session = get_session() - with session.begin(): - query = model_query(models.X509KeyPair, session=session) + with _session_for_write() as session: + query = session.query(models.X509KeyPair) query = add_identity_filter(query, x509keypair_id) count = query.delete() if count != 1: @@ -495,10 +505,10 @@ def update_x509keypair(self, x509keypair_id, values): return self._do_update_x509keypair(x509keypair_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_x509keypair(self, x509keypair_id, values): - session = get_session() - with session.begin(): - query = model_query(models.X509KeyPair, session=session) + with _session_for_write() as session: + query = session.query(models.X509KeyPair) query = add_identity_filter(query, x509keypair_id) try: ref = query.with_for_update().one() @@ -521,26 +531,27 @@ def _add_x509keypairs_filters(self, query, filters): def get_x509keypair_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = self._add_x509keypairs_filters(query, filters) - return _paginate_query(models.X509KeyPair, limit, marker, - sort_key, sort_dir, query) - + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = self._add_x509keypairs_filters(query, filters) + return _paginate_query(models.X509KeyPair, limit, marker, + sort_key, sort_dir, query) + + @oslo_db_api.retry_on_deadlock def destroy_magnum_service(self, magnum_service_id): - session = get_session() - with session.begin(): - query = model_query(models.MagnumService, session=session) + with _session_for_write() as session: + query = session.query(models.MagnumService) query = add_identity_filter(query, magnum_service_id) count = query.delete() if count != 1: raise exception.MagnumServiceNotFound( magnum_service_id=magnum_service_id) + @oslo_db_api.retry_on_deadlock def update_magnum_service(self, magnum_service_id, values): - session = get_session() - with session.begin(): - query = model_query(models.MagnumService, session=session) + with _session_for_write() as session: + query = session.query(models.MagnumService) query = add_identity_filter(query, magnum_service_id) try: ref = query.with_for_update().one() @@ -556,48 +567,60 @@ def update_magnum_service(self, magnum_service_id, values): return ref def get_magnum_service_by_host_and_binary(self, host, binary): - query = model_query(models.MagnumService) - query = query.filter_by(host=host, binary=binary) - try: - return query.one() - except NoResultFound: - return None + with _session_for_read() as session: + query = session.query(models.MagnumService) + query = query.filter_by(host=host, binary=binary) + try: + return query.one() + except NoResultFound: + return None + @oslo_db_api.retry_on_deadlock def create_magnum_service(self, values): magnum_service = models.MagnumService() magnum_service.update(values) - try: - magnum_service.save() - except db_exc.DBDuplicateEntry: - host = values["host"] - binary = values["binary"] - LOG.warning("Magnum service with same host:%(host)s and" - " binary:%(binary)s had been saved into DB", - {'host': host, 'binary': binary}) - query = model_query(models.MagnumService) - query = query.filter_by(host=host, binary=binary) - return query.one() - return magnum_service + + with _session_for_write() as session: + try: + session.add(magnum_service) + session.flush() + except db_exc.DBDuplicateEntry: + host = values["host"] + binary = values["binary"] + LOG.warning("Magnum service with same host:%(host)s and" + " binary:%(binary)s had been saved into DB", + {'host': host, 'binary': binary}) + with _session_for_read() as read_session: + query = read_session.query(models.MagnumService) + query = query.filter_by(host=host, binary=binary) + return query.one() + return magnum_service def get_magnum_service_list(self, disabled=None, limit=None, marker=None, sort_key=None, sort_dir=None ): - query = model_query(models.MagnumService) - if disabled: - query = query.filter_by(disabled=disabled) + with _session_for_read() as session: + query = session.query(models.MagnumService) + if disabled: + query = query.filter_by(disabled=disabled) - return _paginate_query(models.MagnumService, limit, marker, - sort_key, sort_dir, query) + return _paginate_query(models.MagnumService, limit, marker, + sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def create_quota(self, values): quotas = models.Quota() quotas.update(values) - try: - quotas.save() - except db_exc.DBDuplicateEntry: - raise exception.QuotaAlreadyExists(project_id=values['project_id'], - resource=values['resource']) - return quotas + + with _session_for_write() as session: + try: + session.add(quotas) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.QuotaAlreadyExists( + project_id=values['project_id'], + resource=values['resource']) + return quotas def _add_quota_filters(self, query, filters): if filters is None: @@ -614,15 +637,16 @@ def _add_quota_filters(self, query, filters): def get_quota_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Quota) - query = self._add_quota_filters(query, filters) - return _paginate_query(models.Quota, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.Quota) + query = self._add_quota_filters(query, filters) + return _paginate_query(models.Quota, limit, marker, + sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def update_quota(self, project_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Quota, session=session) + with _session_for_write() as session: + query = session.query(models.Quota) resource = values['resource'] try: query = query.filter_by(project_id=project_id).filter_by( @@ -636,10 +660,10 @@ def update_quota(self, project_id, values): ref.update(values) return ref + @oslo_db_api.retry_on_deadlock def delete_quota(self, project_id, resource): - session = get_session() - with session.begin(): - query = model_query(models.Quota, session=session) \ + with _session_for_write() as session: + query = session.query(models.Quota) \ .filter_by(project_id=project_id) \ .filter_by(resource=resource) @@ -653,31 +677,34 @@ def delete_quota(self, project_id, resource): query.delete() def get_quota_by_id(self, context, quota_id): - query = model_query(models.Quota) - query = query.filter_by(id=quota_id) - try: - return query.one() - except NoResultFound: - msg = _('quota id %s .') % quota_id - raise exception.QuotaNotFound(msg=msg) + with _session_for_read() as session: + query = session.query(models.Quota) + query = query.filter_by(id=quota_id) + try: + return query.one() + except NoResultFound: + msg = _('quota id %s .') % quota_id + raise exception.QuotaNotFound(msg=msg) def quota_get_all_by_project_id(self, project_id): - query = model_query(models.Quota) - result = query.filter_by(project_id=project_id).all() + with _session_for_read() as session: + query = session.query(models.Quota) + result = query.filter_by(project_id=project_id).all() return result def get_quota_by_project_id_resource(self, project_id, resource): - query = model_query(models.Quota) - query = query.filter_by(project_id=project_id).filter_by( - resource=resource) + with _session_for_read() as session: + query = session.query(models.Quota) + query = query.filter_by(project_id=project_id).filter_by( + resource=resource) - try: - return query.one() - except NoResultFound: - msg = (_('project_id %(project_id)s resource %(resource)s.') % - {'project_id': project_id, 'resource': resource}) - raise exception.QuotaNotFound(msg=msg) + try: + return query.one() + except NoResultFound: + msg = (_('project_id %(project_id)s resource %(resource)s.') % + {'project_id': project_id, 'resource': resource}) + raise exception.QuotaNotFound(msg=msg) def _add_federation_filters(self, query, filters): if filters is None: @@ -704,60 +731,68 @@ def _add_federation_filters(self, query, filters): return query def get_federation_by_id(self, context, federation_id): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=federation_id) - try: - return query.one() - except NoResultFound: - raise exception.FederationNotFound(federation=federation_id) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=federation_id) + try: + return query.one() + except NoResultFound: + raise exception.FederationNotFound(federation=federation_id) def get_federation_by_uuid(self, context, federation_uuid): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=federation_uuid) - try: - return query.one() - except NoResultFound: - raise exception.FederationNotFound(federation=federation_uuid) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=federation_uuid) + try: + return query.one() + except NoResultFound: + raise exception.FederationNotFound(federation=federation_uuid) def get_federation_by_name(self, context, federation_name): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(name=federation_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple federations exist with same ' - 'name. Please use the federation uuid ' - 'instead.') - except NoResultFound: - raise exception.FederationNotFound(federation=federation_name) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(name=federation_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict('Multiple federations exist with same ' + 'name. Please use the federation uuid ' + 'instead.') + except NoResultFound: + raise exception.FederationNotFound(federation=federation_name) def get_federation_list(self, context, limit=None, marker=None, sort_key=None, sort_dir=None, filters=None): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = self._add_federation_filters(query, filters) - return _paginate_query(models.Federation, limit, marker, - sort_key, sort_dir, query) - + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = self._add_federation_filters(query, filters) + return _paginate_query(models.Federation, limit, marker, + sort_key, sort_dir, query) + + @oslo_db_api.retry_on_deadlock def create_federation(self, values): if not values.get('uuid'): values['uuid'] = uuidutils.generate_uuid() federation = models.Federation() federation.update(values) - try: - federation.save() - except db_exc.DBDuplicateEntry: - raise exception.FederationAlreadyExists(uuid=values['uuid']) - return federation + with _session_for_write() as session: + try: + session.add(federation) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.FederationAlreadyExists(uuid=values['uuid']) + return federation + + @oslo_db_api.retry_on_deadlock def destroy_federation(self, federation_id): - session = get_session() - with session.begin(): - query = model_query(models.Federation, session=session) + with _session_for_write() as session: + query = session.query(models.Federation) query = add_identity_filter(query, federation_id) try: @@ -774,10 +809,10 @@ def update_federation(self, federation_id, values): return self._do_update_federation(federation_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_federation(self, federation_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Federation, session=session) + with _session_for_write() as session: + query = session.query(models.Federation) query = add_identity_filter(query, federation_id) try: ref = query.with_for_update().one() @@ -807,23 +842,27 @@ def _add_nodegoup_filters(self, query, filters): return query + @oslo_db_api.retry_on_deadlock def create_nodegroup(self, values): if not values.get('uuid'): values['uuid'] = uuidutils.generate_uuid() nodegroup = models.NodeGroup() nodegroup.update(values) - try: - nodegroup.save() - except db_exc.DBDuplicateEntry: - raise exception.NodeGroupAlreadyExists( - cluster_id=values['cluster_id'], name=values['name']) - return nodegroup + with _session_for_write() as session: + try: + session.add(nodegroup) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.NodeGroupAlreadyExists( + cluster_id=values['cluster_id'], name=values['name']) + return nodegroup + + @oslo_db_api.retry_on_deadlock def destroy_nodegroup(self, cluster_id, nodegroup_id): - session = get_session() - with session.begin(): - query = model_query(models.NodeGroup, session=session) + with _session_for_write() as session: + query = session.query(models.NodeGroup) query = add_identity_filter(query, nodegroup_id) query = query.filter_by(cluster_id=cluster_id) try: @@ -835,10 +874,10 @@ def destroy_nodegroup(self, cluster_id, nodegroup_id): def update_nodegroup(self, cluster_id, nodegroup_id, values): return self._do_update_nodegroup(cluster_id, nodegroup_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_nodegroup(self, cluster_id, nodegroup_id, values): - session = get_session() - with session.begin(): - query = model_query(models.NodeGroup, session=session) + with _session_for_write() as session: + query = session.query(models.NodeGroup) query = add_identity_filter(query, nodegroup_id) query = query.filter_by(cluster_id=cluster_id) try: @@ -850,56 +889,61 @@ def _do_update_nodegroup(self, cluster_id, nodegroup_id, values): return ref def get_nodegroup_by_id(self, context, cluster_id, nodegroup_id): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(id=nodegroup_id) - try: - return query.one() - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_id) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(id=nodegroup_id) + try: + return query.one() + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_id) def get_nodegroup_by_uuid(self, context, cluster_id, nodegroup_uuid): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(uuid=nodegroup_uuid) - try: - return query.one() - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(uuid=nodegroup_uuid) + try: + return query.one() + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid) def get_nodegroup_by_name(self, context, cluster_id, nodegroup_name): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(name=nodegroup_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple nodegroups exist with same ' - 'name. Please use the nodegroup uuid ' - 'instead.') - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_name) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(name=nodegroup_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict('Multiple nodegroups exist with same ' + 'name. Please use the nodegroup uuid ' + 'instead.') + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_name) def list_cluster_nodegroups(self, context, cluster_id, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = self._add_nodegoup_filters(query, filters) - return _paginate_query(models.NodeGroup, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = self._add_nodegoup_filters(query, filters) + return _paginate_query(models.NodeGroup, limit, marker, + sort_key, sort_dir, query) def get_cluster_nodegroup_count(self, context, cluster_id): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - return query.count() + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + return query.count() diff --git a/magnum/db/sqlalchemy/models.py b/magnum/db/sqlalchemy/models.py index 8d093d6690..f5c494b9ae 100644 --- a/magnum/db/sqlalchemy/models.py +++ b/magnum/db/sqlalchemy/models.py @@ -90,14 +90,6 @@ def as_dict(self): d[c.name] = self[c.name] return d - def save(self, session=None): - import magnum.db.sqlalchemy.api as db_api - - if session is None: - session = db_api.get_session() - - super(MagnumBase, self).save(session) - Base = declarative_base(cls=MagnumBase) diff --git a/magnum/tests/unit/db/base.py b/magnum/tests/unit/db/base.py index 711d30caeb..d78d8fa378 100644 --- a/magnum/tests/unit/db/base.py +++ b/magnum/tests/unit/db/base.py @@ -16,10 +16,10 @@ """Magnum DB test base class.""" import fixtures +from oslo_db.sqlalchemy import enginefacade import magnum.conf from magnum.db import api as dbapi -from magnum.db.sqlalchemy import api as sqla_api from magnum.db.sqlalchemy import migration from magnum.db.sqlalchemy import models from magnum.tests import base @@ -32,16 +32,15 @@ class Database(fixtures.Fixture): - def __init__(self, db_api, db_migrate, sql_connection): + def __init__(self, engine, db_migrate, sql_connection): self.sql_connection = sql_connection - self.engine = db_api.get_engine() + self.engine = engine self.engine.dispose() - conn = self.engine.connect() - self.setup_sqlite(db_migrate) - self.post_migrations() - - self._DB = "".join(line for line in conn.connection.iterdump()) + with self.engine.connect() as conn: + self.setup_sqlite(db_migrate) + self.post_migrations() + self._DB = "".join(line for line in conn.connection.iterdump()) self.engine.dispose() def setup_sqlite(self, db_migrate): @@ -50,9 +49,10 @@ def setup_sqlite(self, db_migrate): models.Base.metadata.create_all(self.engine) db_migrate.stamp('head') - def _setUp(self): - conn = self.engine.connect() - conn.connection.executescript(self._DB) + def setUp(self): + super(Database, self).setUp() + with self.engine.connect() as conn: + conn.connection.executescript(self._DB) self.addCleanup(self.engine.dispose) def post_migrations(self): @@ -68,6 +68,8 @@ def setUp(self): global _DB_CACHE if not _DB_CACHE: - _DB_CACHE = Database(sqla_api, migration, + engine = enginefacade.writer.get_engine() + _DB_CACHE = Database(engine, migration, sql_connection=CONF.database.connection) + engine.dispose() self.useFixture(_DB_CACHE) diff --git a/magnum/tests/unit/db/sqlalchemy/test_types.py b/magnum/tests/unit/db/sqlalchemy/test_types.py index b9a2c1103a..4e132799c8 100644 --- a/magnum/tests/unit/db/sqlalchemy/test_types.py +++ b/magnum/tests/unit/db/sqlalchemy/test_types.py @@ -22,6 +22,7 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase): + #TODO(tylerchristie): these tests need fixing def test_JSONEncodedDict_default_value(self): # Create ClusterTemplate w/o labels cluster_template1_id = uuidutils.generate_uuid()