JingyeChen22 commited on
Commit
9de996f
1 Parent(s): 4595437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -60
app.py CHANGED
@@ -26,10 +26,6 @@ os.system('wget https://huggingface.co/datasets/JingyeChen22/TextDiffuser/resolv
26
  if not os.path.exists('Arial.ttf'):
27
  os.system('wget https://huggingface.co/datasets/JingyeChen22/TextDiffuser/resolve/main/Arial.ttf')
28
 
29
-
30
- os.system('echo finish')
31
- os.system('ls -a')
32
-
33
  import cv2
34
  import random
35
  import logging
@@ -67,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
67
  import transformers
68
  from transformers import CLIPTextModel, CLIPTokenizer
69
 
70
- from util import segmentation_mask_visualization, make_caption_pil, combine_image, transform_mask, transform_mask_pil, filter_segmentation_mask, inpainting_merge_image
71
  from model.layout_generator import get_layout_from_prompt
72
  from model.text_segmenter.unet import UNet
73
 
@@ -364,20 +360,40 @@ if accelerator.is_main_process:
364
  print(args.output_dir)
365
 
366
  # Load scheduler, tokenizer and models.
367
- tokenizer = CLIPTokenizer.from_pretrained(
368
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
369
  )
370
- text_encoder = CLIPTextModel.from_pretrained(
371
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
 
 
 
 
372
  )
373
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).cuda()
374
- unet = UNet2DConditionModel.from_pretrained(
375
- args.resume_from_checkpoint, subfolder="unet", revision=None
 
 
 
 
 
 
 
 
 
376
  ).cuda()
377
 
 
 
 
 
 
378
  # Freeze vae and text_encoder
379
- vae.requires_grad_(False)
380
- text_encoder.requires_grad_(False)
 
 
381
 
382
  if args.enable_xformers_memory_efficient_attention:
383
  if is_xformers_available():
@@ -421,7 +437,6 @@ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
421
 
422
 
423
  # setup schedulers
424
- scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
425
  # sample_num = args.vis_num
426
 
427
  def to_tensor(image):
@@ -461,7 +476,25 @@ def has_chinese_char(string):
461
 
462
  image_404 = Image.open('404.jpg')
463
 
464
- def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
  if has_chinese_char(prompt):
467
  print('trigger')
@@ -484,7 +517,7 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
484
  set_seed(seed)
485
  scheduler.set_timesteps(slider_step)
486
 
487
- noise = torch.randn((sample_num, 4, 64, 64)).to("cuda") # (b, 4, 64, 64)
488
  input = noise # (b, 4, 64, 64)
489
 
490
  captions = [args.prompt] * sample_num
@@ -504,25 +537,18 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
504
  encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
505
  print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
506
 
507
- # load character-level segmenter
508
- segmenter = UNet(3, 96, True).cuda()
509
- segmenter = torch.nn.DataParallel(segmenter)
510
- segmenter.load_state_dict(torch.load(args.character_segmenter_path))
511
- segmenter.eval()
512
- print(f'{colored("[√]", "green")} Text segmenter is successfully loaded.')
513
-
514
  #### text-to-image ####
515
  render_image, segmentation_mask_from_pillow = get_layout_from_prompt(args)
516
 
517
  segmentation_mask = torch.Tensor(np.array(segmentation_mask_from_pillow)).cuda() # (512, 512)
518
 
519
  segmentation_mask = filter_segmentation_mask(segmentation_mask)
520
- segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(256, 256), mode='nearest')
521
  segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (1, 1, 256, 256)
522
  print(f'{colored("[√]", "green")} character-level segmentation_mask: {segmentation_mask.shape}.')
523
 
524
- feature_mask = torch.ones(sample_num, 1, 64, 64).to('cuda') # (b, 1, 64, 64)
525
- masked_image = torch.zeros(sample_num, 3, 512, 512).to('cuda') # (b, 3, 512, 512)
526
  masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
527
  masked_feature = masked_feature * vae.config.scaling_factor
528
  print(f'{colored("[√]", "green")} feature_mask: {feature_mask.shape}.')
