Skip to content

Commit

Permalink
updated cnn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vivamoto committed Aug 25, 2020
1 parent 881ea45 commit 0923443
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion cnn/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,16 @@ def dropout(y, dy, idx, tx):
return y, dy

def init_params(cnn,opts):
"""
Setup CNN based on network parameters.
Args:
cnn: object with cnn parameters
opts: options object
Returns:
cnn object with initial setup
"""
numFiltros1 = opts.imageCanal #number of channels - number of first layer input
dimInputX = opts.imageDimX #original image X dimension
dimInputY = opts.imageDimY #original image Y dimension
Expand Down Expand Up @@ -539,6 +549,18 @@ def init_params(cnn,opts):
return cnn

def train_cnn(cnn,images,labels, pdir = ''):
"""
Train convolutional neural network.
Args:
cnn: cnn object with initial setup and parameters
images: array with images to process
labels: vector with image labels
pdir: directory to save plot if pdir is not empty
Returns:
create 'Total Cost vs Epoch' plot and return cnn object with weights
"""
it = 0 # number of iterations
plotData = pd.DataFrame(columns=['Epoch', 'Iteration', 'Total Cost', 'Cost'])

Expand Down Expand Up @@ -915,6 +937,7 @@ def loadMNISTLabels(filename):
# Demo with MNIST dataset
#=====================================
if __name__ == '__main__':

opts = obj()
opts.alpha = 1e-1 # learning rate
opts.batchsize = 50 # training set size = 150
Expand Down Expand Up @@ -982,7 +1005,7 @@ def loadMNISTLabels(filename):
# cnn.layers[3].dimPool = 2 # filter size
# cnn.layers[3].criterio = 'max' # max/mean

# full connected layer
# full connected layer: last layer setup is optional
# cnn.layers[2] = obj() #
# cnn.layers[2].type = 'f' # f = full connected
# cnn.layers[2].fativ = 'sig' # sig/relu: activation funciton
Expand Down

0 comments on commit 0923443

Please sign in to comment.