skytnt commited on
Commit
294c6ec
1 Parent(s): 5c45beb
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +159 -116
  3. javascript/app.js +59 -35
  4. midi_model.py +46 -25
.gitignore CHANGED
@@ -151,3 +151,4 @@ cython_debug/
151
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
152
  .idea/
153
  output.mid
 
 
151
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
152
  .idea/
153
  output.mid
154
+ /outputs/
app.py CHANGED
@@ -18,11 +18,12 @@ from midi_model import MIDIModel, MIDIModelConfig
18
  from midi_synthesizer import MidiSynthesizer
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
 
21
  in_space = os.getenv("SYSTEM") == "spaces"
22
 
23
 
24
  @torch.inference_mode()
25
- def generate(model: MIDIModel, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
26
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
27
  tokenizer = model.tokenizer
28
  if disable_channels is not None:
@@ -33,49 +34,69 @@ def generate(model: MIDIModel, prompt=None, max_len=512, temp=1.0, top_p=0.98, t
33
  if prompt is None:
34
  input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
35
  input_tensor[0, 0] = tokenizer.bos_id # bos
 
 
36
  else:
37
- prompt = prompt[:, :max_token_seq]
 
 
 
 
 
 
 
38
  if prompt.shape[-1] < max_token_seq:
39
- prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
40
  mode="constant", constant_values=tokenizer.pad_id)
41
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
42
- input_tensor = input_tensor.unsqueeze(0)
43
  cur_len = input_tensor.shape[1]
44
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
45
  with bar:
46
  while cur_len < max_len:
47
- end = False
48
- hidden = model.forward(input_tensor)[0, -1].unsqueeze(0)
49
  next_token_seq = None
50
- event_name = ""
51
  for i in range(max_token_seq):
52
- mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=model.device)
53
- if i == 0:
54
- mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
55
- if disable_patch_change:
56
- mask_ids.remove(tokenizer.event_ids["patch_change"])
57
- if disable_control_change:
58
- mask_ids.remove(tokenizer.event_ids["control_change"])
59
- mask[mask_ids] = 1
60
- else:
61
- param_name = tokenizer.events[event_name][i - 1]
62
- mask_ids = tokenizer.parameter_ids[param_name]
63
- if param_name == "channel":
64
- mask_ids = [i for i in mask_ids if i not in disable_channels]
65
- mask[mask_ids] = 1
 
 
 
 
 
 
 
 
 
66
  logits = model.forward_token(hidden, next_token_seq)[:, -1:]
67
  scores = torch.softmax(logits / temp, dim=-1) * mask
68
- sample = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
69
  if i == 0:
70
- next_token_seq = sample
71
- eid = sample.item()
72
- if eid == tokenizer.eos_id:
73
- end = True
74
- break
75
- event_name = tokenizer.id_events[eid]
 
 
 
76
  else:
77
- next_token_seq = torch.cat([next_token_seq, sample], dim=1)
78
- if len(tokenizer.events[event_name]) == i:
79
  break
80
  if next_token_seq.shape[1] < max_token_seq:
81
  next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
