Skip to content

Commit

Permalink
fix history db do_limit_num_row_for not working properly
Browse files Browse the repository at this point in the history
  • Loading branch information
visualDust committed Dec 20, 2023
1 parent c0bb43f commit e1be05a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
21 changes: 16 additions & 5 deletions neetbox/server/db/_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from datetime import datetime
from typing import Union

from ._condition import *
from neetbox._protocol import *
from neetbox.utils import ResourceLoader
from neetbox.logging import LogStyle
from neetbox.logging.logger import Logger
from neetbox.utils import ResourceLoader

from ._condition import *

logger = Logger("NEETBOX", LogStyle(skip_writers=["ws"]))

Expand Down Expand Up @@ -243,15 +244,21 @@ def get_series_of_table(self, table_name, run_id=None):
result, _ = self._query(sql_query, *args, fetch=DbQueryFetchType.ALL)
return [result for (result,) in result]

def do_limit_num_row_for(self, table_name: str, run_id: str, num_row_limit: int):
def do_limit_num_row_for(
self, table_name: str, run_id: str, num_row_limit: int, series: str = None
):
if num_row_limit <= 0: # no limit or random not triggered
return
sql_query = f"SELECT count(*) from {table_name} WHERE {RUN_ID_COLUMN_NAME} = {run_id}" # count rows for runid in specific table
if series is not None:
sql_query += f" AND {SERIES_COLUMN_NAME} = '{series}'"
num_rows, _ = self._query(sql_query, fetch=DbQueryFetchType.ONE)
if num_rows[0] > num_row_limit: # num rows exceeded limit
sql_query = f"SELECT {ID_COLUMN_NAME} from {table_name} WHERE {RUN_ID_COLUMN_NAME} = {run_id} ORDER BY {ID_COLUMN_NAME} DESC LIMIT 1 OFFSET {num_row_limit - 1}" # get max id to delete of row for specific run id
max_id_to_del, _ = self._query(sql_query, fetch=DbQueryFetchType.ONE)
sql_query = f"DELETE FROM {table_name} WHERE {ID_COLUMN_NAME} < {max_id_to_del[0]} AND {RUN_ID_COLUMN_NAME} = {run_id}"
if series is not None:
sql_query += f" AND {SERIES_COLUMN_NAME} = '{series}'"
self._query(sql_query) # delete rows with smaller id and specific run id

def write_json(
Expand All @@ -278,7 +285,9 @@ def write_json(
if isinstance(json_data, dict):
json_data = json.dumps(json_data)
_, lastrowid = self._execute(sql_query, timestamp, series, run_id, json_data)
self.do_limit_num_row_for(table_name, run_id, num_row_limit)
self.do_limit_num_row_for(
table_name=table_name, run_id=run_id, num_row_limit=num_row_limit, series=series
)
return lastrowid

def read_json(self, table_name: str, condition: QueryCondition = None):
Expand Down Expand Up @@ -366,7 +375,9 @@ def write_blob(

sql_query = f"INSERT INTO {table_name}({TIMESTAMP_COLUMN_NAME}, {SERIES_COLUMN_NAME}, {RUN_ID_COLUMN_NAME}, {METADATA_COLUMN_NAME}, {BLOB_COLUMN_NAME}) VALUES (?, ?, ?, ?, ?)"
_, lastrowid = self._execute(sql_query, timestamp, series, run_id, meta_data, blob_data)
self.do_limit_num_row_for(table_name, run_id, num_row_limit)
self.do_limit_num_row_for(
table_name=table_name, run_id=run_id, num_row_limit=num_row_limit, series=series
)
return lastrowid

def read_blob(self, table_name: str, condition: QueryCondition = None, meta_only=False):
Expand Down
19 changes: 14 additions & 5 deletions tests/client/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,20 @@ def log_with_some_prefix():
logger.err("some error")


train_config = {"epoch": 9999}
train_config = {"epoch": 99, "batch_size": 10}


def train_epoch(config):
def train_batch_in_epoch(num_batch):
for i in neetbox.progress(num_batch):
time.sleep(1)
train(i)

with neetbox.progress(config["epoch"]) as progress:
for _ in progress:
train_batch_in_epoch(config["batch_size"])


if __name__ == "__main__":
neetbox.add_hyperparams(train_config)
with neetbox.progress(train_config["epoch"]) as epochs:
for i in epochs:
sleep(1)
train(i)
train_epoch(train_config)

0 comments on commit e1be05a

Please sign in to comment.