-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
146 lines (125 loc) · 7.19 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from torch import nn
from torchvision.models.video import r3d_18, R3D_18_Weights, MC3_18_Weights, mc3_18
from torchvision.models.video import r2plus1d_18, R2Plus1D_18_Weights, s3d, S3D_Weights
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights, mvit_v1_b, MViT_V1_B_Weights
import torch
from SoccerFoulProject.config.classes import *
from SoccerFoulProject.train import MVFoulTrainer
from SoccerFoulProject.validator import MVFoulValidator
from SoccerFoulProject.utils import CFG, plot_results
from torchvision.io import read_video
import pandas as pd
import json
from pathlib import Path
from SoccerFoulProject import LOGGER
from tqdm import tqdm
class Model(nn.Module):
def __init__(self,video_encoder_name='r3d_18', clip_aggregation='mean',feat_dim=100):
super(Model,self).__init__()
if video_encoder_name== 'r3d_18':
self.video_encoder= r3d_18(weights= R3D_18_Weights.DEFAULT)
elif video_encoder_name=='mc3_18':
self.video_encoder= mc3_18(weights=MC3_18_Weights.DEFAULT)
elif video_encoder_name=='r2plus1d_18':
self.video_encoder= r2plus1d_18(weights=R2Plus1D_18_Weights.DEFAULT)
self.clip_agregation=clip_aggregation
self.action_classifcation_net=nn.Sequential(
nn.LayerNorm(400),
nn.Linear(400, feat_dim),
nn.Sigmoid(),
nn.Linear(feat_dim, len(EVENT_DICTIONARY_action_class)),
)
self.offence_classification_net=nn.Sequential(
nn.LayerNorm(400),
nn.Linear(400, feat_dim),
nn.Sigmoid(),
nn.Linear(feat_dim, 4),
)
self.cfg=None
self.results=pd.DataFrame(columns=['epochs','Train Action loss','Val Action loss','Train Offence severity loss','Val Offence severity loss','Train Action Accuracy','Val Action Accuracy','Train Offence severity Accuracy','Val Offence severity Accuracy'])
def forward(self, batch_clips):
#compute video features
batched_pred_action = torch.empty(0, len(EVENT_DICTIONARY_action_class),device=self.cfg.device)
batched_pred_offence_severity = torch.empty(0, len(EVENT_DICTIONARY_offence_severity_class),device=self.cfg.device)
for clips in batch_clips:
all_clip_features=self.video_encoder(clips)
#aggregate all clips' features
if self.clip_agregation=='mean':
action_features= torch.mean(all_clip_features,dim=0)
elif self.clip_agregation=='max':
action_features,_=torch.max(all_clip_features,dim=0)
else:
raise ValueError('Clip aggreagation method should be mean or max')
pred_action=self.action_classifcation_net(action_features)
pred_offence_severity=self.offence_classification_net(action_features)
batched_pred_action =torch.cat((batched_pred_action,pred_action.unsqueeze(0)),dim=0)
batched_pred_offence_severity =torch.cat((batched_pred_offence_severity,pred_offence_severity.unsqueeze(0)),dim=0)
return batched_pred_action,batched_pred_offence_severity
def do_train(self, train_dataset,val_dataset,cfg):
trainer=MVFoulTrainer(self,train_dataset,cfg)
self.cfg=cfg
validator=MVFoulValidator(self,val_dataset,cfg)
best_averaged_accuracy=0
num_epochs_with_no_improvement=0
self.to(cfg.device)
for epoch in tqdm(range(cfg.num_epochs)):
trainer.train_step()
validator.validation_step()
#Appending training and validation results
new_col = pd.concat([pd.Series({'epochs': epoch+1}), trainer.new_series, validator.new_series])
self.results.loc[epoch]=new_col
#saving model's last weights
self.save(trainer.save_folder/Path('weights')/Path('last.pth'))
#Implementing early stopping and saving the best model's params
curr_overall_accuracy=(self.results.loc[epoch,'Val Action Accuracy'] + self.results.loc[epoch,'Val Offence severity Accuracy'])/2
if best_averaged_accuracy<curr_overall_accuracy:
best_averaged_accuracy=curr_overall_accuracy
self.save(trainer.save_folder/Path('weights')/Path('best.pth'))
num_epochs_with_no_improvement=0
else:
num_epochs_with_no_improvement+=1
if num_epochs_with_no_improvement==cfg.patience:
LOGGER.info('Model has stopeed learning after %d epochs with no improvement in accuracy',cfg.patience)
break
#Plotting results
plot_results(self.results,path=trainer.save_folder,save=True,plot=False)
#saving the hyperparameters
with open(trainer.save_folder.__str__()+'/hyperparameters.json','w') as f:
json.dump(cfg.to_dictionnary(),f)
#Saving confusion matrix
trainer.action_conf_matrix.plot(normalized=False,show=False,save=True,path=trainer.save_folder,prefix="train_action")
validator.action_conf_matrix.plot(normalized=False,show=False,save=True,path=trainer.save_folder,prefix="val_action")
trainer.off_sev_conf_matrix.plot(normalized=False,show=False,save=True,path=trainer.save_folder,prefix="train_off_sev")
validator.off_sev_conf_matrix.plot(normalized=False,show=False,save=True,path=trainer.save_folder,prefix="val_off_sev")
def save(self,path:Path):
if not(path.parent.exists()):
Path(path.parent).mkdir(parents=True)
torch.save(self.state_dict(),path.__str__())
def load(self,path):
weights=torch.load(path)
self.load_state_dict(weights)
def predict(self,path,start=0,end=5,device=torch.device('cpu')):
self.cfg=CFG(device=device)
path=Path(path)
if not(path.exists()):
raise ValueError(f'{path.__str__()} does not exist')
elif path.is_dir():
videos=[]
for clip_path in Path(path).glob('*.mp4'):
video=read_video(clip_path,pts_unit='sec', output_format='TCHW',start_pts=start,end_pts=end)[0].float()
if self.cfg and self.cfg.transform:
video=self.cfg.transform(video)
videos.append(video.permute(1,0,2,3).unsqueeze(0))
videos=torch.vstack(videos)
elif path.is_file():
video=read_video(path,pts_unit='sec', output_format='TCHW',start_pts=start,end_pts=end)[0].float()
if self.cfg and self.cfg.transform:
video=self.cfg.transform(video)
videos=video.permute(1,0,2,3).unsqueeze(0)
self.to(self.cfg.device)
self.eval()
with torch.inference_mode():
pred_action,pred_offence_severity=self(videos.unsqueeze(0).to(self.cfg.device))
action=INVERSE_EVENT_DICTIONARY_action_class[ torch.argmax(torch.sigmoid(pred_action.squeeze())).item()]
off_severity=INVERSE_EVENT_DICTIONARY_offence_severity_class[ torch.argmax(torch.sigmoid(pred_offence_severity.squeeze())).item()]
return action,off_severity