@@ -84,8 +105,8 @@ def generate(model: MIDIModel, prompt=None, max_len=512, temp=1.0, top_p=0.98, t
84
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
85
  cur_len += 1
86
  bar.update(1)
87
- yield next_token_seq.reshape(-1).cpu().numpy()
88
- if end:
89
  break
90
 
91
 
@@ -96,8 +117,9 @@ def create_msg(name, data):
96
  def send_msgs(msgs):
97
  return json.dumps(msgs)
98
 
99
- def get_duration(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig,
100
- key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
 
101
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
102
  if "large" in model_name:
103
  return gen_events // 10 + 15
@@ -106,9 +128,9 @@ def get_duration(model_name, tab, mid_seq, continuation_state, instruments, drum
106
 
107
 
108
  @spaces.GPU(duration=get_duration)
109
- def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
110
- reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
111
- gen_events, temp, top_p, top_k, allow_cc):
112
  model = models[model_name]
113
  model.to(device=opt.device)
114
  tokenizer = model.tokenizer
@@ -156,8 +178,8 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
156
  patches[9] = drum_kits2number[drum_kit]
157
  for i, (c, p) in enumerate(patches.items()):
158
  mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
159
- mid_seq = mid
160
- mid = np.asarray(mid, dtype=np.int64)
161
  if len(instruments) > 0:
162
  disable_patch_change = True
163
  disable_channels = [i for i in range(16) if i not in patches]
@@ -167,84 +189,91 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
167
  remap_track_channel=remap_track_channel,
168
  add_default_instr=add_default_instr,
169
  remove_empty_channels=remove_empty_channels)
170
- mid = np.asarray(mid, dtype=np.int64)
171
  mid = mid[:int(midi_events)]
172
- mid_seq = []
173
- for token_seq in mid:
174
- mid_seq.append(token_seq.tolist())
175
  elif tab == 2 and mid_seq is not None:
176
- continuation_state.append(len(mid_seq))
177
  mid = np.asarray(mid_seq, dtype=np.int64)
 
 
 
 
 
 
178
  else:
179
  continuation_state = [0]
180
- mid_seq = []
181
- mid = None
 
182
 
183
  if mid is not None:
184
- max_len += len(mid)
185
 
186
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
187
  init_msgs = [create_msg("progress", [0, gen_events])]
188
- if tab != 2:
189
- init_msgs += [create_msg("visualizer_clear", tokenizer.version),
190
- create_msg("visualizer_append", events)]
191
- yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
192
- midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
193
- disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
194
- disable_channels=disable_channels, generator=generator)
195
- events = []
196
- t = time.time() + 1
197
- for i, token_seq in enumerate(midi_generator):
198
- token_seq = token_seq.tolist()
199
- mid_seq.append(token_seq)
200
- events.append(tokenizer.tokens2event(token_seq))
201
- ct = time.time()
202
- if ct - t > 0.5:
203
- yield (mid_seq, continuation_state, None, None, seed,
204
- send_msgs([create_msg("visualizer_append", events),
205
- create_msg("progress", [i + 1, gen_events])]))
206
- t = ct
207
- events = []
 
 
 
 
 
 
208
 
209
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
210
- mid = tokenizer.detokenize(mid_seq)
211
- audio = synthesizer.synthesis(MIDI.score2opus(mid))
212
- with open(f"output.mid", 'wb') as f:
213
- f.write(MIDI.score2midi(mid))
214
- end_msgs = [create_msg("visualizer_clear", tokenizer.version),
215
- create_msg("visualizer_append", events),
216
- create_msg("visualizer_end", None),
217
- create_msg("progress", [0, 0])]
218
- yield mid_seq, continuation_state, "output.mid", (44100, audio), seed, send_msgs(end_msgs)
219
 
220
-
221
- def cancel_run(model_name, mid_seq):
222
  if mid_seq is None:
223
  return None, None, []
224
  tokenizer = models[model_name].tokenizer
225
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
226
- mid = tokenizer.detokenize(mid_seq)
227
- audio = synthesizer.synthesis(MIDI.score2opus(mid))
228
- with open(f"output.mid", 'wb') as f:
229
- f.write(MIDI.score2midi(mid))
230
- end_msgs = [create_msg("visualizer_clear", tokenizer.version),
231
- create_msg("visualizer_append", events),
232
- create_msg("visualizer_end", None),
233
- create_msg("progress", [0, 0])]
234
- return "output.mid", (44100, audio), send_msgs(end_msgs)
 
 
 
 
 
235
 
236
 
237
  def undo_continuation(model_name, mid_seq, continuation_state):
238
  if mid_seq is None or len(continuation_state) < 2:
239
  return mid_seq, continuation_state, send_msgs([])
240
- mid_seq = mid_seq[:continuation_state[-1]]
241
- continuation_state = continuation_state[:-1]
242
  tokenizer = models[model_name].tokenizer
243
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
244
- end_msgs = [create_msg("visualizer_clear", tokenizer.version),
245
- create_msg("visualizer_append", events),
246
- create_msg("visualizer_end", None),
247
- create_msg("progress", [0, 0])]
 
 
 
 
 
 
248
  return mid_seq, continuation_state, send_msgs(end_msgs)
249
 
250
 
@@ -296,13 +325,14 @@ if __name__ == "__main__":
296
  opt = parser.parse_args()
297
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
298
  synthesizer = MidiSynthesizer(soundfont_path)
