Skip to content

Commit

Permalink
Added 3DPlot features, made tensor-grid default, solved multi-trace bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Simardeep27 committed Jul 16, 2021
1 parent ebf01e2 commit e21f15c
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 86 deletions.
2 changes: 2 additions & 0 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

183 changes: 98 additions & 85 deletions apps/Analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,24 @@
target="AP_button",
placement='right'
),
dbc.Col([dcc.Input(id="input_func", type="text", placeholder="Input Function...",
dbc.Col([
dbc.Row([
dbc.Col([dcc.Input(id="input_func", type="text", placeholder="Input Function...",
className='ip_field', debounce=True
, style={'width': '150px'}),
dbc.Col([
dbc.Spinner(html.Div(id='loading'), color="primary", show_initially=False,
spinner_style={'top': "-5rem"})
]),
, style={'width': '150px'})])
]),
dbc.Row([
dbc.Col(dbc.Alert(id='input-warning',color='danger',is_open=False),width='auto')
]),


dbc.Tooltip(
"The variables for input function should be of form x1,x2...",
target="input_func",
placement='right'
),
], width=4),
dbc.Col([dbc.Spinner(html.Div(id='param_added'),color='primary')]),
]),
html.Br(),
html.Br(),
Expand Down Expand Up @@ -165,7 +170,7 @@
dbc.Col([
dcc.Dropdown(
options=[
{'label': 'Univariate', 'value': 'univariate'},
# {'label': 'Univariate', 'value': 'univariate'},
{'label': 'Total-order', 'value': 'total-order'},
{'label': 'Tensor-grid', 'value': 'tensor-grid'},
{'label': 'Sparse-grid', 'value': 'sparse-grid'},
Expand All @@ -175,6 +180,7 @@
placeholder='Select Basis',
className="m-1", id='drop_basis',
optionHeight=45,
value='tensor-grid',
style={
"width": "165px",

Expand Down Expand Up @@ -406,16 +412,12 @@

dbc.Col([
dbc.Row([
dcc.Graph(id='plot_poly_3D', style={'width': '600px'}, figure=polyfig3D),
dbc.Spinner([
dcc.Graph(id='plot_poly_3D', style={'width': '600px'}, figure=polyfig3D)],color='primary',type='grow',
show_initially=False)
]),
dbc.Row([
dbc.Col([
dbc.Row([
dbc.Col(["1D"],width='auto',style={'margin-left':'180px'}),
dbc.Col(daq.ToggleSwitch(id='toggle',value=False,disabled=True),width='auto'),
dbc.Col(["2D"],width='auto'),

])
]),
])
], width=6),
Expand Down Expand Up @@ -484,10 +486,11 @@

@app.callback(
Output('AP_button', 'disabled'),
[Input('AP_button', 'n_clicks')]
[Input('AP_button', 'n_clicks'),
Input('CC_button','n_clicks')]
)
def check_param(n_clicks):
if n_clicks > 4:
def check_param(n_clicks,cn_clicks):
if n_clicks > 4 or cn_clicks>0:
return True
else:
return False
Expand All @@ -499,7 +502,7 @@ def check_param(n_clicks):

@app.callback(
Output('param_add', 'children'),
Output('loading', 'children'),
Output('param_added','children'),
[Input('AP_button', 'n_clicks'),
State('param_add', 'children')]
)
Expand Down Expand Up @@ -601,12 +604,12 @@ def addInputs(n_clicks, children):
],
no_gutters=True,
justify='start')
wait = time.sleep(1)

else:
add_card = dbc.Row()
children.append(add_card)
wait = None
return children, wait

return children,None

###################################################################
# Callback for disabling Cardinality Check button
Expand Down Expand Up @@ -904,10 +907,11 @@ def SetMethod(drop_basis):
)
def OutputCardinality(n_clicks, param_obj,ndims,params_click, basis_select, q_val, levels, growth_rule, solver_method):
if n_clicks != 0:
print(levels,growth_rule)
if basis_select is None:
return 'Error...',None,True,'No basis value selected'
elif basis_select=='sparse-grid' and (levels or growth_rule is None):
return 'ERROR...',None,True,'Enter the required values'
elif basis_select=='sparse-grid' and (levels or growth_rule) is None:
return 'ERROR...',None,True,'Enter the required values'
else:
param_data = jsonpickle.decode(param_obj)
basis_ord=[]
Expand Down Expand Up @@ -942,15 +946,15 @@ def PlotBasis(poly, n_clicks,ndims):
DOE = myPoly.get_points()
layout = {'margin': {'t': 0, 'r': 0, 'l': 0, 'b': 0},
'paper_bgcolor': 'white', 'plot_bgcolor': 'white', 'autosize': True,
"xaxis":{"title": r'x/C'}, "yaxis": {"title": r'y/C'}}
"xaxis":{"title": r'X'}, "yaxis": {"title": r'Y'}}

fig = go.Figure(layout=layout)
fig.update_xaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside')
fig.update_yaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside',
zerolinecolor='lightgrey')
fig.update_yaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside')
if ndims == 1:
fig.add_trace(go.Scatter(x=DOE, y=DOE, mode='markers',marker=dict(size=5, color="rgb(144, 238, 144)", opacity=0.6,
fig.add_trace(go.Scatter(x=DOE[:,0], y=np.zeros_like(DOE[:,0]), mode='markers',marker=dict(size=10, color="rgb(144, 238, 144)", opacity=1,
line=dict(color='rgb(0,0,0)', width=1))))
fig.update_yaxes(visible=False)
return fig
elif ndims == 2:
fig.add_trace(go.Scatter(x=DOE[:, 0], y=DOE[:, 1],mode='markers',marker=dict(size=5, color="rgb(144, 238, 144)", opacity=0.6,
Expand Down Expand Up @@ -1006,25 +1010,32 @@ def PlotBasis(poly, n_clicks,ndims):
Output('r2_score', 'value'),
Output('True_vals', 'data'),
Output('Sobol_plot','figure'),
Output('input-warning','is_open'),
Output('input-warning','children'),
],
[
Input('PolyObject', 'data'),
Input('input_func', 'value'),
Input('CU_button', 'n_clicks'),
Input('AP_button', 'n_clicks'),
Input('sobol_order', 'value'),
]
],
State('input_func', 'value'),
)
def SetModel(poly, expr, compute_button, n_clicks, order):

def SetModel(poly,compute_button,n_clicks, order,expr):
if compute_button != 0:
myPoly = jsonpickle.decode(poly)
x = [r"x{} = op[{}]".format(j, j - 1) for j in range(1, n_clicks + 1)]

x = [r"x{} = op[{}]".format(j, j - 1) for j in range(1, n_clicks + 1)]
def f(op):
for i in range(n_clicks):
exec(x[i])
return ne.evaluate(expr)
myPoly.set_model(f)
try:
myPoly.set_model(f)
except KeyError or ValueError:
return None,None,None,None,None,None,True,"Incorrect variable naming"
values = myPoly.get_mean_and_variance()
mean = values[0]
variance = values[1]
Expand All @@ -1041,8 +1052,7 @@ def f(op):
'paper_bgcolor': 'white', 'plot_bgcolor': 'white', 'autosize': True}
fig=go.Figure(layout=layout)
fig.update_xaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside')
fig.update_yaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside',
zerolinecolor='lightgrey')
fig.update_yaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside')
if order is not None:
ndims=myPoly.dimensions
sobol_indices=myPoly.get_sobol_indices(order=order)
Expand All @@ -1051,8 +1061,8 @@ def f(op):
fig=go.Figure(layout=layout)
if order==1:
fig.update_yaxes(title=r'$S_{i}$')
labels = [r'$X_%d$' % i for i in range(int(ndims))]
to_plot = [sobol_indices[(i,)] for i in range(int(ndims))]
labels = [r'$X_%d$' % i for i in range((ndims))]
to_plot = [sobol_indices[(i,)] for i in range((ndims))]
elif order==2:
fig.update_yaxes(title=r'$S_{ij}$')
labels = [r'$S_{%d%d}$' % (i, j) for i in range(int(ndims)) for j in range(i + 1, int(ndims))]
Expand All @@ -1070,7 +1080,7 @@ def f(op):
fig = go.Figure(layout=layout,data=data)
fig.update_layout(uniformtext_minsize=8, uniformtext_mode='hide', xaxis_tickangle=-30)

return jsonpickle.encode(myPoly), mean, variance, r2_score, jsonpickle.encode(y_true),fig ###
return jsonpickle.encode(myPoly), mean, variance, r2_score, jsonpickle.encode(y_true),fig,False,None ###
else:
raise PreventUpdate

Expand All @@ -1079,15 +1089,7 @@ def f(op):
# Disabling toggle for 1D/2D for polyfit plotting function
###################################################################

@app.callback(
Output('toggle','disabled'),
Input('CU_button','n_clicks')
)
def ToggleCheck(n_clicks):
if n_clicks>0:
return False
else:
return True


# @app.callback(
# Output('cu_tooltip','children'),
Expand All @@ -1114,22 +1116,23 @@ def ToggleCheck(n_clicks):

@app.callback(
Output('plot_poly_3D', 'figure'),
Output('plot_poly_3D','style'),
[
Input('ModelSet', 'data'),
Input('CU_button', 'n_clicks'),
Input('True_vals', 'data'),
Input('AP_button','n_clicks'),
Input('toggle','value'),
Input('ndims','data')
Input('ndims','data'),
],
State('plot_poly_3D', 'figure'),
prevent_initial_call=True
)
def Plot_poly_3D(ModelSet, n_clicks, true_vals, param_num, dims,ndims,fig):
def Plot_poly_3D(ModelSet, n_clicks, true_vals, param_num,ndims,fig):
hide={'display':'None'}
default={'width':'600px'}
if ModelSet is not None:
if dims:
if param_num==2:
layout = dict(margin={'t': 0, 'r': 0, 'l': 0, 'b': 0, 'pad': 0}, autosize=True,
if ndims==2:
layout = dict(margin={'t': 0, 'r': 0, 'l': 0, 'b': 0, 'pad': 0}, autosize=True,
scene=dict(
aspectmode='cube',
xaxis=dict(
Expand All @@ -1140,47 +1143,57 @@ def Plot_poly_3D(ModelSet, n_clicks, true_vals, param_num, dims,ndims,fig):
title=r'f(x)'),
),
)
myPoly = jsonpickle.decode(ModelSet)
y_true = jsonpickle.decode(true_vals)
myPolyFit = myPoly.get_polyfit
DOE = myPoly.get_points()
N = 20
s1_samples = np.linspace(DOE[0, 0], DOE[-1, 0], N)
s2_samples = np.linspace(DOE[0, 1], DOE[-1, 1], N)
[S1, S2] = np.meshgrid(s1_samples, s2_samples)
S1_vec = np.reshape(S1, (N * N, 1))
S2_vec = np.reshape(S2, (N * N, 1))
samples = np.hstack([S1_vec, S2_vec])
PolyDiscreet = myPolyFit(samples)
PolyDiscreet = np.reshape(PolyDiscreet, (N, N))

fig = go.Figure(fig)
fig.update_layout(layout, scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)))
fig.data = fig.data[0:2]
fig.plotly_restyle({'x': S1, 'y': S2, 'z': PolyDiscreet}, 0)
fig.plotly_restyle({'x': DOE[:, 0], 'y': DOE[:, 1], 'z': y_true.squeeze()}, 1)

