Skip to content

Commit

Permalink
Added semi approach for creating parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Simardeep27 committed Aug 16, 2021
1 parent 8d28125 commit 5495061
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 59 deletions.
2 changes: 2 additions & 0 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions apps/Analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,6 @@ def addInputs(n_clicks, children):
{'label': 'Uniform', 'value': 'uniform'},
{'label': 'Gaussian', 'value': 'gaussian'},
{'label': 'Truncated Gaussian', 'value': 'truncated-gaussian'},
# {'label': 'Chi', 'value': 'chi'},
# {'label': 'Cauchy', 'value': 'cauchy'},
{'label': 'LogNormal', 'value': 'lognormal'},
{'label': 'Beta', 'value': 'beta'}
],
Expand Down
153 changes: 98 additions & 55 deletions apps/Data_driven.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from dash.exceptions import PreventUpdate
import plotly.graph_objs as go
import pandas as pd
import scipy
import base64
import dash_table
from dash_table.Format import Format, Scheme, Trim
from navbar import navbar



import numpy as np
import equadratures as eq
import equadratures.distributions as db
Expand Down Expand Up @@ -63,6 +65,11 @@
# Parameter Definition Card
###################################################################

MEAN_VAR_DIST = ["gaussian"]
LOWER_UPPER_DIST = ["uniform"]
SHAPE_PARAM_DIST = ["lognormal"]
ALL_4 = ["beta", "truncated-gaussian"]


Upload_dataset=html.Div([
dcc.Upload(
Expand Down Expand Up @@ -118,21 +125,21 @@
[
dbc.Label('Parameter Definition:', html_for="mode-select",width=6),
dbc.Col(dcc.Dropdown(id="mode-select",options=[
{'label': 'Manual', 'value': 'manual'},
{'label':'Semi','value':'semi'},
{'label': 'Automatic', 'value':'auto'},
],searchable=False),width=4)
], row=True
),
dbc.FormGroup(
[
dbc.Label('Parameter Distribution:', html_for="distribution-select", width=6),
dbc.Col(dcc.Dropdown(id="distribution-select",options=[
{'label': 'Uniform', 'value': 'uniform'},
{'label': 'Gaussian', 'value': 'gaussian'},
{'label': 'Truncated Gaussian', 'value': 'truncated-gaussian'},
{'label': 'LogNormal', 'value': 'lognormal'},
{'label': 'Beta', 'value': 'beta'}
], searchable=False,disabled=True),width=4)
dbc.Label('Parameter Order:', html_for="order-select", width=6),
dbc.Col(dcc.Dropdown(id="order-select",options=[
{'label': '1', 'value': 1},
{'label': '2', 'value': 2},
{'label': '3', 'value': 3},
{'label': '4', 'value': 4},
{'label': '5', 'value': 5}
], searchable=False,disabled=False),width=4)
], row=True
),
]
Expand Down Expand Up @@ -181,7 +188,10 @@
[
dbc.Col(data_inputs, width=12),
]
)
),
dbc.Row([
dbc.Col(html.Div(id='param_add_datadriven',children=[]),width=12)
])

],
className='top_card',
Expand Down Expand Up @@ -458,76 +468,61 @@ def InputVars(columns,select):
raise PreventUpdate


@app.callback(
Output('distribution-select','disabled'),
Input('mode-select','value'),
prevent_intial_call=True
)
def DisabledParam(mode):
if mode=='manual':
return False
elif mode=='auto':
return True
else:
raise PreventUpdate



def CreateParam(data,columns,distribution):
if data is not None:
if distribution is not None:
dist=distribution
else:
pass
param_objs=[]
values = []
lower = 0
upper = 0
options = []
for index,i in enumerate(data[0].keys()):
values = ([vals['{}'.format(i)] for vals in data])
values = [vals for vals in values if values != 'nan']
try:
lower = min(values)
upper = max(values)
except NameError:
print('Incorrect data')

param_objs.append(eq.Parameter(distribution=dist,lower=lower,upper=upper,order=3))
return param_objs

def CreateParamWeights(data,columns,distribution):
def CreateParamWeights(data,columns,order):
if data is not None:
param_objs=[]
for index, i in enumerate(data[0].keys()):
values = ([vals['{}'.format(i)] for vals in data])
values = np.array(values)
weight_obj=eq.Weight(values)
param_objs.append(eq.Parameter(distribution='data',weight_function=weight_obj,order=3))
param_objs.append(eq.Parameter(distribution='data',weight_function=weight_obj,order=order))
return param_objs


