-
Notifications
You must be signed in to change notification settings - Fork 0
/
bandit.py
109 lines (86 loc) · 3.86 KB
/
bandit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import tensorflow as tf
import random as rand
import numpy as np
import copy
class Bandit(object):
def __init__(self, bandit_id, ticket_number, x_train, y_train, x_test, y_test, gamma =0.15, eta=3, seed=42) -> None:
self.bandit_id = bandit_id
self.ticket_number = ticket_number
self.gamma = gamma
self.eta = eta
self.ticket_allocation = {}
self.relationships = {}
self.loss_history = []
self.accuracy_history = []
self.relationship_history = []
self.communication_history = []
rand.seed(seed+bandit_id)
self.x_train = x_train
self.y_train = y_train
self.x_test = x_test
self.y_test = y_test
def set_model(self, model):
self.model = model
def get_weights(self):
return self.model.get_weights()
def normalize_relationships(self):
total = 0
for i in self.relationships:
total += self.relationships[i]
for i in self.relationships:
self.relationships[i] /= total
def scale_model(self):
for layer in self.model.get_layers():
if hasattr(layer, 'get_weights'):
weights = layer.get_weights()
for i in range(len(weights)):
weights[i] *= (1-self.gamma)
def utility(self, other, tickets):
score = - self.model.evaluate(self.x_test, self.y_test)[1]
for self_layer, other_layer in zip(self.model.get_layers(), other.model.get_layers()):
if hasattr(self_layer, 'get_weights'):
weights = zip(self_layer.get_weights(), other_layer.get_weights())
updated_weights = [self_weight+self.gamma*tickets/self.ticket_number*other_weight
for self_weight, other_weight in weights]
self_layer.set_weights(updated_weights)
score += self.model.evaluate(self.x_test, self.y_test)[1]
return score
def update_relationships(self, other):
self.relationships[other.bandit_id] = self.relationships.get(other.bandit_id, 0)*np.exp(self.eta*self.utility(other, self.ticket_allocation[other.bandit_id]))
def evaluate(self, x_test, y_test, write_to_history=False, step='train'):
if write_to_history:
loss = self.model.evaluate(x_test, y_test)
self.loss_history.append((step,loss[0]))
self.accuracy_history.append((step,loss[1]))
return loss
return self.model.evaluate(x_test, y_test)
def train(self, x_train, y_train, epochs):
self.model.train(x_train, y_train, epochs=epochs)
def get_history(self):
return self.loss_history
def get_relationship_history(self):
return self.relationship_history
def get_accuracy_history(self):
return self.accuracy_history
def get_ticket_allocation(self):
return self.ticket_allocation
def get_communication_history(self):
return self.communication_history
# Randomly sample ticket allocation from the weighted relationships
def sample_ticket_allocation(self):
self.ticket_allocation = {}
self.normalize_relationships()
self.relationship_history.append(copy.deepcopy(self.relationships))
distributed_tickets = 0
while distributed_tickets < self.ticket_number:
# uses the relationship weights to randomly distribute a ticket to another bandit
ticket_rd= rand.random()
acc=0
for i in self.relationships:
acc += self.relationships[i]
if ticket_rd < acc:
self.ticket_allocation[i] = self.ticket_allocation.get(i, 0) + 1
distributed_tickets += 1
break
print(f"Node {i} distributed {self.ticket_allocation}")
self.communication_history.append(copy.deepcopy(self.ticket_allocation))