return fig
else:
raise PreventUpdate
else:
layout = {"xaxis": {"title": r'X1'}, "yaxis": {"title": r'X2'},
fig = go.Figure(fig)
fig.update_layout(layout, scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)))
fig.data = fig.data[0:2]
myPoly = jsonpickle.decode(ModelSet)
y_true = jsonpickle.decode(true_vals)
myPolyFit = myPoly.get_polyfit
DOE = myPoly.get_points()
N = 20
s1_samples = np.linspace(DOE[0, 0], DOE[-1, 0], N)
s2_samples = np.linspace(DOE[0, 1], DOE[-1, 1], N)
[S1, S2] = np.meshgrid(s1_samples, s2_samples)
S1_vec = np.reshape(S1, (N * N, 1))
S2_vec = np.reshape(S2, (N * N, 1))
samples = np.hstack([S1_vec, S2_vec])
PolyDiscreet = myPolyFit(samples)
PolyDiscreet = np.reshape(PolyDiscreet, (N, N))
fig.plotly_restyle({'x': S1, 'y': S2, 'z': PolyDiscreet}, 0)
fig.plotly_restyle({'x': DOE[:, 0], 'y': DOE[:, 1], 'z': y_true.squeeze()}, 1)
return fig,default
elif ndims==1:
layout = {"xaxis": {"title": r'X1'}, "yaxis": {"title": r'f(X)'},
'margin': {'t': 0, 'r': 0, 'l': 0, 'b': 60},
'paper_bgcolor': 'white', 'plot_bgcolor': 'white', 'autosize': True}