def CreateParamSemi(data,columns,distributions,output,order):
labels = []
for i in columns:
if i != output:
labels.append(i)
else:
pass
x_data = [[None for y in range(len(data[0].keys()))]
for x in range(len(data))]
for i in range(len(data)):
for ind, j in enumerate(labels):
x_data[i][ind] = data[i][j]
values=np.array(x_data)
param_list=[]
for i in range(len(labels)):
param_list.append(eq.Parameter(distribution=distributions[i],data=values,order=order))
return param_list


@app.callback(
ServersideOutput('ParamData','data'),
ServersideOutput('BasisObj','data'),
Input('upload-data-table','data'),
Input('upload-data-table', 'columns'),
Input('output-select','value'),
Input('mode-select','value'),
Input('distribution-select','value'),
Input('order-select','value'),
Input('CU_button_datadriven','n_clicks'),
Input({'type':'drop_vals','index':dash.dependencies.ALL},'value'),
Input('column-headers','data'),
prevent_initial_call=True
)
def ComputeParams(data,columns,output,mode,distribution,n_clicks):
def ComputeParams(data,columns,output,mode,order,n_clicks,distribution_semi,col_list):
changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]
if 'CU_button_datadriven' in changed_id:
for i in range(len(data)):
data[i].pop('{}'.format(output))
if mode=='manual':
param_objs=CreateParam(data,columns,distribution)
if mode=='auto':
param_objs=CreateParamWeights(data,columns,order)
mybasis=Set_Basis()
return param_objs,mybasis
else:
param_objs=CreateParamWeights(data,columns,distribution)
param_objs=CreateParamSemi(data,col_list,distribution_semi,output,order)
mybasis=Set_Basis()
return param_objs,mybasis
else:
Expand Down Expand Up @@ -587,7 +582,6 @@ def SetModel(params,mybasis,data,cols,method,y):
# return None, None, None, None, False, True, "Incorrect Model evaluations"
mean, var = mypoly.get_mean_and_variance()
DOE=mypoly.get_points()
print(mean,var)
y_pred = mypoly.get_polyfit(np.array(x_data))
r2_score = eq.datasets.score(np.array(y_data), y_pred, metric='r2')
return mypoly, mean, var, r2_score ###
Expand Down Expand Up @@ -775,3 +769,52 @@ def Plot_poly_3D(mypoly, cols,fig):
raise PreventUpdate
else:
raise PreventUpdate



@app.callback(
Output('param_add_datadriven','children'),
Input('mode-select', 'value'),
Input('column-headers','data'),
Input('output-select', 'value'),
State('param_add_datadriven', 'children'),
prevent_intial_call=True
)
def AddParams(mode,cols,output,children):
changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]
if 'mode-select' in changed_id:
if mode=='semi':
children=[]
labels=[]
for i in cols:
if i!=output:
labels.append(i)
else:
pass
added_elements=[]
for ind,vals in enumerate(labels):
elements=dbc.Form([
dbc.Label('{}'.format(vals),html_for={'type':'drop_vals','index':ind}),
dbc.Row([
dbc.Col(dcc.Dropdown(
options=[
{'label': 'Uniform', 'value': 'uniform'},
{'label': 'Gaussian', 'value': 'gaussian'},
{'label':'Beta','value':'beta'},
{'label':'lognormal','value':'lognormal'},
{'label':'exponential','value':'exponential'}
],placeholder='Select a distribution', value='uniform', clearable=False,
className="m-1", id={'type':'drop_vals','index':ind}))
])
])
added_elements.append(dbc.Col([elements],width=2))
add_card=dbc.Row([*added_elements])
children.append(add_card)
return children
else:
children=None
return children
else:
raise PreventUpdate


5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ dash-extensions
flask_caching
plotly
numpy
git+git://github.com/equadratures/equadratures@17f15fa#egg=equadratures
git+https://github.com/Simardeep27/equadratures-1.git@feature_fitparam
gunicorn
pylibmc
func-timeout
whitenoise
numexpr
pandas
base64
base64
github.com/Simardeep27/equadratures-1.git

0 comments on commit 5495061

Please sign in to comment.