-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_imbalance.py
43 lines (35 loc) · 1.86 KB
/
main_imbalance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import argparse
import copy, pickle
import numpy as np
from utils.utils import set_seed
import logging
from datasets.process import get_UCRArchive_2018_datasets_names
if __name__ == '__main__':
# torch.cuda.empty_cache()
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser()
# General settings
parser.add_argument('--dataset', type=str, default='SwedishLeaf', help='dataset name',
choices=get_UCRArchive_2018_datasets_names())
parser.add_argument('--lr', type=float, default=0.03, help='learning rate')
parser.add_argument('--batch_size', type=int, default=1000, help='batch size')
parser.add_argument('--epochs', type=int, default=1000, help='number of epochs')
parser.add_argument('--seeds', type=str, default='0:1:2:3', help='random seed')
parser.add_argument('--k_shot', type=int, default=10, help='number of k-shot')
parser.add_argument('-k_additional', type=int, default=2, help='number of additional k-shot')
parser.add_argument('--scaling_rate', type=float, default=0.6, help='scaling rate')
args = parser.parse_args()
seeds = [int(s) for s in args.seeds.split(':')]
for seed in seeds:
torch.cuda.empty_cache()
logging.log(logging.INFO, 'Iterate Seed: {}'.format(seed))
args.seed = seed
set_seed(seed)
from datasets.process import read_dataset_for_imbalance
from models.handlers import DTW_RNN_imbalance_handler
X_train, y_train, X_test, y_test, n_pos = read_dataset_for_imbalance(args.dataset)
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, n_pos)
print(len(np.where(y_train == 1)[0]), len(np.where(y_test == 1)[0]))
print(len(np.where(y_train == 0)[0]), len(np.where(y_test == 0)[0]))
DTW_RNN_imbalance_handler(X_train, y_train, X_test, y_test, n_pos, args)