@@ -543,10 +569,11 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
543
  input = 1 / vae.config.scaling_factor * input
544
  sample_images = vae.decode(input.float(), return_dict=False)[0] # (b, 3, 512, 512)
545
 
546
- image_pil = render_image.resize((512,512))
547
  segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
548
- character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((512,512))
549
  character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
 
550
  caption_pil = make_caption_pil(args.font_path, captions)
551
 
552
  # save pred_img
@@ -557,12 +584,12 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
557
  image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
558
  pred_image_list.append(image)
559
 
560
- blank_pil = combine_image(args, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
561
 
562
- intermediate_result = Image.new('RGB', (512*3, 512))
563
  intermediate_result.paste(image_pil, (0, 0))
564
- intermediate_result.paste(character_mask_pil, (512, 0))
565
- intermediate_result.paste(character_mask_highlight_pil, (512*2, 0))
566
 
567
  return blank_pil, intermediate_result
568
 
@@ -577,7 +604,25 @@ print(f'{colored("[√]", "green")} Text segmenter is successfully loaded.')
577
 
578
 
579
 
580
- def text_to_image_with_template(prompt,template_image,slider_step,slider_guidance,slider_batch, binary):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  if has_chinese_char(prompt):
583
  print('trigger')
@@ -586,7 +631,7 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
586
  if slider_step>=50:
587
  slider_step = 50
588
 
589
- orig_template_image = template_image.resize((512,512)).convert('RGB')
590
  args.prompt = prompt
591
  sample_num = slider_batch
592
  # If passed along, set the training seed now.
@@ -595,7 +640,7 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
595
  set_seed(seed)
596
  scheduler.set_timesteps(slider_step)
597
 
598
- noise = torch.randn((sample_num, 4, 64, 64)).to("cuda") # (b, 4, 64, 64)
599
  input = noise # (b, 4, 64, 64)
600
 
601
  captions = [args.prompt] * sample_num
@@ -634,12 +679,12 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
634
  segmentation_mask = segmentation_mask.max(1)[1].squeeze(0) # (256, 256)
635
  segmentation_mask = filter_segmentation_mask(segmentation_mask) # (256, 256)
636
 
637
- segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(256, 256), mode='nearest') # (b, 1, 256, 256)
638
  segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (b, 1, 256, 256)
639
  print(f'{colored("[√]", "green")} Character-level segmentation_mask: {segmentation_mask.shape}.')
640
 
641
- feature_mask = torch.ones(sample_num, 1, 64, 64).to('cuda') # (b, 1, 64, 64)
642
- masked_image = torch.zeros(sample_num, 3, 512, 512).to('cuda') # (b, 3, 512, 512)
643
  masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
644
  masked_feature = masked_feature * vae.config.scaling_factor # (b, 4, 64, 64)
645
 
@@ -660,8 +705,9 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
660
 
661
  image_pil = None
662
  segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
663
- character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((512,512))
664
  character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
 
665
  caption_pil = make_caption_pil(args.font_path, captions)
666
 
667
  # save pred_img
@@ -672,17 +718,35 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
672
  image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
673
  pred_image_list.append(image)
674
 
675
- blank_pil = combine_image(args, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
676
 
677
- intermediate_result = Image.new('RGB', (512*3, 512))
678
  intermediate_result.paste(orig_template_image, (0, 0))
679
- intermediate_result.paste(character_mask_pil, (512, 0))
680
- intermediate_result.paste(character_mask_highlight_pil, (512*2, 0))
681
 
682
  return blank_pil, intermediate_result
683
 
684
 
685
- def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
  if has_chinese_char(prompt):
688
  print('trigger')
@@ -699,7 +763,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
699
  set_seed(seed)
700
  scheduler.set_timesteps(slider_step)
701
 
702
- noise = torch.randn((sample_num, 4, 64, 64)).to("cuda") # (b, 4, 64, 64)
703
  input = noise # (b, 4, 64, 64)
704
 
705
  captions = [args.prompt] * sample_num
@@ -719,7 +783,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
719
  encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
720
  print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
721
 
722
- mask_image = cv2.resize(mask_image, (512,512))
723
  # mask_image = mask_image.resize((512,512)).convert('RGB')
724
  text_mask = np.array(mask_image)
725
  threshold = 128
@@ -732,21 +796,21 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
732
 
733
  segmentation_mask = segmentation_mask.max(1)[1].squeeze(0)
734
  segmentation_mask = filter_segmentation_mask(segmentation_mask)
735
- segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(256, 256), mode='nearest')
736
 