299
- models_info = {"generic pretrain model (tv2o-medium) by skytnt": ["skytnt/midi-model-tv2o-medium", "", "tv2o-medium"],
300
- "generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
301
- "generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
302
- "generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
303
- "j-pop finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "jpop-tv2o-medium/", "tv2o-medium"],
304
- "touhou finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "touhou-tv2o-medium/", "tv2o-medium"],
305
- }
 
306
  models = {}
307
  if opt.device == "cuda":
308
  torch.backends.cudnn.deterministic = True
@@ -391,7 +421,12 @@ if __name__ == "__main__":
391
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
392
  [input_midi, input_midi_events])
393
  with gr.TabItem("last output prompt") as tab3:
394
- gr.Markdown("Continue generating on the last output. Just click the generate button")
 
 
 
 
 
395
  undo_btn = gr.Button("undo the last continuation")
396
 
397
  tab1.select(lambda: 0, None, tab_select, queue=False)
@@ -413,21 +448,29 @@ if __name__ == "__main__":
413
  stop_btn = gr.Button("stop and output")
414
  output_midi_seq = gr.State()
415
  output_continuation_state = gr.State([0])
416
- output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
417
- output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
418
- output_midi = gr.File(label="output midi", file_types=[".mid"])
 
 
 
 
 
419
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
420
- input_instruments, input_drum_kit, input_bpm, input_time_sig, input_key_sig,
421
- input_midi, input_midi_events, input_reduce_cc_st, input_remap_track_channel,
 
422
  input_add_default_instr, input_remove_empty_channels,
423
  input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
424
  input_top_k, input_allow_cc],
425
- [output_midi_seq, output_continuation_state,
426
- output_midi, output_audio, input_seed, js_msg],
427
- concurrency_limit=10)
428
- stop_btn.click(cancel_run, [input_model, output_midi_seq],
429
- [output_midi, output_audio, js_msg],
430
- cancels=run_event, queue=False)
 
 
431
  undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
432
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
433
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
18
  from midi_synthesizer import MidiSynthesizer
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
+ OUTPUT_BATCH_SIZE = 4
22
  in_space = os.getenv("SYSTEM") == "spaces"
23
 
24
 
25
  @torch.inference_mode()
26
+ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
27
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
28
  tokenizer = model.tokenizer
29
  if disable_channels is not None:
 
34
  if prompt is None:
35
  input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
36
  input_tensor[0, 0] = tokenizer.bos_id # bos
37
+ input_tensor = input_tensor.unsqueeze(0)
38
+ input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
39
  else:
40
+ if len(prompt.shape) == 2:
41
+ prompt = prompt[None, :]
42
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
43
+ elif prompt.shape[0] == 1:
44
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
45
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
46
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
47
+ prompt = prompt[..., :max_token_seq]
48
  if prompt.shape[-1] < max_token_seq:
49
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
50
  mode="constant", constant_values=tokenizer.pad_id)
51
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
 
52
  cur_len = input_tensor.shape[1]
53
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
54
  with bar:
55
  while cur_len < max_len:
56
+ end = [False] * batch_size
57
+ hidden = model.forward(input_tensor)[:, -1]
58
  next_token_seq = None
59
+ event_names = [""] * batch_size
60
  for i in range(max_token_seq):
61
+ mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
62
+ for b in range(batch_size):
63
+ if end[b]:
64
+ mask[b, tokenizer.pad_id] = 1
65
+ continue
66
+ if i == 0:
67
+ mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
68
+ if disable_patch_change:
69
+ mask_ids.remove(tokenizer.event_ids["patch_change"])
70
+ if disable_control_change:
71
+ mask_ids.remove(tokenizer.event_ids["control_change"])
72
+ mask[b, mask_ids] = 1
73
+ else:
74
+ param_names = tokenizer.events[event_names[b]]
75
+ if i > len(param_names):
76
+ mask[b, tokenizer.pad_id] = 1
77
+ continue
78
+ param_name = param_names[i - 1]
79
+ mask_ids = tokenizer.parameter_ids[param_name]
80
+ if param_name == "channel":
81
+ mask_ids = [i for i in mask_ids if i not in disable_channels]
82
+ mask[b, mask_ids] = 1
83
+ mask = mask.unsqueeze(1)
84
  logits = model.forward_token(hidden, next_token_seq)[:, -1:]
