Skip to content

Commit

Permalink
Fix compatibility of object_detection for Python3 (tensorflow#1610)
Browse files Browse the repository at this point in the history
* make batcher compatible for py3

* make prefetcher and operations in ops compatible for py3

* use six.iteritem

* make all tests in  compatible with py3

* simplify usage of six and modify import order

* add back the space line
  • Loading branch information
leVirve authored and sguada committed Jun 21, 2017
1 parent 477ed41 commit 3f9382a
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
6 changes: 3 additions & 3 deletions object_detection/core/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def __init__(self, tensor_dict, batch_size, batch_queue_capacity,
"""
# Remember static shapes to set shapes of batched tensors.
static_shapes = collections.OrderedDict(
{key: tensor.get_shape() for key, tensor in tensor_dict.iteritems()})
{key: tensor.get_shape() for key, tensor in tensor_dict.items()})
# Remember runtime shapes to unpad tensors after batching.
runtime_shapes = collections.OrderedDict(
{(key, 'runtime_shapes'): tf.shape(tensor)
for key, tensor in tensor_dict.iteritems()})
for key, tensor in tensor_dict.items()})
all_tensors = tensor_dict
all_tensors.update(runtime_shapes)
batched_tensors = tf.train.batch(
Expand All @@ -109,7 +109,7 @@ def dequeue(self):
# Separate input tensors from tensors containing their runtime shapes.
tensors = {}
shapes = {}
for key, batched_tensor in batched_tensors.iteritems():
for key, batched_tensor in batched_tensors.items():
unbatched_tensor_list = tf.unstack(batched_tensor)
for i, unbatched_tensor in enumerate(unbatched_tensor_list):
if isinstance(key, tuple) and key[1] == 'runtime_shapes':
Expand Down
2 changes: 1 addition & 1 deletion object_detection/core/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def multiclass_non_max_suppression(boxes,
boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
per_class_masks)
if additional_fields is not None:
for key, tensor in additional_fields.iteritems():
for key, tensor in additional_fields.items():
boxlist_and_class_scores.add_field(key, tensor)
boxlist_filtered = box_list_ops.filter_greater_than(
boxlist_and_class_scores, score_thresh)
Expand Down
2 changes: 1 addition & 1 deletion object_detection/core/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def prefetch(tensor_dict, capacity):
Returns:
a FIFO prefetcher queue
"""
names = tensor_dict.keys()
names = list(tensor_dict.keys())
dtypes = [t.dtype for t in tensor_dict.values()]
shapes = [t.get_shape() for t in tensor_dict.values()]
prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes,
Expand Down
7 changes: 6 additions & 1 deletion object_detection/core/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@

"""Tests for object_detection.core.preprocessor."""

import mock
import numpy as np
import six

import tensorflow as tf

from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields

if six.PY2:
import mock # pylint: disable=g-import-not-at-top
else:
from unittest import mock # pylint: disable=g-import-not-at-top


class PreprocessorTest(tf.test.TestCase):

Expand Down
7 changes: 4 additions & 3 deletions object_detection/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""A module for helper tensorflow ops."""
import math
import six

import tensorflow as tf

Expand Down Expand Up @@ -197,9 +198,9 @@ def padded_one_hot_encoding(indices, depth, left_pad):
TODO: add runtime checks for depth and indices.
"""
if depth < 0 or not isinstance(depth, (int, long)):
if depth < 0 or not isinstance(depth, (int, long) if six.PY2 else int):
raise ValueError('`depth` must be a non-negative integer.')
if left_pad < 0 or not isinstance(left_pad, (int, long)):
if left_pad < 0 or not isinstance(left_pad, (int, long) if six.PY2 else int):
raise ValueError('`left_pad` must be a non-negative integer.')
if depth == 0:
return None
Expand Down Expand Up @@ -548,7 +549,7 @@ def position_sensitive_crop_regions(image,
raise ValueError('crop_size should be divisible by num_spatial_bins')

total_bins *= num_bins
bin_crop_size.append(crop_dim / num_bins)
bin_crop_size.append(crop_dim // num_bins)

if not global_pool and bin_crop_size[0] != bin_crop_size[1]:
raise ValueError('Only support square bin crop size for now.')
Expand Down

0 comments on commit 3f9382a

Please sign in to comment.