Skip to content

Commit

Permalink
[db] Fix DB opening and closing
Browse files Browse the repository at this point in the history
This PR fixes issues that could occur when accessing the DB using
multiple processes:
- When opening the DB, check if the DB already exists.
- When closing the DB, properly close the connection to the DB
  engine.

Moreover, also check if the user accidentally did not provide the
DB file path with the .db extension. If this happens, simply add
it to the path.

Signed-off-by: Pascal Nasahl <[email protected]>
  • Loading branch information
nasahlpa committed Jan 16, 2024
1 parent 6cf6140 commit 1cdcc67
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
24 changes: 19 additions & 5 deletions capture/project_library/ot_trace_library/trace_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,16 @@ class TraceLibrary:
"""
def __init__(self, db_name, trace_threshold, wave_datatype = np.uint16,
overwrite = False):
# If .db extension is not provided, add it to the file.
if not db_name.endswith(".db"):
db_name = db_name + ".db"
# If overwrite flag is set, delete existing DB.
if overwrite:
trace_lib_file = Path(db_name + ".db")
trace_lib_file.unlink(missing_ok=True)
self.engine = db.create_engine("sqlite:///" + db_name + ".db")
db_utils.create_database(self.engine.url)
Path(db_name).unlink(missing_ok=True)
# Create or open database.
self.engine = db.create_engine("sqlite:///" + db_name)
if not db_utils.database_exists(self.engine.url):
db_utils.create_database(self.engine.url)
self.session = sessionmaker(self.engine)()
self.metadata = db.MetaData()
self.traces_table = db.Table(
Expand All @@ -71,11 +76,20 @@ def __init__(self, db_name, trace_threshold, wave_datatype = np.uint16,
self.metadata,
db.Column("data", db.PickleType)
)
self.metadata.create_all(self.engine)
self.metadata.create_all(self.engine, checkfirst=True)
self.trace_mem = []
self.trace_mem_thr = trace_threshold
self.wave_datatype = wave_datatype

def close(self, save: bool):
""" Close database.
Args:
save: Save data to database.
"""
if save:
self.flush_to_disk()
self.engine.dispose()

def flush_to_disk(self):
""" Writes traces from memory into database.
"""
Expand Down
2 changes: 1 addition & 1 deletion capture/project_library/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def close(self, save: bool) -> None:
if self.project_cfg.type == "cw":
self.project.close(save = save)
elif self.project_cfg.type == "ot_trace_library":
self.project.flush_to_disk()
self.project.close(save = save)

self.project = None

Expand Down

0 comments on commit 1cdcc67

Please sign in to comment.