import gradio as gr import subprocess from .common_gui import ( get_folder_path, add_pre_postfix, scriptdir, list_dirs, get_executable_path, setup_environment, ) from .class_gui_config import KohyaSSGUIConfig import os from .custom_logging import setup_logging # Set up logging log = setup_logging() old_onnx_value = True def caption_images( train_data_dir: str, caption_extension: str, batch_size: int, general_threshold: float, character_threshold: float, repo_id: str, recursive: bool, max_data_loader_n_workers: int, debug: bool, undesired_tags: str, frequency_tags: bool, always_first_tags: str, onnx: bool, append_tags: bool, force_download: bool, caption_separator: str, tag_replacement: bool, character_tag_expand: str, use_rating_tags: bool, use_rating_tags_as_last_tag: bool, remove_underscore: bool, thresh: float, ) -> None: # Check for images_dir_input if train_data_dir == "": log.info("Image folder is missing...") return if caption_extension == "": log.info("Please provide an extension for the caption files.") return repo_id_converted = repo_id.replace("/", "_") if not os.path.exists(f"./wd14_tagger_model/{repo_id_converted}"): force_download = True log.info(f"Captioning files in {train_data_dir}...") run_cmd = [ rf'{get_executable_path("accelerate")}', "launch", rf"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py", ] # Uncomment and modify if needed # if always_first_tags != "": # run_cmd.append('--always_first_tags') # run_cmd.append(always_first_tags) if append_tags: run_cmd.append("--append_tags") run_cmd.append("--batch_size") run_cmd.append(str(int(batch_size))) run_cmd.append("--caption_extension") run_cmd.append(caption_extension) run_cmd.append("--caption_separator") run_cmd.append(caption_separator) if character_tag_expand: run_cmd.append("--character_tag_expand") if not character_threshold == 0.35: run_cmd.append("--character_threshold") run_cmd.append(str(character_threshold)) if debug: run_cmd.append("--debug") if force_download: run_cmd.append("--force_download") if frequency_tags: run_cmd.append("--frequency_tags") if not general_threshold == 0.35: run_cmd.append("--general_threshold") run_cmd.append(str(general_threshold)) run_cmd.append("--max_data_loader_n_workers") run_cmd.append(str(int(max_data_loader_n_workers))) if onnx: run_cmd.append("--onnx") if recursive: run_cmd.append("--recursive") if remove_underscore: run_cmd.append("--remove_underscore") run_cmd.append("--repo_id") run_cmd.append(repo_id) if not tag_replacement == "": run_cmd.append("--tag_replacement") run_cmd.append(tag_replacement) if not thresh == 0.35: run_cmd.append("--thresh") run_cmd.append(str(thresh)) if not undesired_tags == "": run_cmd.append("--undesired_tags") run_cmd.append(undesired_tags) if use_rating_tags: run_cmd.append("--use_rating_tags") if use_rating_tags_as_last_tag: run_cmd.append("--use_rating_tags_as_last_tag") # Add the directory containing the training data run_cmd.append(rf"{train_data_dir}") env = setup_environment() # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) log.info(f"Executing command: {command_to_run}") # Run the command in the sd-scripts folder context subprocess.run(run_cmd, env=env) # Add prefix and postfix add_pre_postfix( folder=train_data_dir, caption_file_ext=caption_extension, prefix=always_first_tags, recursive=recursive, ) log.info("...captioning done") ### # Gradio UI ### def gradio_wd14_caption_gui_tab( headless=False, default_train_dir=None, config: KohyaSSGUIConfig = {}, ): from .common_gui import create_refresh_button default_train_dir = ( default_train_dir if default_train_dir is not None else os.path.join(scriptdir, "data") ) current_train_dir = default_train_dir def list_train_dirs(path): nonlocal current_train_dir current_train_dir = path return list(list_dirs(path)) with gr.Tab("WD14 Captioning"): gr.Markdown( "This utility will use WD14 to caption files for each images in a folder." ) # Input Settings # with gr.Section('Input Settings'): with gr.Group(), gr.Row(): train_data_dir = gr.Dropdown( label="Image folder to caption (containing the images to caption)", choices=[config.get("wd14_caption.train_data_dir", "")] + list_train_dirs(default_train_dir), value=config.get("wd14_caption.train_data_dir", ""), interactive=True, allow_custom_value=True, ) create_refresh_button( train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(current_train_dir)}, "open_folder_small", ) button_train_data_dir_input = gr.Button( "📂", elem_id="open_folder_small", elem_classes=["tool"], visible=(not headless), ) button_train_data_dir_input.click( get_folder_path, outputs=train_data_dir, show_progress=False, ) repo_id = gr.Dropdown( label="Repo ID", choices=[ "SmilingWolf/wd-v1-4-convnext-tagger-v2", "SmilingWolf/wd-v1-4-convnextv2-tagger-v2", "SmilingWolf/wd-v1-4-vit-tagger-v2", "SmilingWolf/wd-v1-4-swinv2-tagger-v2", "SmilingWolf/wd-v1-4-moat-tagger-v2", "SmilingWolf/wd-swinv2-tagger-v3", "SmilingWolf/wd-vit-tagger-v3", "SmilingWolf/wd-convnext-tagger-v3", ], value=config.get( "wd14_caption.repo_id", "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" ), show_label="Repo id for wd14 tagger on Hugging Face", ) force_download = gr.Checkbox( label="Force model re-download", value=config.get("wd14_caption.force_download", False), info="Useful to force model re download when switching to onnx", ) with gr.Row(): caption_extension = gr.Dropdown( label="Caption file extension", choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, allow_custom_value=True, ) caption_separator = gr.Textbox( label="Caption Separator", value=config.get("wd14_caption.caption_separator", ", "), interactive=True, ) with gr.Row(): tag_replacement = gr.Textbox( label="Tag replacement", info="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`", value=config.get("wd14_caption.tag_replacement", ""), interactive=True, ) character_tag_expand = gr.Checkbox( label="Character tag expand", info="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`", value=config.get("wd14_caption.character_tag_expand", False), interactive=True, ) undesired_tags = gr.Textbox( label="Undesired tags", placeholder="(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.", interactive=True, value=config.get("wd14_caption.undesired_tags", ""), ) with gr.Row(): always_first_tags = gr.Textbox( label="Prefix to add to WD14 caption", info="comma-separated list of tags to always put at the beginning, e.g.: 1girl, 1boy, ", placeholder="(Optional)", interactive=True, value=config.get("wd14_caption.always_first_tags", ""), ) with gr.Row(): onnx = gr.Checkbox( label="Use onnx", value=config.get("wd14_caption.onnx", True), interactive=True, info="https://github.com/onnx/onnx", ) append_tags = gr.Checkbox( label="Append TAGs", value=config.get("wd14_caption.append_tags", False), interactive=True, info="This option appends the tags to the existing tags, instead of replacing them.", ) use_rating_tags = gr.Checkbox( label="Use rating tags", value=config.get("wd14_caption.use_rating_tags", False), interactive=True, info="Adds rating tags as the first tag", ) use_rating_tags_as_last_tag = gr.Checkbox( label="Use rating tags as last tag", value=config.get("wd14_caption.use_rating_tags_as_last_tag", False), interactive=True, info="Adds rating tags as the last tag", ) with gr.Row(): recursive = gr.Checkbox( label="Recursive", value=config.get("wd14_caption.recursive", False), info="Tag subfolders images as well", ) remove_underscore = gr.Checkbox( label="Remove underscore", value=config.get("wd14_caption.remove_underscore", True), info="replace underscores with spaces in the output tags", ) debug = gr.Checkbox( label="Debug", value=config.get("wd14_caption.debug", True), info="Debug mode", ) frequency_tags = gr.Checkbox( label="Show tags frequency", value=config.get("wd14_caption.frequency_tags", True), info="Show frequency of tags for images.", ) with gr.Row(): thresh = gr.Slider( value=config.get("wd14_caption.thresh", 0.35), label="Threshold", info="threshold of confidence to add a tag", minimum=0, maximum=1, step=0.05, ) general_threshold = gr.Slider( value=config.get("wd14_caption.general_threshold", 0.35), label="General threshold", info="Adjust `general_threshold` for pruning tags (less tags, less flexible)", minimum=0, maximum=1, step=0.05, ) character_threshold = gr.Slider( value=config.get("wd14_caption.character_threshold", 0.35), label="Character threshold", minimum=0, maximum=1, step=0.05, ) # Advanced Settings with gr.Row(): batch_size = gr.Number( value=config.get("wd14_caption.batch_size", 1), label="Batch size", interactive=True, ) max_data_loader_n_workers = gr.Number( value=config.get("wd14_caption.max_data_loader_n_workers", 2), label="Max dataloader workers", interactive=True, ) def repo_id_changes(repo_id, onnx): global old_onnx_value if "-v3" in repo_id: old_onnx_value = onnx return gr.Checkbox(value=True, interactive=False) else: return gr.Checkbox(value=old_onnx_value, interactive=True) repo_id.change(repo_id_changes, inputs=[repo_id, onnx], outputs=[onnx]) caption_button = gr.Button("Caption images") caption_button.click( caption_images, inputs=[ train_data_dir, caption_extension, batch_size, general_threshold, character_threshold, repo_id, recursive, max_data_loader_n_workers, debug, undesired_tags, frequency_tags, always_first_tags, onnx, append_tags, force_download, caption_separator, tag_replacement, character_tag_expand, use_rating_tags, use_rating_tags_as_last_tag, remove_underscore, thresh, ], show_progress=False, ) train_data_dir.change( fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)), inputs=train_data_dir, outputs=train_data_dir, show_progress=False, )