Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 44 additions & 20 deletions demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@
parser = argparse.ArgumentParser()
parser.add_argument('--share', action='store_true')
parser.add_argument("--server", type=str, default='0.0.0.0')
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--port", type=int, required=False)
parser.add_argument("--inbrowser", action='store_true')
args = parser.parse_args()

# for win desktop probably use --server 127.0.0.1 --inbrowser
# For linux server probably use --server 127.0.0.1 or do not use any cmd flags

print(args)

free_mem_gb = get_cuda_free_memory_gb(gpu)
high_vram = free_mem_gb > 40
high_vram = free_mem_gb > 60

print(f'Free VRAM {free_mem_gb} GB')
print(f'High-VRAM Mode: {high_vram}')
Expand All @@ -64,7 +68,6 @@
vae.enable_tiling()

transformer.high_quality_fp32_output_for_inference = True
print('transformer.high_quality_fp32_output_for_inference = True')

transformer.to(dtype=torch.bfloat16)
vae.to(dtype=torch.float16)
Expand Down Expand Up @@ -94,9 +97,25 @@
outputs_folder = './outputs/'
os.makedirs(outputs_folder, exist_ok=True)

def _encode_prompt_and_generate_attention_mask(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2):
prompts = prompt.split(';')
llama_vecs = []
clip_l_poolers = []
llama_attention_masks = []
for prompt in prompts:
llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec = llama_vec.to(transformer.dtype)
clip_l_pooler = clip_l_pooler.to(transformer.dtype)
llama_attention_masks.append(llama_attention_mask)
llama_vecs.append(llama_vec)
clip_l_poolers.append(clip_l_pooler)
return list(llama_vecs), list(clip_l_poolers), list(llama_attention_masks)


@torch.no_grad()
def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache):
def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))

Expand All @@ -119,14 +138,13 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
load_model_as_complete(text_encoder_2, target_device=gpu)

llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
llama_vecs, clip_l_poolers, llama_attention_masks = _encode_prompt_and_generate_attention_mask(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

if cfg == 1:
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vecs[0]), torch.zeros_like(clip_l_poolers[0])
else:
llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)

# Processing input image
Expand All @@ -136,7 +154,6 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
H, W, C = input_image.shape
height, width = find_nearest_bucket(H, W, resolution=640)
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)

Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))

input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
Expand All @@ -162,10 +179,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state

# Dtype

llama_vec = llama_vec.to(transformer.dtype)
llama_vec_n = llama_vec_n.to(transformer.dtype)
clip_l_pooler = clip_l_pooler.to(transformer.dtype)
clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)

Expand All @@ -189,6 +203,10 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
# use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]

first_llama_vec = llama_vecs[0]
first_clip_l_pooler = clip_l_poolers[0]
first_llama_attention_mask = llama_attention_masks[0]

for latent_padding in latent_paddings:
is_last_section = latent_padding == 0
latent_padding_size = latent_padding * latent_window_size
Expand Down Expand Up @@ -246,9 +264,9 @@ def callback(d):
# shift=3.0,
num_inference_steps=steps,
generator=rnd,
prompt_embeds=llama_vec,
prompt_embeds_mask=llama_attention_mask,
prompt_poolers=clip_l_pooler,
prompt_embeds=llama_vecs.pop() if llama_vecs else first_llama_vec,
prompt_embeds_mask=llama_attention_masks.pop() if llama_attention_masks else first_llama_attention_mask,
prompt_poolers=clip_l_poolers.pop() if clip_l_poolers else first_clip_l_pooler,
negative_prompt_embeds=llama_vec_n,
negative_prompt_embeds_mask=llama_attention_mask_n,
negative_prompt_poolers=clip_l_pooler_n,
Expand Down Expand Up @@ -291,7 +309,7 @@ def callback(d):

output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')

save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf)

print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')

Expand All @@ -311,15 +329,15 @@ def callback(d):
return


def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache):
def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
global stream
assert input_image is not None, 'No input image!'

yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)

stream = AsyncStream()

async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache)
async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf)

output_filename = None

Expand Down Expand Up @@ -356,8 +374,8 @@ def end_process():
gr.Markdown('# FramePack')
with gr.Row():
with gr.Column():
input_image = gr.Image(sources='upload', type="numpy", label="Image", height=320)
prompt = gr.Textbox(label="Prompt", value='')
input_image = gr.Image(sources='upload', type="numpy", label="Image", height=320)
prompt = gr.Textbox(label="Prompt 1", value='', lines=5)
example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Quick List', samples_per_page=1000, components=[prompt])
example_quick_prompts.click(lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=False)

Expand All @@ -381,13 +399,18 @@ def end_process():

gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.")

mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ")

with gr.Column():
preview_image = gr.Image(label="Next Latents", height=200, visible=False)
result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True)
gr.Markdown('Note that the ending actions will be generated before the starting actions due to the inverted sampling. If the starting action is not in the video, you just need to wait, and it will be generated later.')
progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
progress_bar = gr.HTML('', elem_classes='no-generating-animation')
ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache]

gr.HTML('<div style="text-align:center; margin-top:20px;">Share your results and find ideas at the <a href="https://x.com/search?q=framepack&f=live" target="_blank">FramePack Twitter (X) thread</a></div>')

ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf]
start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
end_button.click(fn=end_process)

Expand All @@ -396,4 +419,5 @@ def end_process():
server_name=args.server,
server_port=args.port,
share=args.share,
inbrowser=args.inbrowser,
)