Adding infinite sampling functionality + updating lm code to take multiple prompts

#3
by shyamsn97 - opened
Files changed (2) hide show
  1. app.py +105 -23
  2. mario_gpt/lm.py +1 -1
app.py CHANGED
@@ -46,25 +46,79 @@ def make_html_file(generated_level):
46
  </html>''')
47
  return f"demo-{unique_id}.html"
48
 
49
- def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""):
50
- if prompt == "":
51
- prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
52
- print(f"Using prompt: {prompt}")
53
- prompts = [prompt]
54
- generated_level = mario_lm.sample(
 
 
 
 
 
 
 
 
 
55
  prompts=prompts,
56
  num_steps=level_size,
57
  temperature=temperature,
58
- use_tqdm=True
 
59
  )
60
- filename = make_html_file(generated_level)
61
- img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
62
-
 
 
63
  gradio_html = f'''<div>
64
  <iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe>
65
  <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
66
  </div>'''
67
- return [img, gradio_html]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  with gr.Blocks().queue() as demo:
70
  gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
@@ -83,27 +137,55 @@ with gr.Blocks().queue() as demo:
83
 
84
  with gr.Accordion(label="Advanced settings", open=False):
85
  temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
86
- level_size = gr.Number(value=1399, precision=0, label="level_size")
87
-
88
- btn = gr.Button("Generate level")
 
 
 
89
  with gr.Row():
90
  with gr.Box():
91
- level_play = gr.HTML()
92
- level_image = gr.Image()
93
- btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  gr.Examples(
95
  examples=[
96
- ["many", "many", "some", "high"],
97
  ["no", "some", "many", "high", 2.0],
98
- ["many", "many", "little", "low", 2.0],
99
- ["no", "no", "many", "high", 2.4],
100
  ],
101
- inputs=[pipes, enemies, blocks, elevation],
102
  outputs=[level_image, level_play],
103
- fn=generate,
104
  cache_examples=True,
105
  )
106
 
107
  app.mount("/static", StaticFiles(directory="static", html=True), name="static")
108
  app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
109
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
46
  </html>''')
47
  return f"demo-{unique_id}.html"
48
 
49
+ def trim_level(level):
50
+ mod = level.shape[-1] % 14
51
+ if mod > 0:
52
+ return level[:, :-mod]
53
+ return level
54
+
55
+ def reset_state(seed_state):
56
+ length = len(seed_state)
57
+ print(f"Resetting state with {length} levels!")
58
+ for _ in range(length):
59
+ seed_state.pop()
60
+
61
+ def _generate_level(prompts, seed, level_size, temperature):
62
+ print(f"Using prompts: {prompts}")
63
+ generated_levels = mario_lm.sample(
64
  prompts=prompts,
65
  num_steps=level_size,
66
  temperature=temperature,
67
+ use_tqdm=True,
68
+ seed = seed
69
  )
70
+ generated_levels = trim_level(generated_levels)
71
+ return generated_levels
72
+
73
+ def _make_gradio_html(level):
74
+ filename = make_html_file(level)
75
  gradio_html = f'''<div>
76
  <iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe>
77
  <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
78
  </div>'''
