radames's picture
add-quicksearch (#1)
52d5c1d verified
import gradio as gr
import yaml
from gradio_huggingfacehub_search import HuggingfaceHubSearch
MARKDOWN_DESCRIPTION = """
# mergekit config.yaml generator
GUI to template a YAML configuration file for mergekit, which you can then copy/paste into [mergekit-gui](https://huggingface.co/spaces/arcee-ai/mergekit-gui) 🔥
"""
DEFAULT_PARAMETERS = """
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5
"""
def create_config_yaml(
model1,
model1_layers,
model2,
model2_layers,
merge_method,
base_model,
parameters,
dtype,
) -> str:
dict_config = {
"slices": [
{
"sources": [
{"model": model1, "layer_range": yaml.safe_load(model1_layers)},
{"model": model2, "layer_range": yaml.safe_load(model2_layers)},
]
}
],
"merge_method": merge_method,
"base_model": base_model,
}
if parameters:
dict_config["parameters"] = yaml.safe_load(parameters)
if dtype:
dict_config["dtype"] = dtype
return yaml.dump(dict_config, sort_keys=False)
# make sure to add the themes as well
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
# model_name_input = gr.Textbox(label="Model Name", value="my-merge")
model1_input = HuggingfaceHubSearch(
label="Model 1",
placeholder="Search for model 1 on Huggingface",
search_type="model",
value="BioMistral/BioMistral-7B"
)
model1_layers_input = gr.Textbox(
label="Model 1 Layer Range", placeholder="[start, end]", value="[0, 32]"
)
model2_input = HuggingfaceHubSearch(
label="Model 2",
placeholder="Search for model 2 on Huggingface",
search_type="model",
value="CorticalStack/pastiche-crown-clown-7b-dare-dpo"
)
model2_layers_input = gr.Textbox(
label="Model 2 Layer Range", placeholder="[start, end]", value="[0, 32]"
)
merge_method_input = gr.Dropdown(
label="Merge Method", choices=["slerp", "linear"], value="slerp"
)
base_model_input = gr.Textbox(label="Base Model", value="BioMistral/BioMistral-7B")
parameters_input = gr.Code(
language="yaml",
label="Merge Parameters",
value=DEFAULT_PARAMETERS,
)
dtype_input = gr.Textbox(label="Dtype", value="bfloat16")
create_button = gr.Button("Create config.yaml", variant="primary")
output_zone = gr.Code(language="yaml", lines=10)
create_button.click(
fn=create_config_yaml,
inputs=[
model1_input,
model1_layers_input,
model2_input,
model2_layers_input,
merge_method_input,
base_model_input,
parameters_input,
dtype_input,
],
outputs=[output_zone],
)
gr.Markdown("A Space by [1littlecoder](https://huggingface.co/1littlecoder)")
demo.launch()