File size: 5,797 Bytes
8f29722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
#!/usr/bin/env python
from __future__ import annotations
import enum
import gradio as gr
from huggingface_hub import HfApi
from inference import InferencePipeline
from utils import find_exp_dirs
SAMPLE_MODEL_IDS = ['patrickvonplaten/lora_dreambooth_dog_example']
class ModelSource(enum.Enum):
SAMPLE = 'Sample'
HUB_LIB = 'Hub (lora-library)'
LOCAL = 'Local'
class InferenceUtil:
def __init__(self, hf_token: str | None):
self.hf_token = hf_token
@staticmethod
def load_sample_lora_model_list():
return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
def load_hub_lora_model_list(self) -> dict:
api = HfApi(token=self.hf_token)
choices = [
info.modelId for info in api.list_models(author='lora-library')
]
return gr.update(choices=choices,
value=choices[0] if choices else None)
@staticmethod
def load_local_lora_model_list() -> dict:
choices = find_exp_dirs()
return gr.update(choices=choices,
value=choices[0] if choices else None)
def reload_lora_model_list(self, model_source: str) -> dict:
if model_source == ModelSource.SAMPLE.value:
return self.load_sample_lora_model_list()
elif model_source == ModelSource.HUB_LIB.value:
return self.load_hub_lora_model_list()
elif model_source == ModelSource.LOCAL.value:
return self.load_local_lora_model_list()
else:
raise ValueError
def load_model_info(self, lora_model_id: str) -> tuple[str, str]:
try:
card = InferencePipeline.get_model_card(lora_model_id,
self.hf_token)
except Exception:
return '', ''
base_model = getattr(card.data, 'base_model', '')
instance_prompt = getattr(card.data, 'instance_prompt', '')
return base_model, instance_prompt
def create_inference_demo(pipe: InferencePipeline,
hf_token: str | None = None) -> gr.Blocks:
app = InferenceUtil(hf_token)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Box():
model_source = gr.Radio(
label='Model Source',
choices=[_.value for _ in ModelSource],
value=ModelSource.SAMPLE.value)
reload_button = gr.Button('Reload Model List')
lora_model_id = gr.Dropdown(label='LoRA Model ID',
choices=SAMPLE_MODEL_IDS,
value=SAMPLE_MODEL_IDS[0])
with gr.Accordion(
label=
'Model info (Base model and instance prompt used for training)',
open=False):
with gr.Row():
base_model_used_for_training = gr.Text(
label='Base model', interactive=False)
instance_prompt_used_for_training = gr.Text(
label='Instance prompt', interactive=False)
prompt = gr.Textbox(
label='Prompt',
max_lines=1,
placeholder='Example: "A picture of a sks dog in a bucket"'
)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
step=1,
value=0)
with gr.Accordion('Other Parameters', open=False):
num_steps = gr.Slider(label='Number of Steps',
minimum=0,
maximum=100,
step=1,
value=25)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
run_button = gr.Button('Generate')
gr.Markdown('''
- After training, you can press "Reload Model List" button to load your trained model names.
''')
with gr.Column():
result = gr.Image(label='Result')
model_source.change(fn=app.reload_lora_model_list,
inputs=model_source,
outputs=lora_model_id)
reload_button.click(fn=app.reload_lora_model_list,
inputs=model_source,
outputs=lora_model_id)
lora_model_id.change(fn=app.load_model_info,
inputs=lora_model_id,
outputs=[
base_model_used_for_training,
instance_prompt_used_for_training,
])
inputs = [
lora_model_id,
prompt,
seed,
num_steps,
guidance_scale,
]
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
return demo
if __name__ == '__main__':
import os
hf_token = os.getenv('HF_TOKEN')
pipe = InferencePipeline(hf_token)
demo = create_inference_demo(pipe, hf_token)
demo.queue(max_size=10).launch(share=False)
|