-
Notifications
You must be signed in to change notification settings - Fork 1
/
tools.py
190 lines (138 loc) · 5.07 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from __future__ import division
import numpy as np
from numpy import pi, sqrt, exp, power, log, log10
from scipy.interpolate import interp1d
from scipy.special import erf, lambertw
from scipy.optimize import brentq
#########################################
def interp_fn(array):
"""
An interpolator for log-arrays spanning many orders of magnitude.
Parameters
----------
array : An array of shape (N, 2) from which to interpolate.
"""
array[array < 1.e-300] = 1.e-300 # regularizing small numbers
def fn(x): return 10**interp1d(log10(array[:, 0]),
log10(array[:, 1]), fill_value='extrapolate')(log10(x))
return fn
def zeros(fn, arr, *args):
"""
Find where a function crosses 0. Returns the zeroes of the function.
Parameters
----------
fn : function
arr : array of arguments for function
*args : any other arguments the function may have
"""
# the reduced function, with only the argument to be solved for (all other arguments fixed):
def fn_reduced(array): return fn(array, *args)
# the array of values of the function:
fn_arr = fn_reduced(arr)
# looking where the function changes sign...
sign_change_arr = np.where(np.logical_or((fn_arr[:-1] < 0.) * (fn_arr[1:] > 0.),
(fn_arr[:-1] > 0.) * (fn_arr[1:] < 0.))
)[0]
# or, just in case, where it is exactly 0!
exact_zeros_arr = np.where(fn_arr == 0.)[0]
# defining the array of 0-crossings:
cross_arr = []
# first, interpolating between the sign changes
if len(sign_change_arr) > 0:
for i in range(len(sign_change_arr)):
cross_arr.append(
brentq(fn_reduced, arr[sign_change_arr[i]],
arr[sign_change_arr[i] + 1])
)
# and then adding those places where it is exactly 0
if len(exact_zeros_arr) > 0:
for i in range(len(exact_zeros_arr)):
cross_arr.append(arr[exact_zeros_arr[i]])
# sorting the crossings in increasing order:
cross_arr = np.sort(np.array(cross_arr))
return cross_arr
def treat_as_arr(arg):
"""
A routine to cleverly return scalars as (temporary and fake) arrays. True arrays are returned unharmed. Thanks to Chen!
"""
arr = np.asarray(arg)
is_scalar = False
# making sure scalars are treated properly
if arr.ndim == 0: # it is really a scalar!
arr = arr[None] # turning scalar into temporary fake array
is_scalar = True # keeping track of its scalar nature
return arr, is_scalar
def my_ceil(x, precision=0):
x_arr, is_scalar = treat_as_arr(x)
res = np.true_divide(np.ceil(x_arr*10**precision), 10**precision)
if is_scalar:
res = np.squeeze(res)
return res
def my_floor(x, precision=0):
x_arr, is_scalar = treat_as_arr(x)
res = np.true_divide(np.floor(x*10**precision), 10**precision)
if is_scalar:
res = np.squeeze(res)
return res
# return np.true_divide(np.floor(x*10**precision), 10**precision)
def load_dct(dct, key):
"""Used to load and determine if dict has a key
:param dct: the dictionary to be interrogated
:param key: the key to be tried
"""
try:
res = dct[key]
is_success = True
except KeyError:
res = None
is_success = False
return res, is_success
def scientific(val, output='string'):
"""Convert a number to the scientific form
:param val: number(s) to be converted
:param output: LaTeX "string" form or "number" form. (Default: 'string')
"""
val, is_scalar = treat_as_arr(val)
exponent, factor = [], []
string = []
for vali in val:
expi = int(np.log10(vali))
faci = vali / 10**expi
# save it
exponent.append(expi)
factor.append(faci)
if round(faci) == 1.:
string.append(r"$10^{{{:.0f}}}$".format(expi))
else:
string.append(
r"${{{:.0f}}} \times 10^{{{:.0f}}}$".format(faci, expi))
exponent = np.array(exponent)
factor = np.array(factor)
string = np.array(string)
if is_scalar:
exponent = np.squeeze(exponent)
factor = np.squeeze(factor)
string = np.squeeze(string)
if output == 'string':
res = string
elif output == 'number':
res = (factor, exponent)
return res
def flatten_tuples(t):
"""Simple tool to flatten tuples
:param t: nested tuples
:returns: flattened tuple
"""
for x in t:
if isinstance(x, tuple):
yield from flatten_tuples(x)
else:
yield x
def smooth_step(x, h0, h1, delx, x0):
"""A smoothed step function making use of the hyperbolic tangent function
:param x: the variable
:param h: the "step" size defined as the function value diference between +infty and -infty
:param delx: step size. Roughly after one x0 the function value should go from 0 to h
:param x0: the step location. Half of the step is taken at x=b.
"""
return h0 + (h1-h0)/2.*(np.tanh((x-x0)/delx)+1)