forked from dapowan/LIMU-BERT-Public
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding.py
68 lines (54 loc) · 2.66 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 13/1/2021
# @Author : Huatao
# @Email : [email protected]
# @File : embedding.py
# @Description : generate embeddings using pretrained LIMU-BERT models
import os
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
import train
from config import load_dataset_label_names
from models import LIMUBertModel4Pretrain
from plot import plot_reconstruct_sensor, plot_embedding
from utils import LIBERTDataset4Pretrain, load_pretrain_data_config, get_device, handle_argv, \
Preprocess4Normalization, IMUDataset
def fetch_setup(args, output_embed):
data, labels, train_cfg, model_cfg, mask_cfg, dataset_cfg = load_pretrain_data_config(args)
pipeline = [Preprocess4Normalization(model_cfg.feature_num)]
data_set = IMUDataset(data, labels, pipeline=pipeline)
data_loader = DataLoader(data_set, shuffle=False, batch_size=train_cfg.batch_size)
model = LIMUBertModel4Pretrain(model_cfg, output_embed=output_embed)
criterion = nn.MSELoss(reduction='none')
return data, labels, data_loader, model, criterion, train_cfg
def generate_embedding_or_output(args, save=False, output_embed=True):
data, labels, data_loader, model, criterion, train_cfg \
= fetch_setup(args, output_embed)
optimizer = None
trainer = train.Trainer(train_cfg, model, optimizer, args.save_path, get_device(args.gpu))
def func_forward(model, batch):
seqs, label = batch
embed = model(seqs)
return embed, label
output = trainer.run(func_forward, None, data_loader, args.pretrain_model)
# trainer.save()
if save:
save_name = 'embed_' + args.model_file.split('.')[0] + '_' + args.dataset + '_' + args.dataset_version
np.save(os.path.join('embed', save_name + '.npy'), output)
return data, output, labels
def load_embedding_label(model_file, dataset, dataset_version):
embed_name = 'embed_' + model_file + '_' + dataset + '_' + dataset_version
label_name = 'label_' + dataset_version
embed = np.load(os.path.join('embed', embed_name + '.npy')).astype(np.float32)
labels = np.load(os.path.join('dataset', dataset, label_name + '.npy')).astype(np.float32)
return embed, labels
if __name__ == "__main__":
save = True
mode = "base"
args = handle_argv('pretrain_' + mode, 'pretrain.json', mode)
data, output, labels = generate_embedding_or_output(args=args, output_embed=True, save=save)
# label_index = 1
# label_names, label_num = load_dataset_label_names(args.dataset_cfg, label_index)
# data_tsne, labels_tsne = plot_embedding(output, labels, label_index=label_index, reduce=1000, label_names=label_names)