Skip to content

Commit

Permalink
added generate_seed functions + easier continued sampling with Sample…
Browse files Browse the repository at this point in the history
…Output objects
  • Loading branch information
shyamsn97 committed Feb 21, 2023
1 parent 74d9ff8 commit 595de76
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 29 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ generated_level.play()
# run Astar agent
generated_level.run_astar()

# Continue generation
generated_level_continued = mario_lm.sample(
seed=generated_level,
prompts=prompts,
num_steps=1400,
temperature=2.0,
use_tqdm=True
)

# load from text file
loaded_level = SampleOutput.load("generated_level.txt")

Expand Down
8 changes: 8 additions & 0 deletions mario_gpt/lm/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def __init__(
super().__init__(lm, tokenizer, context_len)
self.mask_proportion = mask_proportion

def generate_seed(self, length: int, batch_size: Optional[int] = None):
seed = self.tokenizer("X", return_tensors="pt").input_ids.squeeze()[
1:-1
] # remove start and end tokens
if batch_size is None:
return seed.repeat(length)
return seed.view(1, 1).repeat(batch_size, length)

def load_pretrained_lm(self) -> RobertaModel:
print(f"Using {PRETRAINED_MODEL_MASK_PATH} model")
return AutoModelForMaskedLM.from_pretrained(PRETRAINED_MODEL_MASK_PATH)
Expand Down
6 changes: 6 additions & 0 deletions mario_gpt/lm/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(
if prompter is None:
self.prompter = Prompter(self.tokenizer)

def generate_seed(self, length: int, batch_size: Optional[int] = None):
seed = self.tokenizer("X", return_tensors="pt").input_ids.squeeze()
if batch_size is None:
return seed.repeat(length)
return seed.view(1, 1).repeat(batch_size, length)

def load_pretrained_lm(self) -> GPT2Model:
print(f"Using {PRETRAINED_MODEL_PATH} model")
return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH)
Expand Down
22 changes: 15 additions & 7 deletions mario_gpt/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def step(

def sample(
self,
seed: Optional[torch.Tensor] = None,
seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
prompts: Optional[List[str]] = None,
num_steps: int = 1,
encoder_hidden_states: torch.Tensor = None,
Expand All @@ -184,12 +184,18 @@ def sample(
self.mario_lm.lm.eval()
with torch.no_grad():
if seed is None:
seed = (
self.mario_lm.tokenizer("X", return_tensors="pt")
.input_ids.view(1, 1)
.repeat(len(prompts), 1)
seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
self.device
)
out_tensor = seed.to(self.device)
out_tensor = seed.to(self.device)
elif isinstance(seed, SampleOutput):
out_tensor = seed.level_tensor.to(self.device).squeeze()
else:
out_tensor = seed.to(self.device).squeeze()
if len(out_tensor.shape) < 2:
# if we pass in a single seed vector, then we repeat for each prompt
# Otherwise, we treat inputs as separate seed-prompt pairs
out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
if encoder_hidden_states is None:
if prompts is not None:
encoder_hidden_states = torch.stack(
Expand All @@ -208,7 +214,9 @@ def sample(
encoder_hidden_states = encoder_hidden_states.to(
self.device
) # b x 1 x hidden_dim
encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1)
encoder_hidden_states = encoder_hidden_states.view(
out_tensor.shape[0], 1, -1
)
if not self.use_tqdm:
bar = np.arange(num_steps)
else:
Expand Down
Loading

0 comments on commit 595de76

Please sign in to comment.