Skip to content

Commit

Permalink
Introduce 'input_selection' trait to interact with editors
Browse files Browse the repository at this point in the history
  • Loading branch information
yakutovicha committed Jan 5, 2023
1 parent 9b6fc31 commit f2d5f19
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 30 deletions.
43 changes: 27 additions & 16 deletions aiidalab_widgets_base/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,19 @@ def _structure_editors(self, editors):
"""Preparing structure editors."""
if editors and len(editors) == 1:
link((editors[0], "structure"), (self, "structure"))

if editors[0].has_trait("input_selection"):
dlink((editors[0], "input_selection"), (self.viewer, "input_selection"))

if editors[0].has_trait("selection"):
link((editors[0], "selection"), (self.viewer, "selection"))
dlink((self.viewer, "selection"), (editors[0], "selection"))

if editors[0].has_trait("camera_orientation"):
dlink(
(self.viewer._viewer, "_camera_orientation"),
(editors[0], "camera_orientation"),
) # pylint: disable=protected-access

return editors[0]

# If more than one editor was defined.
Expand Down Expand Up @@ -942,6 +948,7 @@ class BasicStructureEditor(ipw.VBox): # pylint: disable=too-many-instance-attri
position of periodic structure in cell) editing."""

structure = Instance(Atoms, allow_none=True)
input_selection = List(Int, allow_none=True)
selection = List(Int)
camera_orientation = List()

Expand Down Expand Up @@ -1206,13 +1213,15 @@ def def_point(self, _=None):
"""Define the action point."""
self.point.value = self.vec2str(self.sel2com())
if self.autoclear_selection.value:
self.selection = []
self.input_selection = None
self.input_selection = []

def def_axis_p1(self, _=None):
"""Define the first point of axis."""
self.axis_p1.value = self.vec2str(self.sel2com())
if self.autoclear_selection.value:
self.selection = []
self.input_selection = None
self.input_selection = []

def def_axis_p2(self, _=None):
"""Define the second point of axis."""
Expand All @@ -1230,7 +1239,8 @@ def def_axis_p2(self, _=None):
)
self.axis_p2.value = self.vec2str(com)
if self.autoclear_selection.value:
self.selection = []
self.input_selection = None
self.input_selection = []

def def_perpendicular_to_screen(self, _=None):
"""Define a normalized vector perpendicular to the screen."""
Expand All @@ -1251,7 +1261,7 @@ def translate_dr(self, _=None, atoms=None, selection=None):
self.action_vector * self.displacement.value
)

self.structure, self.selection = atoms, selection
self.structure, self.input_selection = atoms, selection

@_register_structure
@_register_selection
Expand All @@ -1261,7 +1271,7 @@ def translate_dxdydz(self, _=None, atoms=None, selection=None):
# The action.
atoms.positions[self.selection] += np.array(self.str2vec(self.dxyz.value))

self.structure, self.selection = atoms, selection
self.structure, self.input_selection = atoms, selection

@_register_structure
@_register_selection
Expand All @@ -1271,7 +1281,7 @@ def translate_to_xyz(self, _=None, atoms=None, selection=None):
geo_center = np.average(self.structure[self.selection].get_positions(), axis=0)
atoms.positions[self.selection] += self.str2vec(self.dxyz.value) - geo_center

self.structure, self.selection = atoms, selection
self.structure, self.input_selection = atoms, selection

@_register_structure
@_register_selection
Expand All @@ -1283,9 +1293,9 @@ def rotate(self, _=None, atoms=None, selection=None):
vec = self.str2vec(self.vec2str(self.action_vector))
center = self.str2vec(self.point.value)
rotated_subset.rotate(self.phi.value, v=vec, center=center, rotate_cell=False)
atoms.positions[list(self.selection)] = rotated_subset.positions
atoms.positions[self.selection] = rotated_subset.positions

self.structure, self.selection = atoms, selection
self.structure, self.input_selection = atoms, selection

@_register_structure
@_register_selection
Expand Down Expand Up @@ -1318,7 +1328,7 @@ def mirror(self, _=None, norm=None, point=None, atoms=None, selection=None):
# Mirror atoms.
atoms.positions[selection] -= 2 * projections

self.structure, self.selection = atoms, selection
self.structure, self.input_selection = atoms, selection

def mirror_3p(self, _=None):
"""Mirror atoms on the plane containing action vector and action point."""
Expand All @@ -1342,7 +1352,7 @@ def align(self, _=None, atoms=None, selection=None):
subset.rotate(self.action_vector, self.str2vec(self.dxyz.value), center=center)
atoms.positions[selection] = subset.positions

self.structure, self.selection = atoms, selection
self.structure, self.input_selection = atoms, selection

@_register_structure
@_register_selection
Expand Down Expand Up @@ -1373,7 +1383,7 @@ def mod_element(self, _=None, atoms=None, selection=None):
range(last_atom, last_atom + len(selection) * len(lgnd))
)

self.structure, self.selection = atoms, new_selection
self.structure, self.input_selection = atoms, new_selection

@_register_structure
@_register_selection
Expand All @@ -1387,8 +1397,7 @@ def copy_sel(self, _=None, atoms=None, selection=None):
atoms += add_atoms

new_selection = list(range(last_atom, last_atom + len(selection)))

self.structure, self.selection = atoms, new_selection
self.structure, self.input_selection = atoms, new_selection

@_register_structure
@_register_selection
Expand Down Expand Up @@ -1421,12 +1430,14 @@ def add(self, _=None, atoms=None, selection=None):

new_selection = list(range(last_atom, last_atom + len(selection) * len(lgnd)))

self.structure, self.selection = atoms, new_selection
self.structure, self.input_selection = atoms, new_selection

@_register_structure
@_register_selection
def remove(self, _, atoms=None, selection=None):
"""Remove selected atoms."""
del [atoms[selection]]

self.structure, self.selection = atoms, []
self.structure = atoms
self.input_selection = None
self.input_selection = []
35 changes: 21 additions & 14 deletions aiidalab_widgets_base/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class _StructureDataBaseViewer(ipw.VBox):
"""