85
  scores = torch.softmax(logits / temp, dim=-1) * mask
86
+ samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
87
  if i == 0:
88
+ next_token_seq = samples
89
+ for b in range(batch_size):
90
+ if end[b]:
91
+ continue
92
+ eid = samples[b].item()
93
+ if eid == tokenizer.eos_id:
94
+ end[b] = True
95
+ else:
96
+ event_names[b] = tokenizer.id_events[eid]
97
  else:
98
+ next_token_seq = torch.cat([next_token_seq, samples], dim=1)
99
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
100
  break
101
  if next_token_seq.shape[1] < max_token_seq:
102
  next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
 
105
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
106
  cur_len += 1
107
  bar.update(1)
108
+ yield next_token_seq[:, 0].cpu().numpy()
109
+ if all(end):
110
  break
111
 
112
 
 
117
  def send_msgs(msgs):
118
  return json.dumps(msgs)
119
 
120
+
121
+ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
122
+ time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
123
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
124
  if "large" in model_name:
125
  return gen_events // 10 + 15
 
128
 
129
 
130
  @spaces.GPU(duration=get_duration)
131
+ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
132
+ key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
133
+ seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
134
  model = models[model_name]
135
  model.to(device=opt.device)
136
  tokenizer = model.tokenizer
 
178
  patches[9] = drum_kits2number[drum_kit]
179
  for i, (c, p) in enumerate(patches.items()):
180
  mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
181
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
182
+ mid_seq = mid.tolist()
183
  if len(instruments) > 0:
184
  disable_patch_change = True
185
  disable_channels = [i for i in range(16) if i not in patches]
 
189
  remap_track_channel=remap_track_channel,
190
  add_default_instr=add_default_instr,
191
  remove_empty_channels=remove_empty_channels)
 
192
  mid = mid[:int(midi_events)]
193
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
194
+ mid_seq = mid.tolist()
 
195
  elif tab == 2 and mid_seq is not None:
 
196
  mid = np.asarray(mid_seq, dtype=np.int64)
197
+ if continuation_select > 0:
198
+ continuation_state.append(mid_seq)
199
+ mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
200
+ mid_seq = mid.tolist()
201
+ else:
202
+ continuation_state.append(mid.shape[1])
203
  else:
204
  continuation_state = [0]
205
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
206
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
207
+ mid_seq = mid.tolist()
208
 
209
  if mid is not None:
210
+ max_len += mid.shape[1]
211
 
 
212
  init_msgs = [create_msg("progress", [0, gen_events])]
213
+ if not (tab == 2 and continuation_select == 0):
214
+ for i in range(OUTPUT_BATCH_SIZE):
215
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
216
+ init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
217
+ create_msg("visualizer_append", [i, events])]
218
+ yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
219
+ midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
220
+ top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
221
+ disable_control_change=not allow_cc, disable_channels=disable_channels,
222
+ generator=generator)
223
+ events = [list() for i in range(OUTPUT_BATCH_SIZE)]
224
+ t = time.time()
225
+ for i, token_seqs in enumerate(midi_generator):
226
+ token_seqs = token_seqs.tolist()
227
+ for j in range(OUTPUT_BATCH_SIZE):
228
+ token_seq = token_seqs[j]
229
+ mid_seq[j].append(token_seq)
230
+ events[j].append(tokenizer.tokens2event(token_seq))
231
+ if time.time() - t > 0.2:
232
+ msgs = [create_msg("progress", [i + 1, gen_events])]
233
+ for j in range(OUTPUT_BATCH_SIZE):
234
+ msgs += [create_msg("visualizer_append", [j, events[j]])]
235
+ events[j] = list()
236
+ yield mid_seq, continuation_state, seed, send_msgs(msgs)
237
+ t = time.time()
238
+ yield mid_seq, continuation_state, seed, send_msgs([])
239
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ def finish_run(model_name, mid_seq):
 
242
  if mid_seq is None:
243
  return None, None, []
244
  tokenizer = models[model_name].tokenizer
