Skip to content

Commit

Permalink
ENH: log_breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Oct 10, 2018
1 parent 5dae986 commit 385b9e8
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 15 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Enhancements
- Better support for handling missing values when training discrete
scales.

- Changed the algorithm for :class:`~mizani.breaks.log_breaks`, it can
now return breaks that do not fall on the integer powers of the base.

v0.4.6
------
*(2018-03-20)*
Expand Down
108 changes: 95 additions & 13 deletions mizani/breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from matplotlib.ticker import MaxNLocator

from .utils import min_max, SECONDS, NANOSECONDS
from .utils import same_log10_order_of_magnitude


__all__ = ['mpl_breaks', 'log_breaks', 'minor_breaks',
Expand Down Expand Up @@ -105,10 +104,10 @@ class log_breaks:
Examples
--------
>>> x = np.logspace(3, 7)
>>> x = np.logspace(3, 6)
>>> limits = min(x), max(x)
>>> log_breaks()(limits)
array([ 100, 10000, 1000000])
array([ 100, 1000, 10000, 100000, 1000000])
>>> log_breaks(2)(limits)
array([ 100, 100000])
"""
Expand All @@ -131,26 +130,109 @@ def __call__(self, limits):
out : array_like
Sequence of breaks points
"""
n = self.n
base = self.base

if any(np.isinf(limits)):
return []

n = self.n
base = self.base
rng = np.log(limits)/np.log(base)

if base == 10 and same_log10_order_of_magnitude(rng):
return extended_breaks(n=4)(limits)

_min = int(np.floor(rng[0]))
_max = int(np.ceil(rng[1]))

if _max == _min:
return base ** _min

step = (_max-_min)//n + 1
dtype = float if (_min < 0) else int
return base ** np.arange(_min, _max+1, step, dtype=dtype)
# Try getting breaks at the integer powers of the base
# e.g [1, 100, 10000, 1000000]
# If there are too few breaks, try other points using the
# _log_sub_breaks
by = int(np.floor((_max-_min)/n)) + 1
for step in range(by, 0, -1):
breaks = base ** np.arange(_min, _max+1, step=step)
relevant_breaks = (
(limits[0] <= breaks) &
(breaks <= limits[1])
)
if np.sum(relevant_breaks) >= n-2:
return breaks

return _log_sub_breaks(n=n, base=base)(limits)


class _log_sub_breaks:
"""
Breaks for log transformed scales
Calculate breaks that do not fall on integer powers of
the base.
Parameters
----------
n : int
Desired number of breaks
base : int
Base of logarithm
Notes
-----
Credit: Thierry Onkelinx ([email protected]) for the original
algorithm in the r-scales package.
"""

def __init__(self, n=5, base=10):
self.n = n
self.base = base

def __call__(self, limits):
base = self.base
n = self.n
rng = np.log(limits)/np.log(base)
_min = int(np.floor(rng[0]))
_max = int(np.ceil(rng[1]))
steps = [1]

def delta(x):
"""
Calculates the smallest distance in the log scale between the
currectly selected breaks and a new candidate 'x'
"""
arr = np.sort(np.hstack([x, steps, base]))
if base == 10:
log_arr = np.log10(arr)
else:
log_arr = np.log(arr) / np.log(base)
return np.min(np.diff(log_arr))

if self.base == 2:
return base ** np.arange(_min, _max+1)

candidate = np.arange(base+1)
candidate = np.compress(
(1 < candidate) & (candidate < base), candidate)

while len(candidate):
best = np.argmax([delta(x) for x in candidate])
steps.append(candidate[best])
candidate = np.delete(candidate, best)

breaks = np.outer(
base ** np.arange(_min, _max+1), steps).ravel()
relevant_breaks = (
(limits[0] <= breaks) & (breaks <= limits[1]))

if np.sum(relevant_breaks) >= n-2:
breaks = np.sort(breaks)
lower_end = np.max([
np.min(np.where(limits[0] <= breaks))-1,
0
])
upper_end = np.min([
np.max(np.where(breaks <= limits[1]))+1,
len(breaks)
])
return breaks[lower_end:upper_end+1]
else:
return extended_breaks(n=n)(limits)


class minor_breaks:
Expand Down
12 changes: 10 additions & 2 deletions mizani/tests/test_breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,16 @@ def test_log_breaks():
assert all([1 < b < 100 for b in breaks])

breaks = log_breaks()([200, 800])
assert len(breaks) > 0
assert all([10 < b < 1000 for b in breaks])
npt.assert_array_equal(breaks, [100, 200, 300, 500, 1000])

breaks = log_breaks()((1664, 14008))
npt.assert_array_equal(breaks, [1000, 3000, 5000, 10000, 30000])

breaks = log_breaks()([407, 3430])
npt.assert_array_equal(breaks, [300, 500, 1000, 3000, 5000])

breaks = log_breaks()([1761, 8557])
npt.assert_array_equal(breaks, [1000, 2000, 3000, 5000, 10000])


def test_minor_breaks():
Expand Down

0 comments on commit 385b9e8

Please sign in to comment.