diff --git a/capture/project_library/ot_trace_library/trace_library.py b/capture/project_library/ot_trace_library/trace_library.py index 4c10c89b..c910fd73 100644 --- a/capture/project_library/ot_trace_library/trace_library.py +++ b/capture/project_library/ot_trace_library/trace_library.py @@ -47,11 +47,16 @@ class TraceLibrary: """ def __init__(self, db_name, trace_threshold, wave_datatype = np.uint16, overwrite = False): + # If .db extension is not provided, add it to the file. + if not db_name.endswith(".db"): + db_name = db_name + ".db" + # If overwrite flag is set, delete existing DB. if overwrite: - trace_lib_file = Path(db_name + ".db") - trace_lib_file.unlink(missing_ok=True) - self.engine = db.create_engine("sqlite:///" + db_name + ".db") - db_utils.create_database(self.engine.url) + Path(db_name).unlink(missing_ok=True) + # Create or open database. + self.engine = db.create_engine("sqlite:///" + db_name) + if not db_utils.database_exists(self.engine.url): + db_utils.create_database(self.engine.url) self.session = sessionmaker(self.engine)() self.metadata = db.MetaData() self.traces_table = db.Table( @@ -71,11 +76,20 @@ def __init__(self, db_name, trace_threshold, wave_datatype = np.uint16, self.metadata, db.Column("data", db.PickleType) ) - self.metadata.create_all(self.engine) + self.metadata.create_all(self.engine, checkfirst=True) self.trace_mem = [] self.trace_mem_thr = trace_threshold self.wave_datatype = wave_datatype + def close(self, save: bool): + """ Close database. + Args: + save: Save data to database. + """ + if save: + self.flush_to_disk() + self.engine.dispose() + def flush_to_disk(self): """ Writes traces from memory into database. """ diff --git a/capture/project_library/project.py b/capture/project_library/project.py index 9b35bd44..2866268c 100644 --- a/capture/project_library/project.py +++ b/capture/project_library/project.py @@ -72,7 +72,7 @@ def close(self, save: bool) -> None: if self.project_cfg.type == "cw": self.project.close(save = save) elif self.project_cfg.type == "ot_trace_library": - self.project.flush_to_disk() + self.project.close(save = save) self.project = None