image_modification / dataset_creation /generate_txt_dataset.py
holynski's picture
Updates for v1 release
926ff6c
raw
history blame
4.41 kB
from __future__ import annotations
import json
import time
from argparse import ArgumentParser
from pathlib import Path
from typing import Optional
import datasets
import numpy as np
import openai
from tqdm.auto import tqdm
DELIMITER_0 = "\n##\n"
DELIMITER_1 = "\n%%\n"
STOP = "\nEND"
def generate(
openai_model: str,
caption: str,
num_retries: int = 3,
max_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 1.0,
frequency_penalty: float = 0.1,
presence_penalty: float = 0.0,
sleep_on_error: float = 1.0,
) -> Optional[tuple[str, str]]:
for _ in range(1 + num_retries):
try:
response = openai.Completion.create(
model=openai_model,
prompt=caption + DELIMITER_0,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
stop=[STOP],
)
except Exception as e:
print(e)
time.sleep(sleep_on_error)
continue
output = response["choices"][0]["text"].split(DELIMITER_1)
if len(output) == 2:
instruction, edited_caption = output
results = openai.Moderation.create([instruction, edited_caption])["results"]
if results[0]["flagged"] or results[1]["flagged"]:
continue
if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
return instruction, edited_caption
def main(openai_model: str, num_samples: int, num_partitions: int, partition: int, seed: int):
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
# Other datasets we considered that may be worth trying:
# dataset = datasets.load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train")
# dataset = datasets.load_dataset("laion/laion-coco", split="train")
np.random.seed(seed)
permutation = np.array_split(np.random.permutation(len(dataset)), num_partitions)[partition]
dataset = dataset[permutation]
captions = dataset["TEXT"]
urls = dataset["URL"]
output_path = f"data/dataset=laion-aesthetics-6.5_model={openai_model}_samples={num_samples}_partition={partition}.jsonl" # fmt: skip
print(f"Prompt file path: {output_path}")
count = 0
caption_set = set()
url_set = set()
if Path(output_path).exists():
with open(output_path, "r") as f:
for line in tqdm(f, desc="Resuming from existing prompts"):
prompt = json.loads(line)
if prompt["caption"] not in caption_set and prompt["url"] not in url_set:
caption_set.add(prompt["caption"])
url_set.add(prompt["url"])
count += 1
with open(output_path, "a") as fp:
with tqdm(total=num_samples - count, desc="Generating instructions and edited captions") as progress_bar:
for caption, url in zip(captions, urls):
if caption in caption_set or url in url_set:
continue
if openai.Moderation.create(caption)["results"][0]["flagged"]:
continue
edit_output = generate(openai_model, caption)
if edit_output is not None:
edit, output = edit_output
fp.write(f"{json.dumps(dict(caption=caption, edit=edit, output=output, url=url))}\n")
count += 1
progress_bar.update()
caption_set.add(caption)
url_set.add(url)
if count == num_samples:
break
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--openai-api-key", required=True, type=str)
parser.add_argument("--openai-model", required=True, type=str)
parser.add_argument("--num-samples", default=10000, type=int)
parser.add_argument("--num-partitions", default=1, type=int)
parser.add_argument("--partition", default=0, type=int)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()
openai.api_key = args.openai_api_key
main(args.openai_model, args.num_samples, args.num_partitions, args.partition, args.seed)