737
- image_mask = transform_mask_pil(mask_image)
738
  image_mask = torch.from_numpy(image_mask).cuda().unsqueeze(0).unsqueeze(0)
739
 
740
- orig_image = orig_image.convert('RGB').resize((512,512))
741
  image = orig_image
742
  image_tensor = to_tensor(image).unsqueeze(0).cuda().sub_(0.5).div_(0.5)
743
  masked_image = image_tensor * (1-image_mask)
744
  masked_feature = vae.encode(masked_image).latent_dist.sample().repeat(sample_num, 1, 1, 1)
745
  masked_feature = masked_feature * vae.config.scaling_factor
746
 
747
- image_mask = torch.nn.functional.interpolate(image_mask, size=(256, 256), mode='nearest').repeat(sample_num, 1, 1, 1)
748
  segmentation_mask = segmentation_mask * image_mask
749
- feature_mask = torch.nn.functional.interpolate(image_mask, size=(64, 64), mode='nearest')
750
 
751
  # diffusion process
752
  intermediate_images = []
@@ -767,6 +831,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
767
  segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
768
  character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((512,512))
769
  character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
 
770
  caption_pil = make_caption_pil(args.font_path, captions)
771
 
772
  # save pred_img
@@ -786,7 +851,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
786
  character_mask_highlight_pil.save('character_mask_highlight_pil.png')
787
 
788
 
