Skip to content

Commit

Permalink
RF: Refactor keyword-only code and Add keyword-only arguments to func…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
WassCodeur committed Jun 5, 2024
1 parent 8e5c0b0 commit 0126a15
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 89 deletions.
156 changes: 76 additions & 80 deletions fury/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,107 +51,103 @@ def doctest_skip_parser(func):


def keyword_only(func):
"""A decorator to enforce keyword-only arguments.
"""Decorator to enforce keyword-only arguments.
This decorator is used to enforce that certain arguments of a function
are passed as keyword arguments. This is useful to prevent users from
passing arguments in the wrong order.
This decorator enforces that all arguments after the first one are
keyword-only arguments. It also checks that all keyword arguments are
expected by the function.
Parameters
----------
func : callable
The function to decorate.
Parameters:
-----------
func: function
Function to be decorated.
Returns
-------
callable
The decorated function.
Examples
Returns:
--------
wrapper: function
Decorated function.
Examples:
---------
>>> @keyword_only
... def add(*, a, b):
... return a + b
>>> add(a=1, b=2)
3
>>> add(b=2, a=1, c=3)
... def f(a, b, *, c, d=1, e=1):
... return a + b + c + d + e
>>> f(1, 2, 3, 4, 5)
15
>>> f(1, 2, c=3, d=4, e=5)
15
>>> f(1, 2, 2, 4, e=5)
14
>>> f(1, 2, c=3, d=4)
11
>>> f(1, 2, d=3, e=5)
Traceback (most recent call last):
...
TypeError: f() missing 1 required keyword-only argument: 'c'
>>> f(1, 2, c=3, d=4, e=5, f=6)
Traceback (most recent call last):
...
TypeError: add() got an unexpected keyword arguments: c
Usage: add(a=[your_value], b=[your_value])
Please Provide keyword-only arguments: a=[your_value], b=[your_value]
>>> add(1, 2)
TypeError: f() got an unexpected keyword argument 'f'
>>> f(1, c=3, d=4, e=5)
Traceback (most recent call last):
...
TypeError: add() takes 0 positional arguments but 2 were given
Usage: add(a=[your_value], b=[your_value])
Please Provide keyword-only arguments: a=[your_value], b=[your_value]
>>> add(a=1)
TypeError: f() missing 1 required positional argument: 'b'
>>> f(1, 2, 3, 4, 5, 6)
Traceback (most recent call last):
...
TypeError: add() missing 1 required keyword-only arguments: b
Usage: add(a=[your_value], b=[your_value])
Please Provide keyword-only arguments: a=[your_value], b=[your_value]
TypeError: f() takes 2 positional arguments but 6 were given
"""

@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)
params = sig.parameters
missing_params = [
# args_names = [param.name for param in params.values()]
KEYWORD_ONLY_ARGS = [
arg.name for arg in params.values() if arg.kind == arg.KEYWORD_ONLY
]
POSITIONAL_ARGS = [
arg.name
for arg in params.values()
if arg.name not in kwargs and arg.kind == arg.KEYWORD_ONLY
if arg.kind in (arg.POSITIONAL_OR_KEYWORD, arg.POSITIONAL_ONLY)
]
params_sample = [
f"{arg}=[your_value]"
for arg in params.values()
if arg.kind == arg.KEYWORD_ONLY
missing_kwargs = [
arg
for arg in KEYWORD_ONLY_ARGS
if arg not in kwargs and params[arg].default == params[arg].empty
]
params_sample_str = ", ".join(params_sample)
unexpected_params_list = [arg for arg in kwargs if arg not in params]
unexpected_params = ", ".join(unexpected_params_list)
if args:
raise TypeError(
(
"{}() takes 0 positional arguments but {} were given\n"
"Usage: {}({})\n"
"Please Provide keyword-only arguments: {}"
).format(
func.__name__,
len(args),
func.__name__,
params_sample_str,
params_sample_str,
)
)
else:
if unexpected_params:
raise TypeError(
"{}() got an unexpected keyword arguments: {}\n"
"Usage: {}({})\n"
"Please Provide keyword-only arguments: {}".format(
func.__name__,
unexpected_params,
func.__name__,
params_sample_str,
params_sample_str,
)
)

elif missing_params:
raise TypeError(
"{}() missing {} required keyword-only arguments: {}\n"
"Usage: {}({})\n"
"Please Provide keyword-only arguments: {}".format(
func.__name__,
len(missing_params),
", ".join(missing_params),
func.__name__,
params_sample_str,
params_sample_str,
)
)
ARG_DEFAULT = [
arg
for arg in KEYWORD_ONLY_ARGS
if arg not in kwargs and params[arg].default != params[arg].empty
]
func_params_sample = []
for arg in params.values():
if arg.kind in (arg.POSITIONAL_OR_KEYWORD, arg.POSITIONAL_ONLY):
func_params_sample.append(f"{arg.name}_value")
elif arg.kind == arg.KEYWORD_ONLY:
func_params_sample.append(f"{arg.name}='value'")
func_params_sample = ", ".join(func_params_sample)
args_kwargs_len = len(args) + len(kwargs)
params_len = len(params)
try:
return func(*args, **kwargs)
except Exception:
if ARG_DEFAULT:
missing_kwargs += ARG_DEFAULT
if missing_kwargs and params_len >= args_kwargs_len:
positional_args_len = len(POSITIONAL_ARGS)
args_k = list(args[positional_args_len:])
args = list(args[:positional_args_len])
kwargs.update(dict(zip(missing_kwargs, args_k)))
warn(
"Here's how to call the Function {}: {}({})".format(
func.__name__, func.__name__, func_params_sample
),
UserWarning,
stacklevel=3,
)
result = func(*args, **kwargs)
return result

return wrapper
27 changes: 25 additions & 2 deletions fury/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from fury.decorators import keyword_only
from fury.decorators import keyword_only
from fury.lib import (
Command,
Expand Down Expand Up @@ -118,6 +119,12 @@ def get_prop_at_event_position(self):
# TODO: return a list of items (i.e. each level of the assembly path).
event_pos = self.GetInteractor().GetEventPosition()

self.picker.Pick(
event_pos[0],
event_pos[1],
0,
self.GetCurrentRenderer(),
)
self.picker.Pick(
event_pos[0],
event_pos[1],
Expand Down Expand Up @@ -179,6 +186,8 @@ def _process_event(self, obj, evt):

self.event.reset() # Event fully processed.

@keyword_only
def _button_clicked(self, button, *, last_event=-1, before_last_event=-2):
@keyword_only
def _button_clicked(self, button, *, last_event=-1, before_last_event=-2):
if len(self.history) < abs(before_last_event):
Expand All @@ -187,6 +196,7 @@ def _button_clicked(self, button, *, last_event=-1, before_last_event=-2):
if self.history[last_event]["event"] != button + "ButtonReleaseEvent":
return False

if (self.history[before_last_event]["event"]) != (button + "ButtonPressEvent"): # noqa
if (self.history[before_last_event]["event"]) != (button + "ButtonPressEvent"): # noqa
return False

Expand All @@ -198,8 +208,8 @@ def _button_double_clicked(self, button):
and (
self._button_clicked(
button,
last_event=-3,
before_last_event=-4,
-3,
-4,
)
)
):
Expand Down Expand Up @@ -399,6 +409,16 @@ def force_render(self):
"""Causes the scene to refresh."""
self.GetInteractor().GetRenderWindow().Render()

@keyword_only
def add_callback(
self,
prop,
event_type,
callback,
*,
priority=0,
args=None,
):
@keyword_only
def add_callback(
self,
Expand Down Expand Up @@ -440,6 +460,9 @@ def _callback(_obj, event_name):
self.event2id[event_type] = (Command.UserEvent) + (
len(self.event2id) + 1
)
self.event2id[event_type] = (Command.UserEvent) + (
len(self.event2id) + 1
)

event_type = self.event2id[event_type]

Expand Down
30 changes: 24 additions & 6 deletions fury/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def load_image(filename, *, as_vtktype=False, use_pillow=True):
desired image array
"""
is_url = filename.lower().startswith("http://") or filename.lower().startswith(
"https://"
is_url = (filename.lower().startswith("http://")) or (
filename.lower().startswith("https://")
)

if is_url:
Expand Down Expand Up @@ -138,7 +138,14 @@ def load_image(filename, *, as_vtktype=False, use_pillow=True):

# width, height
vtk_image.SetDimensions(image.shape[1], image.shape[0], depth)
vtk_image.SetExtent(0, image.shape[1] - 1, 0, image.shape[0] - 1, 0, 0)
vtk_image.SetExtent(
0,
image.shape[1] - 1,
0,
image.shape[0] - 1,
0,
0,
)
vtk_image.SetSpacing(1.0, 1.0, 1.0)
vtk_image.SetOrigin(0.0, 0.0, 0.0)

Expand Down Expand Up @@ -312,7 +319,13 @@ def save_image(
writer.SetQuality(compression_quality)
if extension.lower() in [".tif", ".tiff"]:
compression_type = compression_type or "nocompression"
l_compression = ["nocompression", "packbits", "jpeg", "deflate", "lzw"]
l_compression = [
"nocompression",
"packbits",
"jpeg",
"deflate",
"lzw",
]

if compression_type.lower() in l_compression:
comp_id = l_compression.index(compression_type.lower())
Expand Down Expand Up @@ -457,13 +470,18 @@ def load_sprite_sheet(sheet_path, nb_rows, nb_cols, *, as_vtktype=False):
nxt_col * sprite_size_y,
)

sprite_arr = sprite_sheet[box[0] : box[2], box[1] : box[3]]
sprite_arr = sprite_sheet[
box[0] : box[2], box[1] : box[3] # noqa: E203
]
if as_vtktype:
with InTemporaryDirectory() as tdir:
tmp_img_path = os.path.join(tdir, f"{row}{col}.png")
save_image(sprite_arr, tmp_img_path, compression_quality=100)

sprite_dicts[(row, col)] = load_image(tmp_img_path, as_vtktype=True)
sprite_dicts[(row, col)] = load_image(
tmp_img_path,
as_vtktype=True,
)
else:
sprite_dicts[(row, col)] = sprite_arr

Expand Down
2 changes: 1 addition & 1 deletion fury/molecular.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ def ribbon(molecule):

for i in range(num_total_atoms):
radii[i] = np.repeat(
table.atomic_radius(all_atomic_numbers[i], radius_type="VDW"),
table.atomic_radius(all_atomic_numbers[i], "VDW"),
3,
)
rgb[i] = table.atom_color(all_atomic_numbers[i])
Expand Down

0 comments on commit 0126a15

Please sign in to comment.