245
+ outputs = []
246
+ end_msgs = [create_msg("progress", [0, 0])]
247
+ if not os.path.exists("outputs"):
248
+ os.mkdir("outputs")
249
+ for i in range(OUTPUT_BATCH_SIZE):
250
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
251
+ mid = tokenizer.detokenize(mid_seq[i])
252
+ audio = synthesizer.synthesis(MIDI.score2opus(mid))
253
+ with open(f"outputs/output{i + 1}.mid", 'wb') as f:
254
+ f.write(MIDI.score2midi(mid))
255
+ outputs += [(44100, audio), f"outputs/output{i + 1}.mid"]
256
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
257
+ create_msg("visualizer_append", [i, events]),
258
+ create_msg("visualizer_end", i)]
259
+ return *outputs, send_msgs(end_msgs)
260
 
261
 
262
  def undo_continuation(model_name, mid_seq, continuation_state):
263
  if mid_seq is None or len(continuation_state) < 2:
264
  return mid_seq, continuation_state, send_msgs([])
 
 
265
  tokenizer = models[model_name].tokenizer
266
+ if isinstance(continuation_state[-1], list):
267
+ mid_seq = continuation_state[-1]
268
+ else:
269
+ mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
270
+ continuation_state = continuation_state[:-1]
271
+ end_msgs = [create_msg("progress", [0, 0])]
272
+ for i in range(OUTPUT_BATCH_SIZE):
273
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
274
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
275
+ create_msg("visualizer_append", [i, events]),
276
+ create_msg("visualizer_end", i)]
277
  return mid_seq, continuation_state, send_msgs(end_msgs)
278
 
279
 
 
325
  opt = parser.parse_args()
326
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
327
  synthesizer = MidiSynthesizer(soundfont_path)
328
+ models_info = {
329
+ "generic pretrain model (tv2o-medium) by skytnt": ["skytnt/midi-model-tv2o-medium", "", "tv2o-medium"],
330
+ "generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
331
+ "generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
332
+ "generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
333
+ "j-pop finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "jpop-tv2o-medium/", "tv2o-medium"],
334
+ "touhou finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "touhou-tv2o-medium/", "tv2o-medium"],
335
+ }
336
  models = {}
337
  if opt.device == "cuda":
338
  torch.backends.cudnn.deterministic = True
 
421
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
422
  [input_midi, input_midi_events])
423
  with gr.TabItem("last output prompt") as tab3:
424
+ gr.Markdown("Continue generating on the last output.")
425
+ input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
426
+ choices=["all"] + [f"output{i + 1}" for i in
427
+ range(OUTPUT_BATCH_SIZE)],
428
+ type="index"
429
+ )
430
  undo_btn = gr.Button("undo the last continuation")
431
 
432
  tab1.select(lambda: 0, None, tab_select, queue=False)
 
448
  stop_btn = gr.Button("stop and output")
449
  output_midi_seq = gr.State()
450
  output_continuation_state = gr.State([0])
451
+ batch_outputs = []
452
+ with gr.Tabs(elem_id="output_tabs"):
453
+ for i in range(OUTPUT_BATCH_SIZE):
454
+ with gr.TabItem(f"output {i + 1}") as tab1:
455
+ output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
456
+ output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
457
+ output_midi = gr.File(label="output midi", file_types=[".mid"])
458
+ batch_outputs += [output_audio, output_midi]
459
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
460
+ input_continuation_select, input_instruments, input_drum_kit, input_bpm,
461
+ input_time_sig, input_key_sig, input_midi, input_midi_events,
462
+ input_reduce_cc_st, input_remap_track_channel,
463
  input_add_default_instr, input_remove_empty_channels,
464
  input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
465
  input_top_k, input_allow_cc],
466
+ [output_midi_seq, output_continuation_state, input_seed, js_msg],
467
+ concurrency_limit=10, queue=True)
468
+ run_event.then(fn=finish_run,
469
+ inputs=[input_model, output_midi_seq],
470
+ outputs=batch_outputs + [js_msg],
471
+ queue=False)
472
+ stop_btn.click(None, [], [], cancels=run_event,
473
+ queue=False)
474
  undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