789
- blank_pil = combine_image(args, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
790
 
791
 
792
  background = orig_image.resize((512, 512))
@@ -825,6 +890,11 @@ with gr.Blocks() as demo:
825
  We propose <b>TextDiffuser</b>, a flexible and controllable framework to generate images with visually appealing text that is coherent with backgrounds.
826
  Main features include: (a) <b><font color="#A52A2A">Text-to-Image</font></b>: The user provides a prompt and encloses the keywords with single quotes (e.g., a text image of ‘hello’). The model first determines the layout of the keywords and then draws the image based on the layout and prompt. (b) <b><font color="#A52A2A">Text-to-Image with Templates</font></b>: The user provides a prompt and a template image containing text, which can be a printed, handwritten, or scene text image. These template images can be used to determine the layout of the characters. (c) <b><font color="#A52A2A">Text Inpainting</font></b>: The user provides an image and specifies the region to be modified along with the desired text content. The model is able to modify the original text or add text to areas without text.
827
  </h2>
 
 
 
 
 
828
  <img src="file/images/huggingface_blank.jpg" alt="textdiffuser">
829
  </div>
830
  """)
@@ -833,9 +903,10 @@ with gr.Blocks() as demo:
833
  with gr.Row():
834
  with gr.Column(scale=1):
835
  prompt = gr.Textbox(label="Input your prompt here. Please enclose keywords with 'single quotes', you may refer to the examples below. The current version only supports input in English characters.", placeholder="Placeholder 'Team' hat")
 
836
  slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
837
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
838
- slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
839
  # slider_seed = gr.Slider(minimum=1, maximum=10000, label="Seed", randomize=True)
840
  button = gr.Button("Generate")
841
 
@@ -851,7 +922,7 @@ with gr.Blocks() as demo:
851
  [
852
  ["Distinguished poster of 'SPIDERMAN'. Trending on ArtStation and Pixiv. A vibrant digital oil painting. A highly detailed fantasy character illustration by Wayne Reynolds and Charles Monet and Gustave Dore and Carl Critchlow and Bram Sels"],
853
  ["A detailed portrait of a fox guardian with a shield with 'Kung Fu' written on it, by victo ngai and justin gerard, digital art, realistic painting, very detailed, fantasy, high definition, cinematic light, dnd, trending on artstation"],
854
- ["portrait of a 'dragon', concept art, sumi - e style, intricate linework, green smoke, artstation, trending, highly detailed, smooth, focus, art by yoji shinkawa,"],
855
  ["elderly woman dressed in extremely colorful clothes with many strange patterns posing for a high fashion photoshoot of 'FASHION', haute couture, golden hour, artstation, by J. C. Leyendecker and Peter Paul Rubens"],
856
  ["epic digital art of a luxury yacht named 'Time Machine' driving through very dark hard edged city towers from tron movie, faint tall mountains in background, wlop, pixiv"],
857
  ["A poster of 'Adventurer'. A beautiful so tall boy with big eyes and small nose is in the jungle, he wears normal clothes and shows his full length, which we see from the front, unreal engine, cozy indoor lighting, artstation, detailed"],
@@ -876,16 +947,17 @@ with gr.Blocks() as demo:
876
  examples_per_page=100
877
  )
878
 
879
- button.click(text_to_image, inputs=[prompt,slider_step,slider_guidance,slider_batch], outputs=[output,intermediate_results])
880
 
881
  with gr.Tab("Text-to-Image-with-Template"):
882
  with gr.Row():
883
  with gr.Column(scale=1):
884
  prompt = gr.Textbox(label='Input your prompt here.')
885
  template_image = gr.Image(label='Template image', type="pil")
 
886
  slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
887
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
888
- slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
889
  # binary = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
890
  binary = gr.Checkbox(label="Binarization", bool=True, info="Whether to binarize the template image? You may need it when using handwritten images as templates.")
891
  button = gr.Button("Generate")
@@ -923,7 +995,7 @@ with gr.Blocks() as demo:
923
  examples_per_page=100
924
  )
925
 
926
- button.click(text_to_image_with_template, inputs=[prompt,template_image,slider_step,slider_guidance,slider_batch,binary], outputs=[output,intermediate_results])
927
 
928
  with gr.Tab("Text-Inpainting"):
929
  with gr.Row():
@@ -932,9 +1004,10 @@ with gr.Blocks() as demo:
932
  with gr.Row():
933
  orig_image = gr.Image(label='Original image', type="pil")
934
  mask_image = gr.Image(label='Mask image', type="numpy")
 
935
  slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
936
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
937
- slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
938
  button = gr.Button("Generate")
939
  with gr.Column(scale=1):
940
  output = gr.Image(label='Generated image')
@@ -969,7 +1042,7 @@ with gr.Blocks() as demo:
969
  )
970
 
971
 
972
- button.click(text_inpainting, inputs=[prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch], outputs=[output, intermediate_results])
973
 
974
 
975
 
 
26
  if not os.path.exists('Arial.ttf'):
27
  os.system('wget https://huggingface.co/datasets/JingyeChen22/TextDiffuser/resolve/main/Arial.ttf')
28
 
 
 
 
 
29
  import cv2
30
  import random
31
  import logging
 
63
  import transformers
64
  from transformers import CLIPTextModel, CLIPTokenizer
65
 
66
+ from util import segmentation_mask_visualization, make_caption_pil, combine_image, transform_mask_pil, filter_segmentation_mask, inpainting_merge_image
67
  from model.layout_generator import get_layout_from_prompt
68
  from model.text_segmenter.unet import UNet
69
 
 
360
  print(args.output_dir)
361
 
362
  # Load scheduler, tokenizer and models.
363
+ tokenizer15 = CLIPTokenizer.from_pretrained(
364
+ 'runwayml/stable-diffusion-v1-5', subfolder="tokenizer", revision=args.revision
365
  )
366
+ tokenizer21 = CLIPTokenizer.from_pretrained(
367
+ 'stabilityai/stable-diffusion-2-1', subfolder="tokenizer", revision=args.revision
368
+ )
369
+
370
+ text_encoder15 = CLIPTextModel.from_pretrained(
371
+ 'runwayml/stable-diffusion-v1-5', subfolder="text_encoder", revision=args.revision
372
  )
373
+ text_encoder21 = CLIPTextModel.from_pretrained(
374
+ 'stabilityai/stable-diffusion-2-1', subfolder="text_encoder", revision=args.revision
375
+ )
376
+
377
+ vae15 = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae", revision=args.revision).cuda()
378
+ unet15 = UNet2DConditionModel.from_pretrained(
379
+ 'textdiffuser-ckpt/diffusion_backbone_1.5', subfolder="unet", revision=None
380
+ ).cuda()
381
+
382
+ vae21 = AutoencoderKL.from_pretrained('stabilityai/stable-diffusion-2-1', subfolder="vae", revision=args.revision).cuda()
383
+ unet21 = UNet2DConditionModel.from_pretrained(
384
+ 'textdiffuser-ckpt/diffusion_backbone_2.1', subfolder="unet", revision=None
385
  ).cuda()
386
 
387
+ scheduler15 = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
388
+ scheduler21 = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-2-1', subfolder="scheduler")
389
+
390
+
391
+
392
  # Freeze vae and text_encoder
393
+ vae15.requires_grad_(False)
394
+ vae21.requires_grad_(False)
395
+ text_encoder15.requires_grad_(False)
396
+ text_encoder21.requires_grad_(False)
397
 
398
  if args.enable_xformers_memory_efficient_attention:
399
  if is_xformers_available():
 
437
 
438
 
439
  # setup schedulers
 
440
  # sample_num = args.vis_num
441
 
442
  def to_tensor(image):
 
476
 
477
  image_404 = Image.open('404.jpg')
478
 
479
+ def text_to_image(prompt,slider_step,slider_guidance,slider_batch, version):
480
+ print(f'【version】{version}')
481
+ if version == 'Stable Diffusion v2.1':
482
+ vae = vae21
483
+ unet = unet21
484
+ text_encoder = text_encoder21
485
+ tokenizer = tokenizer21
486
+ scheduler = scheduler21
487
+ slider_batch = min(slider_batch, 2)
488
+ size = 768
489
+ elif version == 'Stable Diffusion v1.5':
490
+ vae = vae15
491
+ unet = unet15
492
+ text_encoder = text_encoder15
493
+ tokenizer = tokenizer15
494
+ scheduler = scheduler15
495
+ size = 512
496
+ else:
497
+ assert False, 'Version Not Found'
498
 
499
  if has_chinese_char(prompt):
500
  print('trigger')
 
517
  set_seed(seed)
518
  scheduler.set_timesteps(slider_step)
519
 
520
+ noise = torch.randn((sample_num, 4, size//8, size//8)).to("cuda") # (b, 4, 64, 64)
521
  input = noise # (b, 4, 64, 64)
522
 
523
  captions = [args.prompt] * sample_num
 
537
  encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
538
  print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
539
 
 
 
 
 
 
 
 
540
  #### text-to-image ####
541
  render_image, segmentation_mask_from_pillow = get_layout_from_prompt(args)
542
 
543
  segmentation_mask = torch.Tensor(np.array(segmentation_mask_from_pillow)).cuda() # (512, 512)
544
 
545
  segmentation_mask = filter_segmentation_mask(segmentation_mask)
546
+ segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(size//2, size//2), mode='nearest')
547
  segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (1, 1, 256, 256)
548
  print(f'{colored("[√]", "green")} character-level segmentation_mask: {segmentation_mask.shape}.')
549
 
550
+ feature_mask = torch.ones(sample_num, 1, size//8, size//8).to('cuda') # (b, 1, 64, 64)
551
+ masked_image = torch.zeros(sample_num, 3, size, size).to('cuda') # (b, 3, 512, 512)
552
  masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
553
  masked_feature = masked_feature * vae.config.scaling_factor
554
  print(f'{colored("[√]", "green")} feature_mask: {feature_mask.shape}.')
 
569
  input = 1 / vae.config.scaling_factor * input
570
  sample_images = vae.decode(input.float(), return_dict=False)[0] # (b, 3, 512, 512)
571
 
572
+ image_pil = render_image.resize((size,size))
573
  segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
574
+ character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((size,size))
575
  character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
576
+ character_mask_highlight_pil = character_mask_highlight_pil.resize((size, size))
577
  caption_pil = make_caption_pil(args.font_path, captions)
578
 
579
  # save pred_img
 
584
  image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
585
  pred_image_list.append(image)
586
 
587
+ blank_pil = combine_image(args, size, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
588
 
589
+ intermediate_result = Image.new('RGB', (size*3, size))
590
  intermediate_result.paste(image_pil, (0, 0))
591
+ intermediate_result.paste(character_mask_pil, (size, 0))
592
+ intermediate_result.paste(character_mask_highlight_pil, (size*2, 0))
593
 
594
  return blank_pil, intermediate_result
595
 
 
604
 
605
 
606
 
607
+ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidance,slider_batch, binary, version):
608
+
609
+ if version == 'Stable Diffusion v2.1':
610
+ vae = vae21
611
+ unet = unet21
612
+ text_encoder = text_encoder21
613
+ tokenizer = tokenizer21
614
+ scheduler = scheduler21
615
+ slider_batch = min(slider_batch, 2)
616
+ size = 768
617
+ elif version == 'Stable Diffusion v1.5':
618
+ vae = vae15
619
+ unet = unet15
620
+ text_encoder = text_encoder15
621
+ tokenizer = tokenizer15
622
+ scheduler = scheduler15
623
+ size = 512
624
+ else:
625
+ assert False, 'Version Not Found'
626
 
627
  if has_chinese_char(prompt):
628
  print('trigger')
 
631
  if slider_step>=50:
632
  slider_step = 50
633
 
634
+ orig_template_image = template_image.resize((size,size)).convert('RGB')
635
  args.prompt = prompt
636
  sample_num = slider_batch
637
  # If passed along, set the training seed now.
 
640
  set_seed(seed)
641
  scheduler.set_timesteps(slider_step)
642
 
643
+ noise = torch.randn((sample_num, 4, size//8, size//8)).to("cuda") # (b, 4, 64, 64)
644
  input = noise # (b, 4, 64, 64)
645
 
646
  captions = [args.prompt] * sample_num
 
679
  segmentation_mask = segmentation_mask.max(1)[1].squeeze(0) # (256, 256)
680
  segmentation_mask = filter_segmentation_mask(segmentation_mask) # (256, 256)
681
 
682
+ segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(size//2, size//2), mode='nearest') # (b, 1, 256, 256)
683
  segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (b, 1, 256, 256)
684
  print(f'{colored("[√]", "green")} Character-level segmentation_mask: {segmentation_mask.shape}.')
685
 
686
+ feature_mask = torch.ones(sample_num, 1, size//8, size//8).to('cuda') # (b, 1, 64, 64)
687
+ masked_image = torch.zeros(sample_num, 3, size, size).to('cuda') # (b, 3, 512, 512)
688
  masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
689
  masked_feature = masked_feature * vae.config.scaling_factor # (b, 4, 64, 64)
690
 
 
705
 
706
  image_pil = None
707
  segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
708
+ character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((size,size))
709
  character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
710
+ character_mask_highlight_pil = character_mask_highlight_pil.resize((size, size))
711
  caption_pil = make_caption_pil(args.font_path, captions)
712
 
713
  # save pred_img
 
718
  image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
719
  pred_image_list.append(image)
720
 
721
+ blank_pil = combine_image(args, size, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
722
 
723
+ intermediate_result = Image.new('RGB', (size*3, size))
724
  intermediate_result.paste(orig_template_image, (0, 0))
725
+ intermediate_result.paste(character_mask_pil, (size, 0))
726
+ intermediate_result.paste(character_mask_highlight_pil, (size*2, 0))
727
 
728
  return blank_pil, intermediate_result
729
 
730
 
731
+ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch, version):
732
+
733
+ if version == 'Stable Diffusion v2.1':
734
+ vae = vae21
735
+ unet = unet21
736
+ text_encoder = text_encoder21
737
+ tokenizer = tokenizer21
738
+ scheduler = scheduler21
739
+ slider_batch = min(slider_batch, 2)
740
+ size = 768
741
+ elif version == 'Stable Diffusion v1.5':
742
+ vae = vae15
743
+ unet = unet15
744
+ text_encoder = text_encoder15
745
+ tokenizer = tokenizer15
746
+ scheduler = scheduler15
747
+ size = 512
748
+ else:
749
+ assert False, 'Version Not Found'
750
 
751
  if has_chinese_char(prompt):
752
  print('trigger')
 
763
  set_seed(seed)
764
  scheduler.set_timesteps(slider_step)
765
 
766
+ noise = torch.randn((sample_num, 4, size//8, size//8)).to("cuda") # (b, 4, 64, 64)
767
  input = noise # (b, 4, 64, 64)
768
 
769
  captions = [args.prompt] * sample_num
 
783
  encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
784
  print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
785
 
786
+ mask_image = cv2.resize(mask_image, (size,size))
787
  # mask_image = mask_image.resize((512,512)).convert('RGB')
788
  text_mask = np.array(mask_image)
789
  threshold = 128
 
796
 
797
  segmentation_mask = segmentation_mask.max(1)[1].squeeze(0)
798
  segmentation_mask = filter_segmentation_mask(segmentation_mask)
799
+ segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(size//2, size//2), mode='nearest')
800
 
801
+ image_mask = transform_mask_pil(mask_image, size)
802
  image_mask = torch.from_numpy(image_mask).cuda().unsqueeze(0).unsqueeze(0)
803
 
804
+ orig_image = orig_image.convert('RGB').resize((size,size))
805
  image = orig_image
806
  image_tensor = to_tensor(image).unsqueeze(0).cuda().sub_(0.5).div_(0.5)
807
  masked_image = image_tensor * (1-image_mask)
808
  masked_feature = vae.encode(masked_image).latent_dist.sample().repeat(sample_num, 1, 1, 1)
809
  masked_feature = masked_feature * vae.config.scaling_factor
810
 
811
+ image_mask = torch.nn.functional.interpolate(image_mask, size=(size//2, size//2), mode='nearest').repeat(sample_num, 1, 1, 1)
812
  segmentation_mask = segmentation_mask * image_mask
813
+ feature_mask = torch.nn.functional.interpolate(image_mask, size=(size//8, size//8), mode='nearest')
814
 
815
  # diffusion process
816
  intermediate_images = []
 
831
  segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
832
  character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((512,512))
833
  character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
834
+ character_mask_highlight_pil = character_mask_highlight_pil.resize((size, size))
835
  caption_pil = make_caption_pil(args.font_path, captions)
836
 
837
  # save pred_img
 
851
  character_mask_highlight_pil.save('character_mask_highlight_pil.png')
852
 
853
 
854
+ blank_pil = combine_image(args, size, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
855
 
856
 
857
  background = orig_image.resize((512, 512))
 
890
  We propose <b>TextDiffuser</b>, a flexible and controllable framework to generate images with visually appealing text that is coherent with backgrounds.
891
  Main features include: (a) <b><font color="#A52A2A">Text-to-Image</font></b>: The user provides a prompt and encloses the keywords with single quotes (e.g., a text image of ‘hello’). The model first determines the layout of the keywords and then draws the image based on the layout and prompt. (b) <b><font color="#A52A2A">Text-to-Image with Templates</font></b>: The user provides a prompt and a template image containing text, which can be a printed, handwritten, or scene text image. These template images can be used to determine the layout of the characters. (c) <b><font color="#A52A2A">Text Inpainting</font></b>: The user provides an image and specifies the region to be modified along with the desired text content. The model is able to modify the original text or add text to areas without text.
892
  </h2>
893
+ <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
894
+ 🔥 <b>News</b>: We further trained TextDiffuser based on <b>Stable Diffusion v2.1</b> pre-trained model, enlarging the resolution from 512x512 to <b>768x768</b> to enhance the legibility of small text. Additionally, we fine-tuned the model with images with <b>high aesthetical score</b>, enabling generating images with richer details.
895
+ </h2>
896
+
897
+
898
  <img src="file/images/huggingface_blank.jpg" alt="textdiffuser">
899
  </div>
900
  """)
 
