-
Notifications
You must be signed in to change notification settings - Fork 0
/
params.py
85 lines (72 loc) · 1.93 KB
/
params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 10 14:34:13 2023
@author: hawkiyc
"""
#%%
'Import Libraries'
import ast
from datetime import datetime
import itertools
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import neurokit2 as nk
import numpy as np
import os
import pandas as pd
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, Normalizer
from tslearn.preprocessing import TimeSeriesResampler as tsr
import torch
import torch.nn as nn
from torch import Tensor
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import wfdb
#%%
'Parameters and Hyper-Parameters'
path = '../../DATA/PTB-XL'
sr = 100
add_noise = False
check_data = True
scaler = False
LEADs = [ "DI", "DII", "DIII", "AVL", "AVR", "AVF",
"V1", "V2", "V3", "V4", "V5", "V6"]
diagnose_label = ['CD', 'HYP', 'MI', "NORM", 'STTC']
model_out = len(diagnose_label)
d_input = (12, 1)
emb_size = 512
seq_length = 500
max_rr_seq = 20
batch_size = 24
n_epochs = 50
loss_fn = nn.BCEWithLogitsLoss()
out_activation = nn.Sigmoid()
#%%
"Setting GPU"
use_cpu = False
m_seed = 42
if use_cpu:
device = torch.device('cpu')
elif torch.cuda.is_available():
device = torch.device('cuda')
torch.cuda.manual_seed(m_seed)
torch.cuda.empty_cache()
elif not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was "
"NOT built with MPS enabled.")
else:
print("MPS not available because this MacOS version is NOT 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
else:
device = torch.device("mps")
print(device)
#%%
'Make Output Dir'
if not os.path.isdir('results'):
os.makedirs('results')