475
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
476
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True)
javascript/app.js CHANGED
@@ -1,3 +1,5 @@
 
 
1
  /**
2
  * 自动绕过 shadowRoot 的 querySelector
3
  * @param {string} selector - 要查询的 CSS 选择器
@@ -594,33 +596,49 @@ class MidiVisualizer extends HTMLElement{
594
  customElements.define('midi-visualizer', MidiVisualizer);
595
 
596
  (()=>{
597
- let midi_visualizer_container_inited = null
598
- let midi_audio_audio_inited = null;
599
- let midi_audio_cursor_inited = null;
600
- let midi_visualizer = document.createElement('midi-visualizer')
601
- onUiUpdate((m)=>{
602
- let app = gradioApp()
603
- let midi_visualizer_container = app.querySelector("#midi_visualizer_container");
604
- if(!!midi_visualizer_container && midi_visualizer_container_inited!== midi_visualizer_container){
605
- midi_visualizer_container.appendChild(midi_visualizer)
606
- midi_visualizer_container_inited = midi_visualizer_container;
607
- }
608
- let midi_audio = app.querySelector("#midi_audio");
609
- if (!!midi_audio){
610
- let midi_audio_cursor = midi_audio.deepQuerySelector(".cursor");
611
- if(!!midi_audio_cursor && midi_audio_cursor_inited!==midi_audio_cursor){
612
- midi_visualizer.bindWaveformCursor(midi_audio_cursor)
613
- midi_audio_cursor_inited = midi_audio_cursor
614
  }
615
- let midi_audio_audio = midi_audio.deepQuerySelector("audio");
616
- if(!!midi_audio_audio && midi_audio_audio_inited!==midi_audio_audio){
617
- midi_visualizer.bindAudioPlayer(midi_audio_audio)
618
- midi_audio_audio_inited = midi_audio_audio
 
 
 
 
 
 
 
 
619
  }
620
- }
621
- })
 
 
 
 
 
 
 
622
 
623
  let hasProgressBar = false;
 
 
 
 
 
 
 
 
624
 
625
  function createProgressBar(progressbarContainer){
626
  let parentProgressbar = progressbarContainer.parentNode;
@@ -653,15 +671,15 @@ customElements.define('midi-visualizer', MidiVisualizer);
653
  hasProgressBar = false;
654
  }
655
 
656
- function setProgressBar(progressbarContainer, progress, total){
657
  if (!hasProgressBar)
658
- createProgressBar(midi_visualizer_container_inited)
659
  if (hasProgressBar && total === 0){
660
- removeProgressBar(midi_visualizer_container_inited)
661
  return
662
  }
663
- let parentProgressbar = progressbarContainer.parentNode;
664
- let divProgress = parentProgressbar.querySelector(".progressDiv");
665
  let divInner = parentProgressbar.querySelector(".progress");
666
  if(total===0)
667
  total = 1;
@@ -679,24 +697,30 @@ customElements.define('midi-visualizer', MidiVisualizer);
679
  }
680
  })
681
  function handleMsg(msg){
 
682
  switch (msg.name) {
683
  case "visualizer_clear":
684
- midi_visualizer.clearMidiEvents(false);
685
- midi_visualizer.version = msg.data
 
 
686
  break;
687
  case "visualizer_append":
688
- msg.data.forEach( value => {
689
- midi_visualizer.appendMidiEvent(value);
 
 
690
  })
691
  break;
692
  case "visualizer_end":
693
- midi_visualizer.finishAppendMidiEvent()
694
- midi_visualizer.setPlayTime(0);
 
695
  break;
696
  case "progress":
697
  let progress = msg.data[0]
698
  let total = msg.data[1]
699
- setProgressBar(midi_visualizer_container_inited, progress, total)
700
  break;
701
  default:
702
  }
 
1
+ const MIDI_OUTPUT_BATCH_SIZE=4;
2
+
3
  /**
4
  * 自动绕过 shadowRoot 的 querySelector
5
  * @param {string} selector - 要查询的 CSS 选择器
 
596
  customElements.define('midi-visualizer', MidiVisualizer);
597
 
598
  (()=>{
599
+ function midi_visualizer_setup(idx, midi_visualizer){
600
+ let midi_visualizer_container_inited = null
601
+ let midi_audio_audio_inited = null;
602
+ let midi_audio_cursor_inited = null;
603
+ onUiUpdate((m)=>{
604
+ let app = gradioApp()
605
+ let midi_visualizer_container = app.querySelector(`#midi_visualizer_container_${idx}`);
606
+ if(!!midi_visualizer_container && midi_visualizer_container_inited!== midi_visualizer_container){
607
+ midi_visualizer_container.appendChild(midi_visualizer)
608
+ midi_visualizer_container_inited = midi_visualizer_container;
 
 
 
 
 
 
 
609
  }
610
+ let midi_audio = app.querySelector(`#midi_audio_${idx}`);
611
+ if (!!midi_audio){
612
+ let midi_audio_cursor = midi_audio.deepQuerySelector(".cursor");
613
+ if(!!midi_audio_cursor && midi_audio_cursor_inited!==midi_audio_cursor){
614
+ midi_visualizer.bindWaveformCursor(midi_audio_cursor)
615
+ midi_audio_cursor_inited = midi_audio_cursor
616
+ }
617
+ let midi_audio_audio = midi_audio.deepQuerySelector("audio");
618
+ if(!!midi_audio_audio && midi_audio_audio_inited!==midi_audio_audio){
619
+ midi_visualizer.bindAudioPlayer(midi_audio_audio)
620
+ midi_audio_audio_inited = midi_audio_audio
621
+ }
622
  }
623
+ });
624
+ }
625
+
626
+ let midi_visualizers = []
627
+ for (let i = 0; i < MIDI_OUTPUT_BATCH_SIZE ; i++){
628
+ let midi_visualizer = document.createElement('midi-visualizer');
629
+ midi_visualizers.push(midi_visualizer);
630
+ midi_visualizer_setup(i, midi_visualizer)
631
+ }
632
 
633
  let hasProgressBar = false;
634
+ let output_tabs_inited = null;
635
+ onUiUpdate((m)=>{
636
+ let app = gradioApp()
637
+ let output_tabs = app.querySelector("#output_tabs");
638
+ if(!!output_tabs && output_tabs_inited!== output_tabs){
639
+ output_tabs_inited = output_tabs;
640
+ }
641
+ });
642
 
643
  function createProgressBar(progressbarContainer){
644
  let parentProgressbar = progressbarContainer.parentNode;
 
671
  hasProgressBar = false;
672
  }
673
 
674
+ function setProgressBar(progress, total){
675
  if (!hasProgressBar)
676
+ createProgressBar(output_tabs_inited)
677
  if (hasProgressBar && total === 0){
678
+ removeProgressBar(output_tabs_inited)
679
  return
680
  }
681
+ let parentProgressbar = output_tabs_inited.parentNode;
682
+ // let divProgress = parentProgressbar.querySelector(".progressDiv");
683
  let divInner = parentProgressbar.querySelector(".progress");
684
  if(total===0)
685
  total = 1;
 
697
  }
698
  })
699
  function handleMsg(msg){
700
+ let idx;
701
  switch (msg.name) {
702
  case "visualizer_clear":
703
+ idx = msg.data[0];
704
+ let ver = msg.data[1];
705
+ midi_visualizers[idx].clearMidiEvents(false);
706
+ midi_visualizers[idx].version = ver;
707
  break;
708
  case "visualizer_append":
709
+ idx = msg.data[0];
710
+ let events = msg.data[1];
711
+ events.forEach( value => {
712
+ midi_visualizers[idx].appendMidiEvent(value);
713
  })
714
  break;
715
  case "visualizer_end":
716
+ idx = msg.data;
717
+ midi_visualizers[idx].finishAppendMidiEvent()
718
+ midi_visualizers[idx].setPlayTime(0);
719
  break;
720
  case "progress":
721
  let progress = msg.data[0]
722
  let total = msg.data[1]
723
+ setProgressBar(progress, total)
724
  break;
725
  default:
726
  }
midi_model.py CHANGED
@@ -111,49 +111,69 @@ class MIDIModel(nn.Module):
111
  return next_token
112
 
113
  @torch.inference_mode()
114
- def generate(self, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
115
  tokenizer = self.tokenizer
116
  max_token_seq = tokenizer.max_token_seq
117
  if prompt is None:
118
  input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
119
  input_tensor[0, 0] = tokenizer.bos_id # bos
 
 
120
  else:
121
- prompt = prompt[:, :max_token_seq]
 
 
 
 
 
 
 
122
  if prompt.shape[-1] < max_token_seq:
123
- prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
124
  mode="constant", constant_values=tokenizer.pad_id)
125
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
126
- input_tensor = input_tensor.unsqueeze(0)
127
  cur_len = input_tensor.shape[1]
128
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
129
  with bar:
130
  while cur_len < max_len:
131
- end = False
132
- hidden = self.forward(input_tensor)[0, -1].unsqueeze(0)
133
  next_token_seq = None
134
- event_name = ""
135
  for i in range(max_token_seq):
136
- mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=self.device)
137
- if i == 0:
138
- mask[list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
139
- else:
140
- param_name = tokenizer.events[event_name][i - 1]
141
- mask[tokenizer.parameter_ids[param_name]] = 1
142
-
 
 
 
 
 
 
 
143
  logits = self.forward_token(hidden, next_token_seq)[:, -1:]
144
  scores = torch.softmax(logits / temp, dim=-1) * mask
145
- sample = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
146
  if i == 0:
147
- next_token_seq = sample
148
- eid = sample.item()
149
- if eid == tokenizer.eos_id:
150
- end = True
151
- break
152
- event_name = tokenizer.id_events[eid]
 
 
 
153
  else:
154
- next_token_seq = torch.cat([next_token_seq, sample], dim=1)
155
- if len(tokenizer.events[event_name]) == i:
156
  break
 
157
  if next_token_seq.shape[1] < max_token_seq:
158
  next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
159
  "constant", value=tokenizer.pad_id)
@@ -161,6 +181,7 @@ class MIDIModel(nn.Module):
161
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
162
  cur_len += 1
163
  bar.update(1)
164
- if end:
 
165
  break
166
- return input_tensor[0].cpu().numpy()
 
111
  return next_token
112
 
113
  @torch.inference_mode()
114
+ def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
115
  tokenizer = self.tokenizer
116
  max_token_seq = tokenizer.max_token_seq
117
  if prompt is None:
118
  input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
119
  input_tensor[0, 0] = tokenizer.bos_id # bos
120
+ input_tensor = input_tensor.unsqueeze(0)
121
+ input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
122
  else:
123
+ if len(prompt.shape) == 2:
124
+ prompt = prompt[None, :]
125
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
126
+ elif prompt.shape[0] == 1:
127
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
128
+ else:
129
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
130
+ prompt = prompt[..., :max_token_seq]
131
  if prompt.shape[-1] < max_token_seq:
132
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
133
  mode="constant", constant_values=tokenizer.pad_id)
134
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
135
+
136
  cur_len = input_tensor.shape[1]
137
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
138
  with bar:
139
  while cur_len < max_len:
140
+ end = [False] * batch_size
141
+ hidden = self.forward(input_tensor)[:, -1]
142
  next_token_seq = None
143
+ event_names = [""] * batch_size
144
  for i in range(max_token_seq):
145
+ mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
146
+ for b in range(batch_size):
147
+ if end[b]:
148
+ mask[b, tokenizer.pad_id] = 1
149
+ continue
150
+ if i == 0:
151
+ mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
152
+ else:
153
+ param_names = tokenizer.events[event_names[b]]
154
+ if i > len(param_names):
155
+ mask[b, tokenizer.pad_id] = 1
156
+ continue
157
+ mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
158
+ mask = mask.unsqueeze(1)
159
  logits = self.forward_token(hidden, next_token_seq)[:, -1:]
160
  scores = torch.softmax(logits / temp, dim=-1) * mask
161
+ samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
162
  if i == 0:
163
+ next_token_seq = samples
164
+ for b in range(batch_size):
165
+ if end[b]:
166
+ continue
167
+ eid = samples[b].item()
168
+ if eid == tokenizer.eos_id:
169
+ end[b] = True
170
+ else:
171
+ event_names[b] = tokenizer.id_events[eid]
172
  else:
173
+ next_token_seq = torch.cat([next_token_seq, samples], dim=1)
174
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
175
  break
176
+
177
  if next_token_seq.shape[1] < max_token_seq:
178
  next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
179
  "constant", value=tokenizer.pad_id)
 
181
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
182
  cur_len += 1
183
  bar.update(1)
184
+
185
+ if all(end):
186
  break
187
+ return input_tensor.cpu().numpy()