Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lsf dev #59

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions atari/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: decision-transformer-atari
channels:
- pytorch
dependencies:
- python=3.7.9
- pytorch=1.2
- cudatoolkit=10.
- python=3.8
#- pytorch=1.2
#- cudatoolkit=10.
- numpy
- psutil
- opencv
Expand Down
4 changes: 2 additions & 2 deletions gym/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: decision-transformer-gym
channels:
- pytorch
dependencies:
- python=3.8.5
- python=3.8
- anaconda
- cudatoolkit=10.
#- cudatoolkit=10.
- numpy
- pip
- pip:
Expand Down
4 changes: 2 additions & 2 deletions gym/decision_transformer/evaluation/evaluate_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ def evaluate_episode(
model,
max_ep_len=1000,
device='cuda',
target_return=None,
target_return=0,
mode='normal',
state_mean=0.,
state_std=1.,
):

print("i am in")
model.eval()
model.to(device=device)

Expand Down
11 changes: 9 additions & 2 deletions gym/decision_transformer/models/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
self.predict_return = torch.nn.Linear(hidden_size, 1)

def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):

batch_size, seq_length = states.shape[0], states.shape[1]

if attention_mask is None:
Expand Down Expand Up @@ -89,8 +88,16 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_
)
x = transformer_outputs['last_hidden_state']

# for k, v in transformer_outputs.items():
# # print(k, type(v))
# if type(v) == tuple:
# print(k, v[0].detach().numpy().shape)
# else:
# print(k, v.detach().numpy().shape)

# reshape x so that the second dimension corresponds to the original
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
# 64, 20, 3, 128 -->> 64, 3, 20, 128
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

# get predictions
Expand Down Expand Up @@ -125,7 +132,7 @@ def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwarg
device=actions.device), actions],
dim=1).to(dtype=torch.float32)
returns_to_go = torch.cat(
[torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
[torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1) , device=returns_to_go.device), returns_to_go],
dim=1).to(dtype=torch.float32)
timesteps = torch.cat(
[torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
Expand Down
67 changes: 67 additions & 0 deletions gym/decision_transformer/models/load_cql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# import fire
import argparse, pickle
import os
import sys

sys.path.append(os.getcwd())
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

import torch, numpy as np, gym
from decision_transformer.models.offlinerl.algo import algo_select
from decision_transformer.models.offlinerl.data.d4rl import load_d4rl_buffer
from decision_transformer.models.offlinerl.evaluation import OnlineCallBackFunction

from decision_transformer.models.offlinerl.config.algo import cql_config, plas_config, mopo_config, moose_config, bcqd_config, bcq_config, bc_config, crr_config, combo_config, bremen_config, maple_config
from decision_transformer.models.offlinerl.utils.config import parse_config
from decision_transformer.models.offlinerl.algo.modelfree import cql, plas, bcqd, bcq, bc, crr
from decision_transformer.models.offlinerl.algo.modelbase import mopo, moose, combo, bremen, maple

def load_cql_q_network(env_name, dataset, mode, device, state_dim=None, action_dim=None):
root_path = os.getcwd()

algo = cql
algo_config_module = cql_config
algo_config = parse_config(algo_config_module)
algo_config['device'] = device
# for k, v in algo_config.items():
# command_args[k] = v
algo_config['env'] = env_name
algo_config['dataset'] = dataset
algo_config['mode'] = mode
algo_config['state_dim'] = state_dim
algo_config['act_dim'] = action_dim


algo_init = algo.algo_init(algo_config)
algo_trainer = algo.AlgoTrainer

algo_trainer = algo_trainer(algo_init, algo_config)

algo_trainer.load_q(root_path + "/saved_para/CQL/%s/%s/%s/"%(env_name, dataset, mode), 300, device)

get_q = algo_trainer.get_q
return get_q

def load_cql_actor(env_name, dataset, mode, device):
root_path = os.getcwd()

algo = cql
algo_config_module = cql_config
algo_config = parse_config(algo_config_module)
algo_config['device'] = device
# for k, v in algo_config.items():
# command_args[k] = v
algo_config['env'] = env_name
algo_config['dataset'] = dataset
algo_config['mode'] = mode

algo_init = algo.algo_init(algo_config)
algo_trainer = algo.AlgoTrainer

algo_trainer = algo_trainer(algo_init, algo_config)

algo_trainer.load_pi(root_path + "/saved_para/CQL/%s/%s/%s/"%(env_name, dataset, mode), 300, device=device)

get_action = algo_trainer.get_action
return get_action
83 changes: 83 additions & 0 deletions gym/decision_transformer/models/offlinerl/algo/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import uuid
import json
from abc import ABC, abstractmethod

import torch
from collections import OrderedDict
from loguru import logger
# from offlinerl.utils.exp import init_exp_logger
from offlinerl.utils.io import create_dir, download_helper, read_json
from offlinerl.utils.logger import log_path


class BaseAlgo(ABC):
def __init__(self, args):
logger.info('Init AlgoTrainer')
if "exp_name" not in args.keys():
exp_name = str(uuid.uuid1()).replace("-","")
else:
exp_name = args["exp_name"]

# if "aim_path" in args.keys():
# if os.path.exists(args["aim_path"]):
# repo = args["aim_path"]
# else:
# repo = None

# self.repo = repo
# self.exp_logger = init_exp_logger(repo = repo, experiment_name = exp_name)
# if self.exp_logger.repo is not None: # a naive fix of aim exp_logger.repo is None
# self.index_path = self.exp_logger.repo.index_path
# else:

repo = os.path.join(log_path(),"./.aim")
if not os.path.exists(repo):
logger.info('{} dir is not exist, create {}',repo, repo)
os.system(str("cd " + os.path.join(repo,"../") + "&& aim init"))
self.index_path = repo
# end else

self.models_save_dir = os.path.join(self.index_path, "models")
self.metric_logs = OrderedDict()
self.metric_logs_path = os.path.join(self.index_path, "metric_logs.json")
create_dir(self.models_save_dir)

# self.exp_logger.set_params(args, name='hparams')


def log_res(self, epoch, result):
logger.info('Epoch : {}', epoch)
for k,v in result.items():
logger.info('{} : {}',k, v)
self.exp_logger.track(v, name=k.split(" ")[0], epoch=epoch,)

self.metric_logs[str(epoch)] = result
with open(self.metric_logs_path,"w") as f:
json.dump(self.metric_logs,f)
self.save_model(os.path.join(self.models_save_dir, str(epoch) + ".pt"))


@abstractmethod
def train(self,
history_buffer,
eval_fn=None,):
pass

def _sync_weight(self, net_target, net, soft_target_tau = 5e-3):
for o, n in zip(net_target.parameters(), net.parameters()):
o.data.copy_(o.data * (1.0 - soft_target_tau) + n.data * soft_target_tau)

@abstractmethod
def get_policy(self,):
pass

#@abstractmethod
def save_model(self, model_path):
torch.save(self.get_policy(), model_path)

#@abstractmethod
def load_model(self, model_path):
model = torch.load(model_path)

return model
Loading