myPoly = jsonpickle.decode(ModelSet)
y_true = jsonpickle.decode(true_vals)
myPolyFit = myPoly.get_polyfit
DOE = myPoly.get_points()

fig = go.Figure(fig)
N = 20
s1_samples = np.linspace(DOE[0, 0], DOE[-1, -1], N)
[S1] = np.meshgrid(s1_samples)
S1_vec = np.reshape(S1, (N , 1))
samples = np.hstack([S1_vec])
PolyDiscreet = myPolyFit(samples)
PolyDiscreet = np.reshape(PolyDiscreet, (N))
fig = go.Figure(layout=layout)

fig.update_xaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside')
fig.update_yaxes(color='black', linecolor='black', showline=True, tickcolor='black', ticks='outside')
fig.update_layout(layout)
fig.plotly_restyle({'x': [[]], 'y': [[]], 'z': [[]]}, 0)
fig.plotly_restyle({'x': [[]], 'y': [[]], 'z': [[]]}, 1)
fig.add_trace(go.Scatter(x=DOE[:,0], y=y_true.flatten(), mode='markers', name='Training samples',
marker=dict(color='rgb(135,206,250)', size=15, opacity=0.5,
line=dict(color='rgb(0,0,0)', width=1))))
if len(fig.data)==3:
fig.plotly_restyle({'x':DOE[:,0],'y':y_true.flatten()},2)
else:
fig.add_trace(go.Scatter(x=DOE[:, 0], y=y_true.flatten(), mode='markers', name='Training samples',
marker=dict(color='rgb(135,206,250)', size=15, opacity=0.5,
line=dict(color='rgb(0,0,0)', width=1))))
fig.add_trace(go.Scatter(x=S1,y=PolyDiscreet,mode='lines',name='f(x)',line_color='rgb(178,34,34)'))

return fig
return fig,default
else:
return {},hide

else:
raise PreventUpdate
2 changes: 1 addition & 1 deletion assets/style.css
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@import url('https://fonts.googleapis.com/css2?family=Raleway:wght@100;300;400;700&display=swap');

body{
background-color: #EBEFF2;
background-color: #FFFFFF;
font-family:"Raleway";
}

Expand Down

0 comments on commit e21f15c

Please sign in to comment.