Skip to content

Commit

Permalink
Mem (1/3): Core DAL improvements
Browse files Browse the repository at this point in the history
- Improve type hints
  • Loading branch information
cjao committed Sep 25, 2023
1 parent f792432 commit cb8ae1d
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 76 deletions.
26 changes: 18 additions & 8 deletions covalent_dispatcher/_dal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

"""Base class for server-side analogues of workflow data types"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Union
from abc import abstractmethod
from typing import Any, Dict, Generator, Generic, List, Type, TypeVar, Union

from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
Expand All @@ -30,8 +30,14 @@
from . import controller
from .asset import FIELDS, Asset

# Metadata
MetaType = TypeVar("MetaType", bound=controller.Record)

class DispatchedObject(ABC):
# Asset links
AssetLinkType = TypeVar("AssetLinkType", bound=controller.Record)


class DispatchedObject(Generic[MetaType, AssetLinkType]):
"""Base class for types with both metadata and assets.
Each subclass must define two properties:
Expand All @@ -42,13 +48,13 @@ class DispatchedObject(ABC):

@classmethod
@property
def meta_type(cls) -> type(controller.Record):
def meta_type(cls) -> Type[MetaType]:
"""Returns the metadata controller class."""
raise NotImplementedError

@classmethod
@property
def asset_link_type(cls) -> type(controller.Record):
def asset_link_type(cls) -> Type[AssetLinkType]:
"""Returns the asset link controller class"""
raise NotImplementedError

Expand All @@ -64,7 +70,7 @@ def query_keys(self) -> set:

@property
@abstractmethod
def metadata(self) -> controller.Record:
def metadata(self) -> MetaType:
raise NotImplementedError

@property
Expand All @@ -85,13 +91,17 @@ def get_asset_ids(self, session: Session, keys: List[str]) -> Dict[str, int]:
)
return {x.key: x.asset_id for x in records}

def associate_asset(self, session: Session, key: str, asset_id: int):
def associate_asset(
self, session: Session, key: str, asset_id: int, flush: bool = False
) -> AssetLinkType:
asset_link_kwargs = {
"meta_id": self._id,
"asset_id": asset_id,
"key": key,
}
type(self).asset_link_type.insert(session, insert_kwargs=asset_link_kwargs, flush=False)
return type(self).asset_link_type.create(
session, insert_kwargs=asset_link_kwargs, flush=flush
)

@property
@abstractmethod
Expand Down
13 changes: 7 additions & 6 deletions covalent_dispatcher/_dal/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import annotations

from typing import Generic, TypeVar
from typing import Generic, Type, TypeVar

from sqlalchemy import select, update
from sqlalchemy.orm import Session, load_only
Expand All @@ -38,7 +38,7 @@ class Record(Generic[T]):

@classmethod
@property
def model(cls) -> type(T):
def model(cls) -> Type[T]:
raise NotImplementedError

