diff --git a/src/unity/python/turicreate/test/test_drawing_classifier.py b/src/unity/python/turicreate/test/test_drawing_classifier.py index c04a1a3d70..4dff9ba832 100644 --- a/src/unity/python/turicreate/test/test_drawing_classifier.py +++ b/src/unity/python/turicreate/test/test_drawing_classifier.py @@ -317,3 +317,9 @@ def setUpClass(self): super(DrawingClassifierFromScratchTest, self).setUpClass( warm_start=None) +class DrawingClassifierUsingQuickdraw245(DrawingClassifierTest): + @classmethod + def setUpClass(self): + super(DrawingClassifierUsingQuickdraw245, self).setUpClass( + warm_start="quickdraw_245_v0") + diff --git a/src/unity/python/turicreate/toolkits/_pre_trained_models.py b/src/unity/python/turicreate/toolkits/_pre_trained_models.py index 8eb71a9a2f..087f40868c 100644 --- a/src/unity/python/turicreate/toolkits/_pre_trained_models.py +++ b/src/unity/python/turicreate/toolkits/_pre_trained_models.py @@ -245,14 +245,11 @@ def get_model_path(self, format): class DrawingClassifierPreTrainedModel(object): def __init__(self, warm_start="auto"): self.model_to_filename = { + "auto": "drawing_classifier_pre_trained_model_245_classes_v0.params", "quickdraw_245_v0": "drawing_classifier_pre_trained_model_245_classes_v0.params" } - self.warm_start = "quickdraw_245_v0" if warm_start == "auto" else warm_start - self.source_url = (_urlparse.urljoin( - MODELS_URL_ROOT, self.model_to_filename[self.warm_start]) - if warm_start == 'auto' - else warm_start - ) + self.source_url = _urlparse.urljoin( + MODELS_URL_ROOT, self.model_to_filename[warm_start]) # @TODO: Think about how to bypass the md5 checksum if the user wants to # provide their own pretrained model. self.source_md5 = "71ba78e48a852f35fb22999650f0a655" diff --git a/src/unity/python/turicreate/toolkits/drawing_classifier/drawing_classifier.py b/src/unity/python/turicreate/toolkits/drawing_classifier/drawing_classifier.py index ed6b2866dd..f72ac3357d 100644 --- a/src/unity/python/turicreate/toolkits/drawing_classifier/drawing_classifier.py +++ b/src/unity/python/turicreate/toolkits/drawing_classifier/drawing_classifier.py @@ -85,11 +85,13 @@ def create(input_dataset, target, feature=None, validation_set='auto', warm_start : string optional A string to denote which pretrained model to use. Set to "auto" by default which uses a model trained on 245 of the 345 classes in the - Quick, Draw! dataset. Here is a list of all the pretrained models that + Quick, Draw! dataset. To disable warm start, pass in None to this + argument. Here is a list of all the pretrained models that can be passed in as this argument: "auto": Uses quickdraw_245_v0 "quickdraw_245_v0": Uses a model trained on 245 of the 345 classes in the Quick, Draw! dataset. + None: No Warm Start batch_size: int optional The number of drawings per training step. If not set, a default @@ -132,6 +134,7 @@ def create(input_dataset, target, feature=None, validation_set='auto', from .._mxnet import _mxnet_utils start_time = _time.time() + accepted_values_for_warm_start = ["auto", "quickdraw_245_v0", None] # @TODO: Should be able to automatically choose number of iterations # based on data size: Tracked in Github Issue #1576 @@ -226,6 +229,14 @@ def create(input_dataset, target, feature=None, validation_set='auto', model_params.initialize(_mx.init.Xavier(), ctx=ctx) if warm_start is not None: + if type(warm_start) is not str: + raise TypeError("'warm_start' must be a string or None. " + + "'warm_start' can take in the following values: " + + str(accepted_values_for_warm_start)) + if warm_start not in accepted_values_for_warm_start: + raise _ToolkitError("Unrecognized value for 'warm_start': " + + warm_start + ". 'warm_start' can take in the following " + + "values: " + str(accepted_values_for_warm_start)) pretrained_model = _pre_trained_models.DrawingClassifierPreTrainedModel( warm_start) pretrained_model_params_path = pretrained_model.get_model_path()