79
+ return gradio_html
80
+
81
+ def initialize_generate(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400):
82
+ prompts = [f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"]
83
+ generated_levels = _generate_level(prompts, None, level_size, temperature)
84
+ level = generated_levels.squeeze().detach().cpu()
85
+ img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
86
+ return [img, _make_gradio_html(level)]
87
+
88
+ def generate_choices(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400, prompt = "", seed_state = []):
89
+ NUM_SAMPLES = 2
90
+ if prompt == "":
91
+ prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
92
+ prompts = [prompt] * NUM_SAMPLES
93
+
94
+ seed = None
95
+ if len(seed_state) > 0:
96
+ seed = torch.cat(seed_state).squeeze()[-48*14:].view(1, -1).repeat(NUM_SAMPLES, 1) # context length
97
+
98
+ generated_levels = _generate_level(prompts, seed, level_size, temperature).detach().cpu().squeeze()
99
+ level_choices = [generated_level[-level_size:] for generated_level in generated_levels]
100
+ level_choice_images = [convert_level_to_png(generated_level[-level_size:], TILE_DIR, mario_lm.tokenizer)[0] for generated_level in generated_levels]
101
+
102
+ # level choices + separate images
103
+ return [level_choices, *level_choice_images]
104
+
105
+ def update_level_state(choice_id, level_choices, seed_state):
106
+ num_choice = int(choice_id)
107
+ level_choice = level_choices[num_choice]
108
+
109
+ # append level choice to seed state
110
+ seed_state.append(level_choice)
111
+
112
+ # get new level from concatenation
113
+ level = torch.cat(seed_state).squeeze()
114
+
115
+ # final image and gradio html
116
+ img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
117
+ gradio_html = _make_gradio_html(level)
118
+
119
+ # return img, gradio html, seed state, level_choice, choice_image_1, choice_image_2, current_level_size
120
+ return img, gradio_html, seed_state, None, None, None, level.shape[-1]
121
+
122
 
123
  with gr.Blocks().queue() as demo:
124
  gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
 
137
 
138
  with gr.Accordion(label="Advanced settings", open=False):
139
  temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
140
+ level_size = gr.Number(value=1400, precision=0, label="level_size")
141
+
142
+ generate_btn = gr.Button("Generate Level")
143
+ reset_btn = gr.Button("Reset Level")
144
+
145
+
146
  with gr.Row():
147
  with gr.Box():
148
+ level_play = gr.HTML()
149
+ level_image = gr.Image(label="Current Level")
150
+ with gr.Box():
151
+ with gr.Column():
152
+ level_choice1_image = gr.Image(label="Sample Choice 1")
153
+ level_choice1_btn = gr.Button("Sample Choice 1")
154
+ with gr.Column():
155
+ level_choice2_image = gr.Image(label="Sample Choice 2")
156
+ level_choice2_btn = gr.Button("Sample Choice 2")
157
+ current_level_size = gr.Number(0, visible=True, label="Current Level Size")
158
+
159
+
160
+ seed_state = gr.State([])
161
+ state_choices = gr.State(None)
162
+
163
+ image_choice_1_id = gr.Number(0, visible=False)
164
+ image_choice_2_id = gr.Number(1, visible=False)
165
+
166
+ # choice buttons
167
+ level_choice1_btn.click(fn=update_level_state, inputs=[image_choice_1_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])
168
+ level_choice2_btn.click(fn=update_level_state, inputs=[image_choice_2_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])
169
+
170
+ # generate_btn
171
+ generate_btn.click(fn=generate_choices, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt, seed_state], outputs=[state_choices, level_choice1_image, level_choice2_image])
172
+
173
+ # reset btn
174
+ reset_btn.click(fn=reset_state, inputs=[seed_state], outputs=[])
175
+
176
  gr.Examples(
177
  examples=[
178
+ ["many", "many", "some", "high", 2.0],
179
  ["no", "some", "many", "high", 2.0],
180
+ ["many", "many", "little", "low", 2.4],
181
+ ["no", "no", "many", "high", 2.8],
182
  ],
183
+ inputs=[pipes, enemies, blocks, elevation, temperature, level_size],
184
  outputs=[level_image, level_play],
185
+ fn=initialize_generate,
186
  cache_examples=True,
187
  )
188
 
189
  app.mount("/static", StaticFiles(directory="static", html=True), name="static")
190
  app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
191
+ uvicorn.run(app, host="0.0.0.0", port=7860)
mario_gpt/lm.py CHANGED
@@ -105,7 +105,7 @@ class MarioLM:
105
  self.lm.eval()
106
  with torch.no_grad():
107
  if seed is None:
108
- seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
109
  out = seed.to(self.device)
110
  if encoder_hidden_states is None:
111
  if prompts is not None:
 
105
  self.lm.eval()
106
  with torch.no_grad():
107
  if seed is None:
108
+ seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1).repeat(len(prompts), 1)
109
  out = seed.to(self.device)
110
  if encoder_hidden_states is None:
111
  if prompts is not None: