Skip to content

Commit

Permalink
random rotation of rectangle masks
Browse files Browse the repository at this point in the history
  • Loading branch information
deeppomf committed Feb 17, 2018
1 parent e268490 commit 296bb71
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ $ python train.py

# To do
- ~~Add Python 3 compatibility~~
- Add random rotations in cropping rectangles
- Retrain for arbitrary shape censors
- Add a user interface
- Incorporate GAN loss into training
- Update the model to the new version

Contributions are welcome!

Expand Down
14 changes: 11 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tqdm
from model import Model
import load
import scipy.ndimage

IMAGE_SIZE = 128
LOCAL_SIZE = 64
Expand All @@ -12,6 +13,8 @@
LEARNING_RATE = 1e-3
BATCH_SIZE = 16
PRETRAIN_EPOCH = 100
#the chance the rectangle crop will be rotated
ROTATE_CHANCE = 0.5

def train():
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3])
Expand Down Expand Up @@ -129,12 +132,17 @@ def get_points():

m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
m[q1:q2 + 1, p1:p2 + 1] = 1
mask.append(m)

if (np.random.random() < ROTATE_CHANCE):
#rotate random amount between 0 and 90 degrees
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
#set all elements greater than 0 to 1
m[m > 0] = 1

mask.append(m)

return np.array(points), np.array(mask)


if __name__ == '__main__':
train()

train()

0 comments on commit 296bb71

Please sign in to comment.