input_selection = List(Int, allow_none=True)
selection = List(Int)
displayed_selection = List(Int)
supercell = List(Int)
Expand Down Expand Up @@ -663,13 +664,21 @@ def highlight_atoms(
def _default_supercell(self):
return [1, 1, 1]

@default("selection")
def _default_selection(self):
return []
@observe("input_selection")
def _observe_input_selection(self, value):
if value["new"] is None:
return

# Exclude everything that is beyond the total number of atoms.
selection_list = [x for x in value["new"] if x < self.natom]

# In the case of a super cell, we need to multiply the selection as well
multiplier = sum(self.supercell) - 2
selection_list = [
x + self.natom * i for x in selection_list for i in range(multiplier)
]

@validate("selection")
def _validate_selection(self, provided):
return list(provided["value"])
self.displayed_selection = selection_list

@observe("displayed_selection")
def _observe_displayed_selection(self, _=None):
Expand All @@ -684,15 +693,16 @@ def apply_displayed_selection(self, _=None):
self._selected_atoms.value, shift=-1
)
if not syntax_ok:
# advance slection
try:
sel = self.parse_advanced_sel(condition=self._selected_atoms.value)
sel = self._parse_advanced_selection(
condition=self._selected_atoms.value
)
sel = list_to_string_range(sel, shift=1)
expanded_selection, syntax_ok = string_range_to_list(sel, shift=-1)
except (IndexError, TypeError, AttributeError):
syntax_ok = False
self.wrong_syntax.layout.visibility = "visible"
# self.wrong_syntax.layout.visibility = 'hidden' if syntax_ok else 'visible'

if syntax_ok:
self.wrong_syntax.layout.visibility = "hidden"
self.displayed_selection = expanded_selection
Expand Down Expand Up @@ -764,7 +774,6 @@ def __init__(self, structure=None, **kwargs):
super().__init__(**kwargs)
self.structure = structure
self.natom = len(self.structure) if self.structure is not None else 0
# self.supercell.observe(self.repeat, names='value')

@observe("supercell")
def repeat(self, _=None):
Expand Down Expand Up @@ -810,9 +819,7 @@ def _update_structure_viewer(self, change):
comp_id
) in self._viewer._ngl_component_ids: # pylint: disable=protected-access
self._viewer.remove_component(comp_id)
self.displayed_selection = (
[]
) # self.selection will be updated automatically
self.displayed_selection = []
if change["new"] is not None:
self._viewer.add_component(nglview.ASEStructure(change["new"]))
self._viewer.clear()
Expand Down Expand Up @@ -848,7 +855,7 @@ def not_operator(self, operand):
+ "]"
)

def parse_advanced_sel(self, condition=None):
def _parse_advanced_selection(self, condition=None):
"""Apply advanced selection specified in the text field."""

def addition(opa, opb):
Expand Down

0 comments on commit f2d5f19

Please sign in to comment.