-
Notifications
You must be signed in to change notification settings - Fork 5
/
models.py
24 lines (23 loc) · 832 Bytes
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def create_model(opt):
model = None
print(opt.model)
if opt.model == 'cycle_gan':
assert(opt.dataset_mode == 'unaligned')
from .cycle_gan_model import CycleGANModel
model = CycleGANModel()
elif opt.model == 'sdrm_pix2pix':
from .sdrm_pix2pix_model import SDRMPix2PixModel
model = SDRMPix2PixModel()
elif opt.model == 'pix2pix':
assert(opt.dataset_mode == 'aligned')
from .pix2pix_model import Pix2PixModel
model = Pix2PixModel()
elif opt.model == 'test':
assert(opt.dataset_mode == 'single')
from .test_model import TestModel
model = TestModel()
else:
raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model