Skip to content

Commit

Permalink
Clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
lgalke committed Oct 10, 2024
1 parent 3d695e1 commit 3e98b2c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 369 deletions.
3 changes: 1 addition & 2 deletions debug.bash
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ OUTDIR="./results-v1"

SEED=123456789
LEXP="data/LearningExp_190501_S5_001_log.txt"
# Iterate over different seeds here
echo "Seed: $SEED"
echo "Starting run with experiment data: $LEXP"
python3 train.py --as_humans "$LEXP" --seed "$SEED" --debug --iterations 5 --outdir "tmp/"
python3 train.py --as_humans "$LEXP" --seed "$SEED" --debug --iterations 5 --outdir "tmp/"
126 changes: 0 additions & 126 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,13 @@

from egg import core

# TODO: https://github.com/pytorch/examples/blob/main/word_language_model/model.py


class RelaxedLinear(nn.Linear):

"""Overwrite to allow second, ignored, input"""

def __init__(self, *args, **kwargs):
"""TODO: to be defined.
:*args: TODO
:**kwargs: TODO
"""
super().__init__(*args, **kwargs)

def forward(self, x, __aux_input=None):
Expand Down Expand Up @@ -145,33 +138,6 @@ def tie_weights(
raise AssertionError("Could not resolve type mismatch")


class DoubleAgent(nn.Module):

"""Common forward signature for an agent capable of both sending and receiving"""

def __init__(self):
nn.Module.__init__(self)
self.tie_weights()

def forward(
self,
sender_input=None,
aux_input=None,
message=None,
receiver_input=None,
lengths=None,
):
if message is not None:
outputs = self.receiver(
message, input=receiver_input, aux_input=aux_input, lengths=length
)
else:
assert sender_input is not None
outputs = self.sender(sender_input, aux_input=aux_input)
return outputs

def tie_weights(self):
tie_weights(self.sender, self.receiver)


class MLP(nn.Module):
Expand All @@ -197,7 +163,6 @@ def __init__(
self.drop = nn.Dropout(dropout)

def forward(self, x, input=None, aux_input=None):
# TODO check if we want to do something with input/aux_input receivers
h = x
for i, layer in enumerate(self.layers):

Expand Down Expand Up @@ -293,12 +258,6 @@ def __init__(
straight_through: bool = False,
tied_weights: str = "all",
):
"""TODO: to be defined.
:input2hidden: TODO
:hidden2output: TODO
"""
nn.Module.__init__(self)

self.input2hidden = input2hidden # MLP / Linear
Expand Down Expand Up @@ -366,88 +325,3 @@ def forward(
sender_input, aux_input=aux_input
) # bsz, seqlen, vocab_size
return outputs


class MLPDoubleAgent(DoubleAgent):

"""MLP-based double agent"""

def __init__(
self,
input_dim,
hidden_size,
vocab_size,
max_length,
tie_weights=True,
dropout=0.5,
):
"""TODO: to be defined.
:l: TODO
"""
DoubleAgent.__init__(self)

self.dropout = nn.Dropout(dropout)
# Seeing: scene to latent
self.input2hidden = nn.Linear(input_dim, hidden_size)

# Writing: latent to message
self.hidden2msg = nn.Linear(hidden_size, vocab_size * max_length)

# Reading: message to latent
if tie_weights:
print("Tying weights...")
self.msg2hidden = TiedLinear(self.hidden2msg)
else:
self.msg2hidden = nn.Linear(vocab_size * max_length, hidden_size)

self.init_weights()

def init_weights(self):
initrange = 0.1

# input to hidden
nn.init.uniform_(self.input2hidden.weight, -initrange, initrange)
nn.init.zeros_(self.input2hidden.bias)

# hidden to message
nn.init.uniform_(self.hidden2msg.weight, -initrange, initrange)
nn.init.zeros_(self.hidden2msg.bias)

# message to hidden (be sure even if tied)
nn.init.uniform_(self.msg2hidden.weight, -initrange, initrange)
nn.init.zeros_(self.msg2hidden.bias)

def forward(self, x, aux_input=None):
self.input2hidden(x)
# TODO fill

raise NotImplementedError


class UniformAgentSamplerWithRoleAlternation(nn.Module):
"""Extension for egg.core.population to build populations that allow role alternation
>>> game = ...
>>> agents = ...
>>> agents_loss_sampler = RoleAlternationAgentSampler(agents)
>>> game = PopulationGame(game, agents_loss_sampler)
"""

def __init__(self, agents, losses, seed=1234):
super().__init__()
np.random.seed(seed)
self.agents = nn.ModuleList(agents)
self.losses = list(losses)

def forward(self):
s_idx, r_idx, l_idx = (
np.random.choice(len(self.senders)),
np.random.choice(len(self.receivers)),
np.random.choice(len(self.losses)),
)
return (
self.agents[s_idx], # Sender
self.agents[r_idx], # Receiver
self.losses[l_idx],
)
5 changes: 4 additions & 1 deletion stats.bash
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
python3 stats.py --models_subdir models_nested_new -o ~/Documents/project-data/paper__LEADS_easy-to-learn/results-v1-stats-output-v2/ ~/Documents/project-data/paper__LEADS_easy-to-learn/results-v1
# Helper script to run the stats.py script
# Adjust the path to the results directory
RESULTS_DIR=./results
python3 stats.py -o $RESULTS_DIR --models_subdir statsmodels $RESULTS_DIR
Loading

0 comments on commit 3e98b2c

Please sign in to comment.