Skip to content

Commit

Permalink
Fix: Simplify challenge_state update logic to prevent missing data (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiichi3227 authored Oct 11, 2024
1 parent 55b4aab commit 5832a7c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 108 deletions.
50 changes: 50 additions & 0 deletions migration/202410101010_challenge_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
async def dochange(db, rs):
await db.execute('DROP TABLE last_update_time;')
await db.execute('DROP INDEX idx_test_last_modified;')
await db.execute('ALTER TABLE test DROP COLUMN last_modified')
await db.execute('DROP TRIGGER IF EXISTS test_last_modified_trigger ON test')
await db.execute('DROP FUNCTION IF EXISTS update_test_last_modified()')
await db.execute('DROP FUNCTION IF EXISTS refresh_challenge_state_incremental()')

await db.execute('CREATE INDEX challenge_state_idx_chal_id ON challenge_state USING btree (chal_id);')
await db.execute(
'''
CREATE OR REPLACE FUNCTION update_challenge_state(p_chal_id INTEGER)
RETURNS VOID AS $$
BEGIN
WITH challenge_summary AS (
SELECT
t.chal_id,
MAX(t.state) AS max_state,
SUM(t.runtime) AS total_runtime,
SUM(t.memory) AS total_memory,
SUM(CASE WHEN t.state = 1 THEN tvr.rate ELSE 0 END) AS total_rate
FROM test t
LEFT JOIN test_valid_rate tvr ON t.pro_id = tvr.pro_id AND t.test_idx = tvr.test_idx
WHERE t.chal_id = p_chal_id
GROUP BY t.chal_id
)
INSERT INTO challenge_state (chal_id, state, runtime, memory, rate)
SELECT
chal_id,
max_state,
total_runtime,
total_memory,
total_rate
FROM challenge_summary
ON CONFLICT (chal_id) DO UPDATE
SET
state = EXCLUDED.state,
runtime = EXCLUDED.runtime,
memory = EXCLUDED.memory,
rate = EXCLUDED.rate
WHERE
challenge_state.state != EXCLUDED.state OR
challenge_state.runtime != EXCLUDED.runtime OR
challenge_state.memory != EXCLUDED.memory OR
challenge_state.rate != EXCLUDED.rate;
RETURN;
END;
$$ LANGUAGE plpgsql;
''')
66 changes: 9 additions & 57 deletions src/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
MAX_WAIT_SECONDS_BEFORE_SHUTDOWN = 0


def sig_handler(server, db, rs, pool, cov, view_task, sig, frame):
def sig_handler(server, db, rs, pool, cov, sig, frame):
io_loop = tornado.ioloop.IOLoop.current()

def stop_loop(deadline):
Expand All @@ -35,16 +35,13 @@ def stop_loop(deadline):
print("Waiting for next tick")
io_loop.add_timeout(now + 1, stop_loop, deadline)
else:
view_task.kill()
for task in asyncio.all_tasks():
task.cancel()

io_loop.run_in_executor(func=db.close, executor=None)
io_loop.run_in_executor(func=rs.aclose, executor=None)
io_loop.run_in_executor(func=pool.aclose, executor=None)
io_loop.run_in_executor(
func=JudgeServerClusterService.inst.disconnect_all_server, executor=None
)
io_loop.add_callback(db.close)
io_loop.add_callback(rs.aclose)
io_loop.add_callback(pool.aclose)
io_loop.add_callback(JudgeServerClusterService.inst.disconnect_all_server)
io_loop.stop()

print("Shutdown finally")
Expand All @@ -60,35 +57,6 @@ def shutdown():
print("Caught signal: %s" % sig)
io_loop.add_callback_from_signal(shutdown)


async def materialized_view_task():
db = await asyncpg.connect(
database=TestConfig.DBNAME_OJ,
user=TestConfig.DBUSER_OJ,
password=TestConfig.DBPW_OJ,
host="localhost",
)
rs = await aioredis.Redis(host="localhost", port=6379, db=TestConfig.REDIS_DB)
p = rs.pubsub()
await p.subscribe("materialized_view_req")

async def _update():
ret = await rs.incr("materialized_view_counter") - 1
await db.execute("SELECT refresh_challenge_state_incremental();")
return ret

counter = await _update()
async for msg in p.listen():
if msg["type"] != "message":
continue

ind = int(msg["data"])
if ind <= counter:
continue

counter = await _update()


testing_loop = asyncio.get_event_loop()
if not os.path.exists('db-inited'):
subprocess.run(
Expand Down Expand Up @@ -118,21 +86,7 @@ async def _update():
if __name__ == "__main__":
e = multiprocessing.Event()

def run_materialized_view_task():
signal.signal(signal.SIGINT, lambda _, __: loop.stop())
signal.signal(signal.SIGTERM, lambda _, __: loop.stop())
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(materialized_view_task())
loop.run_forever()

finally:
loop.stop()
loop.close()

view_task_process = multiprocessing.Process(target=run_materialized_view_task)

def m(event, view_task):
def m(event):
asyncio.set_event_loop(asyncio.new_event_loop())
cov = coverage.Coverage(data_file=f".coverage.{os.getpid()}", branch=True)
cov.start()
Expand Down Expand Up @@ -167,11 +121,11 @@ def m(event, view_task):

signal.signal(
signal.SIGINT,
functools.partial(sig_handler, httpsrv, db2, rs2, pool2, cov, view_task),
functools.partial(sig_handler, httpsrv, db2, rs2, pool2, cov),
)
signal.signal(
signal.SIGTERM,
functools.partial(sig_handler, httpsrv, db2, rs2, pool2, cov, view_task),
functools.partial(sig_handler, httpsrv, db2, rs2, pool2, cov),
)

try:
Expand All @@ -181,13 +135,11 @@ def m(event, view_task):
pass

asyncio.get_event_loop().run_until_complete(rs.flushall())
view_task_process.start()
main_process = multiprocessing.Process(target=m, args=(e, view_task_process))
main_process = multiprocessing.Process(target=m, args=(e,))
main_process.start()

while e.wait():
services_init(db, rs)
test_main(testing_loop)
view_task_process.terminate()
main_process.terminate()
break
44 changes: 0 additions & 44 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import signal
import time
from multiprocessing import Process

import asyncpg
import tornado.httpserver
Expand Down Expand Up @@ -31,8 +30,6 @@ def stop_loop(deadline):
print('Waiting for next tick')
io_loop.add_timeout(now + 1, stop_loop, deadline)
else:
view_task_process.kill()

for task in asyncio.all_tasks():
task.cancel()

Expand All @@ -53,50 +50,9 @@ def shutdown():
print('Caught signal: %s' % sig)
io_loop.add_callback_from_signal(shutdown)


async def materialized_view_task():
db = await asyncpg.connect(
database=config.DBNAME_OJ, user=config.DBUSER_OJ, password=config.DBPW_OJ, host='localhost'
)
rs = await aioredis.Redis(host='localhost', port=6379, db=config.REDIS_DB)
p = rs.pubsub()
await p.subscribe('materialized_view_req')

async def _update():
ret = await rs.incr('materialized_view_counter') - 1
await db.execute('SELECT refresh_challenge_state_incremental();')
return ret

counter = await _update()
async for msg in p.listen():
if msg['type'] != 'message':
continue

ind = int(msg['data'])
if ind <= counter:
continue

counter = await _update()


if __name__ == "__main__":
httpsock = tornado.netutil.bind_sockets(config.PORT)

def run_materialized_view_task():
signal.signal(signal.SIGINT, lambda _, __: loop.stop())
signal.signal(signal.SIGTERM, lambda _, __: loop.stop())
try:
loop = asyncio.new_event_loop()
loop.run_until_complete(materialized_view_task())
loop.run_forever()

finally:
loop.stop()
loop.close()

view_task_process = Process(target=run_materialized_view_task)
view_task_process.start()

# tornado.process.fork_processes(4)
db: asyncpg.Pool = asyncio.get_event_loop().run_until_complete(
asyncpg.create_pool(database=config.DBNAME_OJ, user=config.DBUSER_OJ, password=config.DBPW_OJ, host='localhost')
Expand Down
12 changes: 7 additions & 5 deletions src/services/chal.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ async def reset_chal(self, chal_id):
async with self.db.acquire() as con:
await con.execute('DELETE FROM "test" WHERE "chal_id" = $1;', chal_id)

await self.rs.publish('materialized_view_req', (await self.rs.get('materialized_view_counter')))

await self.update_challenge_state(chal_id)
return None, None

async def get_chal_state(self, chal_id):
Expand Down Expand Up @@ -331,14 +330,14 @@ async def emit_chal(self, chal_id, pro_id, testm_conf, comp_type, pri: int):
'''
)

await self.rs.publish('materialized_view_req', (await self.rs.get('materialized_view_counter')))
await self.update_challenge_state(chal_id)

file_ext = ChalConst.FILE_EXTENSION[comp_type]

if not os.path.isfile(f"code/{chal_id}/main.{file_ext}"):
for test in testl:
await self.update_test(chal_id, test['test_idx'], ChalConst.STATE_ERR, 0, 0, '', refresh_db=False)
await self.rs.publish('materialized_view_req', (await self.rs.get('materialized_view_counter')))
await self.update_challenge_state(chal_id)
return None, None

chalmeta = testm_conf['chalmeta']
Expand Down Expand Up @@ -498,6 +497,9 @@ async def update_test(self, chal_id, test_idx, state, runtime, memory, response,
)

if refresh_db:
await self.rs.publish('materialized_view_req', (await self.rs.get('materialized_view_counter')))
await self.update_challenge_state(chal_id)

return None, None

async def update_challenge_state(self, chal_id: int):
await self.db.execute(f'SELECT update_challenge_state({chal_id});')
2 changes: 1 addition & 1 deletion src/services/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def response_handle(self, ret):
)

self.running_chal_cnt -= 1
await self.rs.publish('materialized_view_req', (await self.rs.get('materialized_view_counter')))
await ChalService.inst.update_challenge_state(res['chal_id'])

await self.rs.publish('chalstatesub', res['chal_id'])
await self.rs.publish('challiststatesub', res['chal_id'])
Expand Down
1 change: 0 additions & 1 deletion src/services/pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ async def update_test_config(self, pro_id, testm_conf: dict):
await self.db.execute("REFRESH MATERIALIZED VIEW test_valid_rate;")
await self.rs.delete('rate')
await self.rs.hdel('pro_rate', pro_id)
await self.rs.publish('materialized_view_req', (await self.rs.get('materialized_view_counter')))

return None, None

Expand Down

0 comments on commit 5832a7c

Please sign in to comment.