diff --git a/neetbox/server/db/_history.py b/neetbox/server/db/_history.py index d2007933..580d3eb2 100644 --- a/neetbox/server/db/_history.py +++ b/neetbox/server/db/_history.py @@ -11,11 +11,12 @@ from datetime import datetime from typing import Union -from ._condition import * from neetbox._protocol import * -from neetbox.utils import ResourceLoader from neetbox.logging import LogStyle from neetbox.logging.logger import Logger +from neetbox.utils import ResourceLoader + +from ._condition import * logger = Logger("NEETBOX", LogStyle(skip_writers=["ws"])) @@ -243,15 +244,21 @@ def get_series_of_table(self, table_name, run_id=None): result, _ = self._query(sql_query, *args, fetch=DbQueryFetchType.ALL) return [result for (result,) in result] - def do_limit_num_row_for(self, table_name: str, run_id: str, num_row_limit: int): + def do_limit_num_row_for( + self, table_name: str, run_id: str, num_row_limit: int, series: str = None + ): if num_row_limit <= 0: # no limit or random not triggered return sql_query = f"SELECT count(*) from {table_name} WHERE {RUN_ID_COLUMN_NAME} = {run_id}" # count rows for runid in specific table + if series is not None: + sql_query += f" AND {SERIES_COLUMN_NAME} = '{series}'" num_rows, _ = self._query(sql_query, fetch=DbQueryFetchType.ONE) if num_rows[0] > num_row_limit: # num rows exceeded limit sql_query = f"SELECT {ID_COLUMN_NAME} from {table_name} WHERE {RUN_ID_COLUMN_NAME} = {run_id} ORDER BY {ID_COLUMN_NAME} DESC LIMIT 1 OFFSET {num_row_limit - 1}" # get max id to delete of row for specific run id max_id_to_del, _ = self._query(sql_query, fetch=DbQueryFetchType.ONE) sql_query = f"DELETE FROM {table_name} WHERE {ID_COLUMN_NAME} < {max_id_to_del[0]} AND {RUN_ID_COLUMN_NAME} = {run_id}" + if series is not None: + sql_query += f" AND {SERIES_COLUMN_NAME} = '{series}'" self._query(sql_query) # delete rows with smaller id and specific run id def write_json( @@ -278,7 +285,9 @@ def write_json( if isinstance(json_data, dict): json_data = json.dumps(json_data) _, lastrowid = self._execute(sql_query, timestamp, series, run_id, json_data) - self.do_limit_num_row_for(table_name, run_id, num_row_limit) + self.do_limit_num_row_for( + table_name=table_name, run_id=run_id, num_row_limit=num_row_limit, series=series + ) return lastrowid def read_json(self, table_name: str, condition: QueryCondition = None): @@ -366,7 +375,9 @@ def write_blob( sql_query = f"INSERT INTO {table_name}({TIMESTAMP_COLUMN_NAME}, {SERIES_COLUMN_NAME}, {RUN_ID_COLUMN_NAME}, {METADATA_COLUMN_NAME}, {BLOB_COLUMN_NAME}) VALUES (?, ?, ?, ?, ?)" _, lastrowid = self._execute(sql_query, timestamp, series, run_id, meta_data, blob_data) - self.do_limit_num_row_for(table_name, run_id, num_row_limit) + self.do_limit_num_row_for( + table_name=table_name, run_id=run_id, num_row_limit=num_row_limit, series=series + ) return lastrowid def read_blob(self, table_name: str, condition: QueryCondition = None, meta_only=False): diff --git a/tests/client/test.py b/tests/client/test.py index d04f4a2c..85ccdc77 100644 --- a/tests/client/test.py +++ b/tests/client/test.py @@ -132,11 +132,20 @@ def log_with_some_prefix(): logger.err("some error") -train_config = {"epoch": 9999} +train_config = {"epoch": 99, "batch_size": 10} + + +def train_epoch(config): + def train_batch_in_epoch(num_batch): + for i in neetbox.progress(num_batch): + time.sleep(1) + train(i) + + with neetbox.progress(config["epoch"]) as progress: + for _ in progress: + train_batch_in_epoch(config["batch_size"]) + if __name__ == "__main__": neetbox.add_hyperparams(train_config) - with neetbox.progress(train_config["epoch"]) as epochs: - for i in epochs: - sleep(1) - train(i) + train_epoch(train_config)