Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict function for logistic regression. #92

Merged
merged 4 commits into from
Jul 2, 2015
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions code/logistic_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import theano
import theano.tensor as T
from theano.gof import graph


class LogisticRegression(object):
Expand Down Expand Up @@ -415,6 +416,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
)
)

# save the best model
with open('best_model.pkl', 'w') as f:
cPickle.dump(classifier, f)

if patience <= iter:
done_looping = True
break
Expand All @@ -433,5 +438,37 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
os.path.split(__file__)[1] +
' ran for %.1fs' % ((end_time - start_time)))


def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""

# load the saved model
classifier = cPickle.load(open('best_model.pkl'))
y_pred = classifier.y_pred

# find the input to theano graph
inputs = graph.inputs([y_pred])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we told we don't use that as this is for the DLT. In the __init__, keep a reference to input. Do that for each example to be consistent, but you don't need to add a predict() function for each example

# select only x
inputs = [item for item in inputs if item.name == 'x']

# compile a predictor function
predict_model = theano.function(
inputs=inputs,
outputs=y_pred)

# We can test it on some examples from test test
dataset='mnist.pkl.gz'
datasets = load_data(dataset)
test_set_x, test_set_y = datasets[2]
test_set_x = test_set_x.get_value()

predicted_values = predict_model(test_set_x[:10])
print ("Predicted values for the first 10 examples in test set:")
print predicted_values


if __name__ == '__main__':
sgd_optimization_mnist()
14 changes: 14 additions & 0 deletions doc/logreg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,20 @@ approximately 1.936 epochs/sec and it took 75 epochs to reach a test
error of 7.489%. On the GPU the code does almost 10.0 epochs/sec. For this
instance we used a batch size of 600.


Prediction Using a Trained Model
+++++++++++++++++++++++++++++++

``sgd_optimization_mnist`` serialize and pickle the model each time new
lowest validation error is reached. We can reload this model and predict
labels of new data. ``predict`` function shows an example of how
this could be done.

.. literalinclude:: ../code/logistic_sgd.py
:start-after: ' ran for %.1fs' % ((end_time - start_time)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a way to specify a function I think. This is more robust. Can you check that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find any here: http://sphinx-doc.org/markup/code.html

The rest of tutorial use them syntax too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use :pyobject: predict, like explained in http://sphinx-doc.org/markup/code.html#includes:

If it is a Python module, you can select a class, function or method to include using the pyobject option

It is used in particular at https://github.com/memimo/DeepLearningTutorials/blob/predict/doc/logreg.txt#L219 in this file.
Otherwise, if it does not work, I would prefer to start and stop on explicit comments, as it is less likely to break later. For instance, if we add a new function after predict, then everything in that function until if __name__ ... wil get included as well, which is not what we want.

:end-before: if __name__ == '__main__':


.. rubric:: Footnotes

.. [#f1] For smaller datasets and simpler models, more sophisticated descent
Expand Down