def __init__(self, session: Session, record: models.Base, *, fields: list):
Expand Down Expand Up @@ -89,16 +89,17 @@ def get_by_primary_key(
return session.get(cls.model, primary_key, with_for_update=for_update)

@classmethod
def insert(cls, session: Session, *, insert_kwargs: dict, flush: bool = True) -> T:
"""INSERT a record into the DB.
def create(cls, session: Session, *, insert_kwargs: dict, flush: bool = True) -> T:
"""Create a new record.
Args:
session: SQLalchemy session
insert_kwargs: kwargs to pass to the model constructor
flush: Whether to flush the session immediately
Returns:
The bound record
Returns: A SQLAlchemy model of type T. If `flush=False`, the
model will need to be added to the session manually.
"""

new_record = cls.model(**insert_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion covalent_dispatcher/_dal/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ElectronAsset(Record[models.ElectronAsset]):
model = models.ElectronAsset


class Electron(DispatchedObject):
class Electron(DispatchedObject[ElectronMeta, ElectronAsset]):
meta_type = ElectronMeta
asset_link_type = ElectronAsset

Expand Down
6 changes: 3 additions & 3 deletions covalent_dispatcher/_dal/importers/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def import_electron(
node_storage_path = asset_recs["function"].storage_path

electron_kwargs = _get_electron_meta(e, lat, node_storage_path, job_id)
electron_row = ElectronMeta.insert(session, insert_kwargs=electron_kwargs, flush=False)
electron_row = ElectronMeta.create(session, insert_kwargs=electron_kwargs, flush=False)

return (
electron_row,
Expand Down Expand Up @@ -158,7 +158,7 @@ def import_electron_assets(
"remote_uri": asset.uri,
"size": asset.size,
}
asset_recs[asset_key] = Asset.insert(session, insert_kwargs=asset_kwargs, flush=False)
asset_recs[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False)

# Send this back to the client
asset.digest = None
Expand All @@ -179,7 +179,7 @@ def import_electron_assets(
"remote_uri": asset.uri,
"size": asset.size,
}
asset_recs[asset_key] = Asset.insert(session, insert_kwargs=asset_kwargs, flush=False)
asset_recs[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False)

# Send this back to the client
asset.remote_uri = f"file://{local_uri}" if asset.digest else ""
Expand Down
13 changes: 10 additions & 3 deletions covalent_dispatcher/_dal/importers/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def import_lattice_assets(
"remote_uri": asset.uri,
"size": asset.size,
}
asset_ids[asset_key] = Asset.insert(session, insert_kwargs=asset_kwargs, flush=False)
asset_ids[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False)

# Send this back to the client
asset.digest = None
Expand All @@ -141,15 +141,22 @@ def import_lattice_assets(
"remote_uri": asset.uri,
"size": asset.size,
}
asset_ids[asset_key] = Asset.insert(session, insert_kwargs=asset_kwargs, flush=False)
asset_ids[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False)

# Send this back to the client
asset.remote_uri = f"file://{local_uri}" if asset.digest else ""
asset.digest = None

session.flush()

# Write asset records to DB
session.flush()

# Link assets to lattice
lattice_asset_links = []
for key, asset_rec in asset_ids.items():
record.associate_asset(session, key, asset_rec.id)
lattice_asset_links.append(record.associate_asset(session, key, asset_rec.id))

session.flush()

return lat.assets
103 changes: 67 additions & 36 deletions covalent_dispatcher/_dal/importers/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Functions to transform ResultSchema -> Result"""

import os
from datetime import datetime
from typing import List, Optional, Tuple

from sqlalchemy.orm import Session
Expand Down Expand Up @@ -69,7 +70,7 @@ def import_result(
membership_filters={},
)
if len(records) > 0:
return _connect_result_to_electron(res, electron_id)
return _connect_result_to_electron(session, res, electron_id)

# Main case: insert new lattice, electron, edge, and job records

Expand All @@ -80,7 +81,8 @@ def import_result(
lattice_record_kwargs.update(_get_lattice_meta(res.lattice, storage_path))

with Result.session() as session:
lattice_row = ResultMeta.insert(session, insert_kwargs=lattice_record_kwargs, flush=True)
st = datetime.now()
lattice_row = ResultMeta.create(session, insert_kwargs=lattice_record_kwargs, flush=True)
res_record = Result(session, lattice_row, True)
res_assets = import_result_assets(session, res, res_record, local_store)

Expand All @@ -91,7 +93,11 @@ def import_result(
res_record.lattice,
local_store,
)
et = datetime.now()
delta = (et - st).total_seconds()
app_log.debug(f"{dispatch_id}: Inserting lattice took {delta} seconds")

st = datetime.now()
tg = import_transport_graph(
session,
dispatch_id,
Expand All @@ -100,54 +106,64 @@ def import_result(
local_store,
electron_id,
)
et = datetime.now()
delta = (et - st).total_seconds()
app_log.debug(f"{dispatch_id}: Inserting transport graph took {delta} seconds")

lat = LatticeSchema(metadata=res.lattice.metadata, assets=lat_assets, transport_graph=tg)

output = ResultSchema(metadata=res.metadata, assets=res_assets, lattice=lat)
return _filter_remote_uris(output)
st = datetime.now()
filtered_uris = _filter_remote_uris(output)
et = datetime.now()
delta = (et - st).total_seconds()
app_log.debug(f"{dispatch_id}: Filtering URIs took {delta} seconds")
return filtered_uris


def _connect_result_to_electron(res: ResultSchema, parent_electron_id: int) -> ResultSchema:
def _connect_result_to_electron(
session: Session, res: ResultSchema, parent_electron_id: int
) -> ResultSchema:
"""Link a sublattice dispatch to its parent electron"""

# Update the `electron_id` lattice field and propagate the
# `Job.cancel_requested` to the sublattice dispatch's jobs.

app_log.debug("connecting previously submitted subdispatch to parent electron")
sub_result = Result.from_dispatch_id(res.metadata.dispatch_id, bare=True)
with Result.session() as session:
sub_result.set_value("electron_id", parent_electron_id, session)
sub_result.set_value("root_dispatch_id", res.metadata.root_dispatch_id, session)

parent_electron_record = ElectronMeta.get(
session,
fields={"id", "parent_lattice_id", "job_id"},
equality_filters={"id": parent_electron_id},
membership_filters={},
)[0]
parent_job_record = Job.get(
session,
fields={"id", "cancel_requested"},
equality_filters={"id": parent_electron_record.job_id},
membership_filters={},
)[0]
cancel_requested = parent_job_record.cancel_requested

sub_electron_records = ElectronMeta.get(
session,
fields={"id", "parent_lattice_id", "job_id"},
equality_filters={"parent_lattice_id": sub_result._lattice_id},
membership_filters={},
)
sub_result.set_value("electron_id", parent_electron_id, session)
sub_result.set_value("root_dispatch_id", res.metadata.root_dispatch_id, session)

parent_electron_record = ElectronMeta.get(
session,
fields={"id", "parent_lattice_id", "job_id"},
equality_filters={"id": parent_electron_id},
membership_filters={},
)[0]
parent_job_record = Job.get(
session,
fields={"id", "cancel_requested"},
equality_filters={"id": parent_electron_record.job_id},
membership_filters={},
)[0]
cancel_requested = parent_job_record.cancel_requested

sub_electron_records = ElectronMeta.get(
session,
fields={"id", "parent_lattice_id", "job_id"},
equality_filters={"parent_lattice_id": sub_result._lattice_id},
membership_filters={},
)

job_ids = [rec.job_id for rec in sub_electron_records]
job_ids = [rec.job_id for rec in sub_electron_records]

Job.update_bulk(
session,
values={"cancel_requested": cancel_requested},
equality_filters={},
membership_filters={"id": job_ids},
)
Job.update_bulk(
session,
values={"cancel_requested": cancel_requested},
equality_filters={},
membership_filters={"id": job_ids},
)

return res

Expand Down Expand Up @@ -237,16 +253,31 @@ def import_result_assets(
"remote_uri": asset.uri,
"size": asset.size,
}
asset_ids[asset_key] = Asset.insert(session, insert_kwargs=asset_kwargs, flush=False)
asset_ids[asset_key] = Asset.create(session, insert_kwargs=asset_kwargs, flush=False)

# Send this back to the client
asset.digest = None
asset.remote_uri = f"file://{local_uri}"

# Write asset records to DB
n_records = len(asset_ids)

st = datetime.now()
session.flush()
et = datetime.now()
delta = (et - st).total_seconds()
app_log.debug(f"Inserting {n_records} asset records took {delta} seconds")

result_asset_links = []
for key, asset_rec in asset_ids.items():
record.associate_asset(session, key, asset_rec.id)
result_asset_links.append(record.associate_asset(session, key, asset_rec.id))

n_records = len(result_asset_links)
st = datetime.now()
session.flush()
et = datetime.now()
delta = (et - st).total_seconds()
app_log.debug(f"Inserting {n_records} asset links took {delta} seconds")

return manifest.assets

Expand Down
Loading

0 comments on commit cb8ae1d

Please sign in to comment.