-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
36 lines (33 loc) · 1.07 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
import gym
import ic3net_envs
from env_wrappers import *
def init(env_name, args, final_init=True):
if env_name == 'levers':
env = gym.make('Levers-v0')
env.multi_agent_init(args.total_agents, args.nagents)
env = GymWrapper(env)
elif env_name == 'number_pairs':
env = gym.make('NumberPairs-v0')
m = args.max_message
env.multi_agent_init(args.nagents, m)
env = GymWrapper(env)
elif env_name == 'predator_prey':
env = gym.make('PredatorPrey-v0')
if args.display:
env.init_curses()
env.multi_agent_init(args)
env = GymWrapper(env)
elif env_name == 'traffic_junction':
env = gym.make('TrafficJunction-v0')
if args.display:
env.init_curses()
env.multi_agent_init(args)
env = GymWrapper(env)
elif env_name == 'starcraft':
env = gym.make('StarCraftWrapper-v0')
env.multi_agent_init(args, final_init)
env = GymWrapper(env.env)
else:
raise RuntimeError("wrong env name")
return env