Suggestion: generate text with contrastive search

#25
by joaogante HF staff - opened

Hi @Gustavosta ๐Ÿ‘‹

I'm Joao and I'm one of the Hugging Face engineers working on text generation! First of all, I'm in love with this Space <3

It came to my attention that most of the outputs start by adding a variation to the input, followed by art styles/names of artists/other keywords. This is done through the pipeline default behavior, which samples one token at a time. That makes complete sense for the first part of the output (add variations to the input). However, for the second part of the output, it seems to me that we would benefit from gathering the best keywords (according to the model) given the text generated so far. For that task, sampling tends to underperform. Unless we want pure randomness, that is :)

I've played around with your demo, and added a slider to control the number of sampled tokens, populating the rest of the text using our brand new Contrastive Search, which picks the highest-scoring tokens while avoiding repetition. In my limited observation, gathering ~10 tokens with sampling and the rest with contrastive search keeps the variability of the results high, while bringing the quality of the generated images up.

I'm pasting my local version of your app below. Have a look at it and, if you feel like it would improve the results, I'd be delighted to help with any integration issues!

from transformers import pipeline, set_seed
import gradio as grad, random, re


gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2')
with open("ideas.txt", "r") as f:
    line = f.readlines()


def generate(starting_text, sampled_tokens):
    seed = random.randint(100, 1000000)
    set_seed(seed)

    if starting_text == "":
        starting_text: str = line[random.randrange(0, len(line))].replace("\n", "").lower().capitalize()
        starting_text: str = re.sub(r"[,:\-โ€“.!;?_]", '', starting_text)

    # Puts the starting text through sampling, to add some entropy
    intermediary_response = gpt2_pipe(starting_text, max_new_tokens=sampled_tokens, num_return_sequences=4)
    intermediary_response_list = []
    for x in intermediary_response:
        resp = x['generated_text'].strip()
        if resp != starting_text and resp.endswith((":", "-", "โ€”")) is False:
            intermediary_response_list.append(resp)

    # Generates the rest of the text using contrastive search
    contrastive_tokens = 90 - sampled_tokens
    response = gpt2_pipe(
        intermediary_response_list, max_new_tokens=contrastive_tokens, do_sample=False, penalty_alpha=0.6, top_k=6
    )
    response_list = []
    for x in response:
        resp = x[0]['generated_text'].strip()
        if resp != starting_text and len(resp) > (len(starting_text) + 4) and resp.endswith((":", "-", "โ€”")) is False:
            if resp not in response_list:
                response_list.append(resp)

    response_end = "\n\n".join(response_list)
    response_end = re.sub('[^ ]+\.[^ ]+', '', response_end)
    response_end = response_end.replace("<", "").replace(">", "")

    if response_end != "":
        return response_end


txt = grad.Textbox(lines=1, label="Initial Text", placeholder="English Text here")
rand_len = grad.Slider(minimum=1, maximum=90, value=10, step=1, label="Sampled new tokens (entropy level)")
out = grad.Textbox(lines=4, label="Generated Prompts")

# examples = []
# for x in range(8):
#     examples.append(line[random.randrange(0, len(line))].replace("\n", "").lower().capitalize())

title = "Stable Diffusion Prompt Generator"
description = 'This is a demo of the model series: "MagicPrompt", in this case, aimed at: "Stable Diffusion". To use it, simply submit your text or click on one of the examples. To learn more about the model, [click here](https://huggingface.co/Gustavosta/MagicPrompt-Stable-Diffusion).<br>'

grad.Interface(fn=generate,
               inputs=[txt, rand_len],
               outputs=out,
            #    examples=examples,
               title=title,
               description=description,
               article='',
               allow_flagging='never',
               cache_examples=False,
               theme="default").launch(enable_queue=True, debug=True)

Sign up or log in to comment