-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
86 lines (65 loc) · 3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, GRUCell, Embedding, Add, Concatenate, BatchNormalization
from tensorflow.keras.optimizers import RMSprop
class RIAL(tf.keras.Model):
def __init__(self, action_space, hidden_size, lr, moment):
'''
RNN that maintains an internal state h, an input network producing a task embedding z and
an output network for the q-values
'''
super(RIAL, self).__init__()
# 2 dense layer as specific problem MLP: input space = 21, output dim = 128
self.mlp = Sequential()
self.mlp.add(Dense(64, activation='relu'))
self.mlp.add(Dense(128, activation='relu'))
# 1-layer MLP to process message
self.mlp2 = Sequential()
self.mlp2.add(Dense(128, activation='relu'))
self.mlp2.add(BatchNormalization())
# embedding for action
self.emb_act = Embedding(input_dim=action_space, output_dim=128)
# embedding for agent index
self.emb_ind = Embedding(input_dim=2, output_dim= 128)
# produce state embedding as input for rnn through element-wise summation of the further processed inputs
self.add = Add(dynamic=True)
self.concat = Concatenate(dynamic=True)
# 2-layer RNN with GRUs that outputs internal state, approximates agent's action-observation history
# work with GRU cell to input last hidden state
self.hidden_size = hidden_size
self.rnn1 = GRUCell(hidden_size)
self.rnn2 = GRUCell(hidden_size)
# the output of the second layer used as input for 2-layer MLP that outputs Q-value
self.q_net = Sequential()
self.q_net.add(Dense(64, activation='relu'))
self.q_net.add(Dense(action_space, activation='relu'))
self.optimizer = RMSprop(learning_rate=lr, momentum=moment)
@tf.function
def call(self, input):
'''
input: features: (observation, last_action, last_message, agent id, hidden states) with
each feature as first dimension the batch_size
'''
#batch_size = input[0].shape[0]
state = input[0]
last_act = input[1]
last_m = input[2]
hidden = input[4]
# assert that hidden has shape (batch_size, 2, 100)
hidden = tf.transpose(hidden, [1,0,2])
agent = input[3]
x = self.mlp(state)
last_m = tf.cast(last_m, 'float')
last_m = self.mlp2(last_m)
last_act = tf.cast(last_act, 'float')
last_act = self.emb_act(last_act)
agent = tf.cast(agent, 'float')
agent = self.emb_ind(agent)
last_act = tf.reshape(last_act, [-1,128])
agent = tf.reshape(agent, [-1,128])
z = self.add([x, last_act, last_m, agent])
hidden_1, _ = self.rnn1(inputs=z, states = hidden[0])
hidden_2,_ = self.rnn2(inputs=hidden_1, states = hidden[1])
q = self.q_net(hidden_2)
return q, hidden_1, hidden_2