diff --git a/prysm/_phase.py b/prysm/_phase.py index 029a214b..f340ddbb 100644 --- a/prysm/_phase.py +++ b/prysm/_phase.py @@ -100,17 +100,38 @@ def __init__(self, unit_x, unit_y, phase, phase_unit, spatial_unit, wavelength=N self.unit_y = unit_y self.phase = phase self.wavelength = wavelength - pul, sul = phase_unit.lower(), spatial_unit.lower() - if pul not in self.units: - raise ValueError(f'{pul} not a valid unit, must be in {set(self.units.keys())}') - if sul not in self.units: - raise ValueError(f'{sul} not a valid unit, must be in {set(self.units.keys())}') - self.phase_unit = self.units[phase_unit.lower()] - self.spatial_unit = self.units[spatial_unit.lower()] + self.phase_unit = phase_unit + self.spatial_unit = spatial_unit self.center_x = len(self.unit_x) // 2 self.center_y = len(self.unit_y) // 2 self.sample_spacing = unit_x[1] - unit_x[0] + @property + def phase_unit(self): + return self._phase_unit + + @phase_unit.setter + def phase_unit(self, unit): + unit = unit.lower() + if unit == 'å': + self._phase_unit = unit.upper() + else: + if unit not in self.units: + raise ValueError(f'{unit} not a valid unit, must be in {set(self.units.keys())}') + self._phase_unit = self.units[unit] + + @property + def spatial_unit(self): + return self._spatial_unit + + @spatial_unit.setter + def spatial_unit(self, unit): + unit = unit.lower() + if unit not in self.units: + raise ValueError(f'{unit} not a valid unit, must be in {set(self.units.keys())}') + + self._spatial_unit = self.units[unit] + @property def slice_x(self): """Retrieve a slice through the X axis of the phase. @@ -171,7 +192,7 @@ def change_phase_unit(self, to, inplace=True): new_phase = self.phase / fctr if inplace: self.phase = new_phase - self.phase_unit = self.units[to.lower()] + self.phase_unit = to return self else: return new_phase @@ -204,7 +225,7 @@ def change_spatial_unit(self, to, inplace=True): if inplace: self.unit_x = new_ux self.unit_y = new_uy - self.spatial_unit = self.units[to.lower()] + self.spatial_unit = to self.sample_spacing /= fctr return self else: diff --git a/prysm/psf.py b/prysm/psf.py index 8e8b84bc..eec33d14 100644 --- a/prysm/psf.py +++ b/prysm/psf.py @@ -258,7 +258,7 @@ def plot_encircled_energy(self, axlim=None, npts=50, fig=None, ax=None): elif axlim is 0: raise ValueError('computing from 0 to 0 is stupid') else: - xx = m.linspace(0, axlim, npts) + xx = m.linspace(1e-5, axlim, npts) yy = self.encircled_energy(xx) fig, ax = share_fig_ax(fig, ax)