903
  with gr.Row():
904
  with gr.Column(scale=1):
905
  prompt = gr.Textbox(label="Input your prompt here. Please enclose keywords with 'single quotes', you may refer to the examples below. The current version only supports input in English characters.", placeholder="Placeholder 'Team' hat")
906
+ radio = gr.Radio(["Stable Diffusion v2.1", "Stable Diffusion v1.5"], label="Pre-trained Model", value="Stable Diffusion v2.1")
907
  slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
908
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
909
+ slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled. Maximum number is set to 【2】 for SD v2.1 to avoid OOM.")
910
  # slider_seed = gr.Slider(minimum=1, maximum=10000, label="Seed", randomize=True)
911
  button = gr.Button("Generate")
912
 
 
922
  [
923
  ["Distinguished poster of 'SPIDERMAN'. Trending on ArtStation and Pixiv. A vibrant digital oil painting. A highly detailed fantasy character illustration by Wayne Reynolds and Charles Monet and Gustave Dore and Carl Critchlow and Bram Sels"],
924
  ["A detailed portrait of a fox guardian with a shield with 'Kung Fu' written on it, by victo ngai and justin gerard, digital art, realistic painting, very detailed, fantasy, high definition, cinematic light, dnd, trending on artstation"],
925
+ ["portrait of a 'dragon', concept art, sumi - e style, intricate linework, green smoke, artstation, trending, highly detailed, smooth, focus, art by yoji shinkawa,"],
926
  ["elderly woman dressed in extremely colorful clothes with many strange patterns posing for a high fashion photoshoot of 'FASHION', haute couture, golden hour, artstation, by J. C. Leyendecker and Peter Paul Rubens"],
927
  ["epic digital art of a luxury yacht named 'Time Machine' driving through very dark hard edged city towers from tron movie, faint tall mountains in background, wlop, pixiv"],
928
  ["A poster of 'Adventurer'. A beautiful so tall boy with big eyes and small nose is in the jungle, he wears normal clothes and shows his full length, which we see from the front, unreal engine, cozy indoor lighting, artstation, detailed"],
 
947
  examples_per_page=100
948
  )
