-
Notifications
You must be signed in to change notification settings - Fork 95
/
eval_script.py
97 lines (82 loc) · 4.31 KB
/
eval_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3
import os
import json
import kecam
from keras_cv_attention_models.imagenet import evaluation
if kecam.backend.is_torch_backend: # os.environ["KECAM_BACKEND"] = "torch"
import torch
# Always 0, no matter CUDA_VISIBLE_DEVICES
global_device = torch.device("cuda:0") if torch.cuda.is_available() and int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")) >= 0 else torch.device("cpu")
else:
import tensorflow as tf
gpus = tf.config.experimental.get_visible_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
try:
import tensorflow_addons as tfa
except:
pass
def parse_arguments(argv):
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"-m",
"--model_path",
type=str,
required=True,
help="Could be: 1. Saved h5 / tflite model path. 2. Model name defined in this repo, format [sub_dir].[model_name] like regnet.RegNetZD8. 3. timm model like timm.models.resmlp_12_224",
)
parser.add_argument("-i", "--input_shape", type=int, default=-1, help="Model input shape, Set -1 for using model.input_shape")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("-d", "--data_name", type=str, default="imagenet2012", help="Dataset name from tensorflow_datasets like imagenet2012 cifar10")
parser.add_argument(
"--rescale_mode", type=str, default="auto", help="Rescale mode, one of [tf, torch, raw, raw01, tf128]. Default `auto` means using model preset"
)
parser.add_argument("--central_crop", type=float, default=0.95, help="Central crop fraction. Set -1 to disable")
parser.add_argument("--resize_method", type=str, default="bicubic", help="Resize method from tf.image.resize, like [bilinear, bicubic]")
parser.add_argument("--disable_antialias", action="store_true", help="Set use antialias=False for tf.image.resize")
parser.add_argument("--num_classes", type=int, default=None, help="num_classes if not inited from h5 file. None for model.num_classes")
parser.add_argument(
"--pretrained",
type=str,
default=None,
help="Pretrianed weights if not from h5. Could be [imagenet, noisy_student, imagenet21k, imagenet21k-ft1k, imagenet_sam], None for model.pretrained",
)
parser.add_argument(
"--additional_model_kwargs", type=str, default=None, help="Json format model kwargs like '{\"drop_connect_rate\": 0.05}'. Note all quote marks"
)
args = parser.parse_known_args(argv)[0]
return args
if __name__ == "__main__":
import sys
args = parse_arguments(sys.argv[1:])
input_shape = None if args.input_shape == -1 else (args.input_shape, args.input_shape, 3)
if args.model_path.startswith("timm."): # model_path like: timm.models.resmlp_12_224
import timm
timm_model_name = ".".join(args.model_path.split(".")[2:])
model = getattr(timm.models, timm_model_name)(pretrained=True)
elif args.model_path.endswith(".h5"):
model = tf.keras.models.load_model(args.model_path, compile=False)
elif args.model_path.endswith(".tflite"):
model = args.model_path
elif args.model_path.endswith(".onnx"):
model = args.model_path
else: # model_path like: volo.VOLO_d1
model_name = args.model_path.strip().split(".")
if len(model_name) == 1:
model_class = getattr(kecam.models, model_name[0])
else:
model_class = getattr(getattr(kecam, model_name[0]), model_name[1])
model_kwargs = json.loads(args.additional_model_kwargs) if args.additional_model_kwargs else {}
if input_shape:
model_kwargs.update({"input_shape": input_shape})
if args.num_classes:
model_kwargs.update({"num_classes": args.num_classes})
if args.pretrained:
model_kwargs.update({"pretrained": args.pretrained})
print(">>>> model_kwargs:", model_kwargs)
model = model_class(**model_kwargs)
if kecam.backend.is_torch_backend:
model.to(device=global_device)
antialias = not args.disable_antialias
evaluation(model, args.data_name, input_shape, args.batch_size, args.central_crop, args.resize_method, antialias, args.rescale_mode)