Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Error handling for warm_start + fix for warm_start="quickdraw_245_v0" (
Browse files Browse the repository at this point in the history
  • Loading branch information
shantanuchhabra authored and Zach Nation committed Apr 29, 2019
1 parent c50c0f2 commit a903e3c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/unity/python/turicreate/test/test_drawing_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,9 @@ def setUpClass(self):
super(DrawingClassifierFromScratchTest, self).setUpClass(
warm_start=None)

class DrawingClassifierUsingQuickdraw245(DrawingClassifierTest):
@classmethod
def setUpClass(self):
super(DrawingClassifierUsingQuickdraw245, self).setUpClass(
warm_start="quickdraw_245_v0")

9 changes: 3 additions & 6 deletions src/unity/python/turicreate/toolkits/_pre_trained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,11 @@ def get_model_path(self, format):
class DrawingClassifierPreTrainedModel(object):
def __init__(self, warm_start="auto"):
self.model_to_filename = {
"auto": "drawing_classifier_pre_trained_model_245_classes_v0.params",
"quickdraw_245_v0": "drawing_classifier_pre_trained_model_245_classes_v0.params"
}
self.warm_start = "quickdraw_245_v0" if warm_start == "auto" else warm_start
self.source_url = (_urlparse.urljoin(
MODELS_URL_ROOT, self.model_to_filename[self.warm_start])
if warm_start == 'auto'
else warm_start
)
self.source_url = _urlparse.urljoin(
MODELS_URL_ROOT, self.model_to_filename[warm_start])
# @TODO: Think about how to bypass the md5 checksum if the user wants to
# provide their own pretrained model.
self.source_md5 = "71ba78e48a852f35fb22999650f0a655"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@ def create(input_dataset, target, feature=None, validation_set='auto',
warm_start : string optional
A string to denote which pretrained model to use. Set to "auto"
by default which uses a model trained on 245 of the 345 classes in the
Quick, Draw! dataset. Here is a list of all the pretrained models that
Quick, Draw! dataset. To disable warm start, pass in None to this
argument. Here is a list of all the pretrained models that
can be passed in as this argument:
"auto": Uses quickdraw_245_v0
"quickdraw_245_v0": Uses a model trained on 245 of the 345 classes in the
Quick, Draw! dataset.
None: No Warm Start
batch_size: int optional
The number of drawings per training step. If not set, a default
Expand Down Expand Up @@ -132,6 +134,7 @@ def create(input_dataset, target, feature=None, validation_set='auto',
from .._mxnet import _mxnet_utils

start_time = _time.time()
accepted_values_for_warm_start = ["auto", "quickdraw_245_v0", None]

# @TODO: Should be able to automatically choose number of iterations
# based on data size: Tracked in Github Issue #1576
Expand Down Expand Up @@ -226,6 +229,14 @@ def create(input_dataset, target, feature=None, validation_set='auto',
model_params.initialize(_mx.init.Xavier(), ctx=ctx)

if warm_start is not None:
if type(warm_start) is not str:
raise TypeError("'warm_start' must be a string or None. "
+ "'warm_start' can take in the following values: "
+ str(accepted_values_for_warm_start))
if warm_start not in accepted_values_for_warm_start:
raise _ToolkitError("Unrecognized value for 'warm_start': "
+ warm_start + ". 'warm_start' can take in the following "
+ "values: " + str(accepted_values_for_warm_start))
pretrained_model = _pre_trained_models.DrawingClassifierPreTrainedModel(
warm_start)
pretrained_model_params_path = pretrained_model.get_model_path()
Expand Down

0 comments on commit a903e3c

Please sign in to comment.