Spaces:
Paused
Paused
import json | |
import webcolors | |
def closest_color(requested_color): | |
min_colors = {} | |
for key, name in webcolors.CSS3_HEX_TO_NAMES.items(): | |
r_c, g_c, b_c = webcolors.hex_to_rgb(key) | |
rd = (r_c - requested_color[0]) ** 2 | |
gd = (g_c - requested_color[1]) ** 2 | |
bd = (b_c - requested_color[2]) ** 2 | |
min_colors[(rd + gd + bd)] = name | |
return min_colors[min(min_colors.keys())] | |
def convert_rgb_to_names(rgb_tuple): | |
try: | |
color_name = webcolors.rgb_to_name(rgb_tuple) | |
except ValueError: | |
color_name = closest_color(rgb_tuple) | |
return color_name | |
class PromptFormat(): | |
def __init__( | |
self, | |
font_path: str = 'assets/font_idx_512.json', | |
color_path: str = 'assets/color_idx.json', | |
): | |
with open(font_path, 'r') as f: | |
self.font_dict = json.load(f) | |
with open(color_path, 'r') as f: | |
self.color_dict = json.load(f) | |
def format_checker(self, texts, styles): | |
assert len(texts) == len(styles), 'length of texts must be equal to length of styles' | |
for style in styles: | |
assert style['font-family'] in self.font_dict, f"invalid font-family: {style['font-family']}" | |
rgb_color = webcolors.hex_to_rgb(style['color']) | |
color_name = convert_rgb_to_names(rgb_color) | |
assert color_name in self.color_dict, f"invalid color hex {color_name}" | |
def format_prompt(self, texts, styles): | |
self.format_checker(texts, styles) | |
prompt = "" | |
''' | |
Text "{text}" in {color}, {type}. | |
''' | |
for text, style in zip(texts, styles): | |
text_prompt = f'Text "{text}"' | |
attr_list = [] | |
# format color | |
hex_color = style["color"] | |
rgb_color = webcolors.hex_to_rgb(hex_color) | |
color_name = convert_rgb_to_names(rgb_color) | |
attr_list.append(f"<color-{self.color_dict[color_name]}>") | |
# format font | |
attr_list.append(f"<font-{self.font_dict[style['font-family']]}>") | |
attr_suffix = ", ".join(attr_list) | |
text_prompt += " in " + attr_suffix | |
text_prompt += ". " | |
prompt = prompt + text_prompt | |
return prompt | |
class MultilingualPromptFormat(): | |
def __init__( | |
self, | |
font_path: str = 'assets/multilingual_cn-en_font_idx.json', | |
color_path: str = 'assets/color_idx.json', | |
): | |
with open(font_path, 'r') as f: | |
self.font_dict = json.load(f) | |
with open(color_path, 'r') as f: | |
self.color_dict = json.load(f) | |
def format_checker(self, texts, styles): | |
assert len(texts) == len(styles), 'length of texts must be equal to length of styles' | |
for style in styles: | |
assert style['font-family'] in self.font_dict, f"invalid font-family: {style['font-family']}" | |
rgb_color = webcolors.hex_to_rgb(style['color']) | |
color_name = convert_rgb_to_names(rgb_color) | |
assert color_name in self.color_dict, f"invalid color hex {color_name}" | |
def format_prompt(self, texts, styles): | |
self.format_checker(texts, styles) | |
prompt = "" | |
''' | |
Text "{text}" in {color}, {type}. | |
''' | |
for text, style in zip(texts, styles): | |
text_prompt = f'Text "{text}"' | |
attr_list = [] | |
# format color | |
hex_color = style["color"] | |
rgb_color = webcolors.hex_to_rgb(hex_color) | |
color_name = convert_rgb_to_names(rgb_color) | |
attr_list.append(f"<color-{self.color_dict[color_name]}>") | |
# format font | |
attr_list.append(f"<{style['font-family'][:2]}-font-{self.font_dict[style['font-family']]}>") | |
attr_suffix = ", ".join(attr_list) | |
text_prompt += " in " + attr_suffix | |
text_prompt += ". " | |
prompt = prompt + text_prompt | |
return prompt | |