Skip to content

Commit

Permalink
Replace 'state_name' dict with lambda function in ddpesgeo
Browse files Browse the repository at this point in the history
Should save a bit of memory
  • Loading branch information
wenzhaojia2000 committed Nov 7, 2023
1 parent ac27d41 commit e940d74
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions source/analyse/analysis_gui/analysis/direct_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,11 @@ def _ddpesgeoV4(self, con:sqlite3.Connection, cur:sqlite3.Cursor):

if self.ddpesgeo_type.currentIndex() == 0:
table = 'pes'
# dictionary mapping states to column names
state_name = {s: f'eng_{s}_{s}' for s in range(1, nroot+1)}
# lambda mapping states to column names
state_name = lambda s: f'eng_{s}_{s}'
else:
table = 'apes'
state_name = {s: f'eng_{s}' for s in range(1, nroot+1)}
state_name = lambda s: f'eng_{s}'

if self.ddpesgeo_task[0].isChecked():
# task is find energies between interval
Expand All @@ -300,7 +300,7 @@ def _ddpesgeoV4(self, con:sqlite3.Connection, cur:sqlite3.Cursor):
# retrieve matching id + energies
for s in range(1, nroot+1):
query = (f'SELECT * FROM {table} LEFT JOIN geo USING(id) '
f'WHERE {state_name[s]} BETWEEN {emin} AND {emax};')
f'WHERE {state_name(s)} BETWEEN {emin} AND {emax};')
res = cur.execute(query).fetchall()
# add id, energies, geo. split geo into geo_length subarrays
# so there are 3 columns
Expand All @@ -322,7 +322,7 @@ def _ddpesgeoV4(self, con:sqlite3.Connection, cur:sqlite3.Cursor):
continue
else:
query = (f'SELECT * FROM {table} LEFT JOIN geo USING(id)'
f'WHERE ABS({state_name[s2]} - {state_name[s1]}) <= {tol};')
f'WHERE ABS({state_name(s2)} - {state_name(s1)}) <= {tol};')
res = cur.execute(query).fetchall()
# add id, energies, geo. split geo into geo_length subarrays
# so there are 3 columns
Expand All @@ -348,7 +348,7 @@ def _ddpesgeoV4(self, con:sqlite3.Connection, cur:sqlite3.Cursor):
# format col_names and energies into a table
# if relevant state, make energy header and value bold
for i, name in enumerate(col_names):
if name in [state_name[state] for state in states]:
if name in [state_name(state) for state in states]:
header.append(f'<b>{name:>15}</b>')
values.append('<b>' + '{: .8e}'.format(energies[i]) + '</b>')
else:
Expand Down

0 comments on commit e940d74

Please sign in to comment.