949
 
950
+ button.click(text_to_image, inputs=[prompt,slider_step,slider_guidance,slider_batch,radio], outputs=[output,intermediate_results])
951
 
952
  with gr.Tab("Text-to-Image-with-Template"):
953
  with gr.Row():
954
  with gr.Column(scale=1):
955
  prompt = gr.Textbox(label='Input your prompt here.')
956
  template_image = gr.Image(label='Template image', type="pil")
957
+ radio = gr.Radio(["Stable Diffusion v2.1", "Stable Diffusion v1.5"], label="Pre-trained Model", value="Stable Diffusion v2.1")
958
  slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
959
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
960
+ slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled. Maximum number is set to 【2】 for SD v2.1 to avoid OOM.")
961
  # binary = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
962
  binary = gr.Checkbox(label="Binarization", bool=True, info="Whether to binarize the template image? You may need it when using handwritten images as templates.")
963
  button = gr.Button("Generate")
 
995
  examples_per_page=100
996
  )
997
 
998
+ button.click(text_to_image_with_template, inputs=[prompt,template_image,slider_step,slider_guidance,slider_batch,binary,radio], outputs=[output,intermediate_results])
999
 
1000
  with gr.Tab("Text-Inpainting"):
1001
  with gr.Row():
 
1004
  with gr.Row():
1005
  orig_image = gr.Image(label='Original image', type="pil")
1006
  mask_image = gr.Image(label='Mask image', type="numpy")
1007
+ radio = gr.Radio(["Stable Diffusion v2.1", "Stable Diffusion v1.5"], label="Pre-trained Model", value="Stable Diffusion v2.1")
1008
  slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
1009
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
1010
+ slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled. Maximum number is set to 【2】 for SD v2.1 to avoid OOM.")
1011
  button = gr.Button("Generate")
1012
  with gr.Column(scale=1):
1013
  output = gr.Image(label='Generated image')
 
1042
  )
1043
 
1044
 
1045
+ button.click(text_inpainting, inputs=[prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch,radio], outputs=[output, intermediate_results])
1046
 
1047
 
1048