from pywebio import start_server from pywebio.input import * from pywebio.output import * import pywebio.pin as pin import time from diffusers import FluxPipeline import torch from pathlib import Path import re from datetime import datetime default_prompt="A cat wearing an old racing helmet and goggles holding a sign that says zoom zoom" default_width=1024 default_height=512 default_numimgs=4 ckpt_id = "./FLUX.1-schnell" # denoising pipe = FluxPipeline.from_pretrained( ckpt_id, torch_dtype=torch.bfloat16, use_safetensors=True, ) pipe.vae.enable_tiling() pipe.vae.enable_slicing() pipe.enable_sequential_cpu_offload() # offloads modules to CPU on a submodule level (rather than model level) def slugify(text): # remove non-word characters and foreign characters text = re.sub(r"[^\w\s]", "", text) text = re.sub(r"\s+", "-", text) return text def set_dir(): blazar_path="/home/tonydero/remdirs/immich/" magnetar_path="/nas/dockerdata/immich/local/" if Path(blazar_path).exists(): dirpath = Path(blazar_path) elif Path(magnetar_path).exists(): dirpath = Path(magnetar_path) else: dirpath = False return dirpath def chk_prompt_len(prompt): if len(prompt) > 370: return False else: return True def flux_run(prompt, height, width, num_images_per_prompt, num_inference_steps, max_sequence_length, dirpath): # FIX: add check to make sure there isn't already one running from another source (e.g. from phone when on computer) output = pipe( prompt, height=height, width=width, num_images_per_prompt=num_images_per_prompt, num_inference_steps=num_inference_steps, max_sequence_length=max_sequence_length, guidance_scale=0.0, ) # print('Max mem allocated (GB) while denoising:', torch.cuda.max_memory_allocated() / (1024 ** 3)) for idx, image in enumerate(output.images): timestamp = datetime.now().strftime("%Y%m%d%-H%M%S") image_name = f'{slugify(prompt[:64])}-{idx}-{timestamp}.png' image_path = dirpath / image_name image.save(image_path) def btn_click(btn_val): dirpath = set_dir() if dirpath: if btn_val == 'Generate': if chk_prompt_len(pin.pin.prompt): with put_loading(): start_time = datetime.now() put_text("Started generating images at " + start_time.strftime("%H:%M:%S")) flux_run(pin.pin.prompt, pin.pin.height, pin.pin.width, pin.pin.numimgs, 4, 128, dirpath) stop_time = datetime.now() run_time = stop_time - start_time # rt_min = run_time.minute # rt_sec = run_time.second put_success("Images finished generating in " + str(run_time), closable=True) else: put_error("Prompt too long.", closable=True) elif btn_val == 'Portrait': pin.pin.width = 512 pin.pin.height = 1024 elif btn_val == 'Landscape': pin.pin.width = 1024 pin.pin.height = 512 elif btn_val == 'Square': pin.pin.width = 512 pin.pin.height = 512 else: put_error("Failed to generate. Check directory location.", closable=True) def main(): # PyWebIO application function # TODO: figure out how to put the labels in line with the inputs and make the number ones smaller # put_text("Prompt:", inline=True) pin.put_input("prompt", label="Prompt:", type="text", value=default_prompt) # put_text("width:", inline=True) pin.put_input("width", label="Width:", type="number", value=default_width) # put_text("height:", inline=True) pin.put_input("height", label="Height:", type="number", value=default_height) # put_text("number of images:", inline=True) pin.put_input("numimgs", label="Number of images:", type="number", value=default_numimgs) # TODO: add the remaining parameters? # TODO: add ability to loop for more generation in one run? put_buttons(['Generate', 'Portrait', 'Landscape', 'Square'], onclick=btn_click) start_server(main, port=4972, debug=True, auto_open_webbrowser=False)