File size: 3,993 Bytes
d3653d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
import webcolors


def closest_color(requested_color):  
    min_colors = {}  
    for key, name in webcolors.CSS3_HEX_TO_NAMES.items():  
        r_c, g_c, b_c = webcolors.hex_to_rgb(key)  
        rd = (r_c - requested_color[0]) ** 2  
        gd = (g_c - requested_color[1]) ** 2  
        bd = (b_c - requested_color[2]) ** 2  
        min_colors[(rd + gd + bd)] = name  
    return min_colors[min(min_colors.keys())]

def convert_rgb_to_names(rgb_tuple):  
    try:  
        color_name = webcolors.rgb_to_name(rgb_tuple)  
    except ValueError:  
        color_name = closest_color(rgb_tuple)  
    return color_name

class PromptFormat():
    def __init__(
        self,
        font_path: str = 'assets/font_idx_512.json',
        color_path: str = 'assets/color_idx.json',
    ):
        with open(font_path, 'r') as f:
            self.font_dict = json.load(f)
        with open(color_path, 'r') as f:
            self.color_dict = json.load(f)

    def format_checker(self, texts, styles):
        assert len(texts) == len(styles), 'length of texts must be equal to length of styles'
        for style in styles:
            assert style['font-family'] in self.font_dict, f"invalid font-family: {style['font-family']}"
            rgb_color = webcolors.hex_to_rgb(style['color'])
            color_name = convert_rgb_to_names(rgb_color)
            assert color_name in self.color_dict, f"invalid color hex {color_name}"

    def format_prompt(self, texts, styles):
        self.format_checker(texts, styles)

        prompt = ""
        '''
        Text "{text}" in {color}, {type}.
        '''
        for text, style in zip(texts, styles):
            text_prompt = f'Text "{text}"'

            attr_list = []

            # format color
            hex_color = style["color"]
            rgb_color = webcolors.hex_to_rgb(hex_color)
            color_name = convert_rgb_to_names(rgb_color)
            attr_list.append(f"<color-{self.color_dict[color_name]}>")

            # format font
            attr_list.append(f"<font-{self.font_dict[style['font-family']]}>")
            attr_suffix = ", ".join(attr_list)
            text_prompt += " in " + attr_suffix
            text_prompt += ". "

            prompt = prompt + text_prompt
        return prompt


class MultilingualPromptFormat():
    def __init__(
        self,
        font_path: str = 'assets/multilingual_cn-en_font_idx.json',
        color_path: str = 'assets/color_idx.json',
    ):
        with open(font_path, 'r') as f:
            self.font_dict = json.load(f)
        with open(color_path, 'r') as f:
            self.color_dict = json.load(f)

    def format_checker(self, texts, styles):
        assert len(texts) == len(styles), 'length of texts must be equal to length of styles'
        for style in styles:
            assert style['font-family'] in self.font_dict, f"invalid font-family: {style['font-family']}"
            rgb_color = webcolors.hex_to_rgb(style['color'])
            color_name = convert_rgb_to_names(rgb_color)
            assert color_name in self.color_dict, f"invalid color hex {color_name}"

    def format_prompt(self, texts, styles):
        self.format_checker(texts, styles)

        prompt = ""
        '''
        Text "{text}" in {color}, {type}.
        '''
        for text, style in zip(texts, styles):
            text_prompt = f'Text "{text}"'

            attr_list = []

            # format color
            hex_color = style["color"]
            rgb_color = webcolors.hex_to_rgb(hex_color)
            color_name = convert_rgb_to_names(rgb_color)
            attr_list.append(f"<color-{self.color_dict[color_name]}>")

            # format font
            attr_list.append(f"<{style['font-family'][:2]}-font-{self.font_dict[style['font-family']]}>")
            attr_suffix = ", ".join(attr_list)
            text_prompt += " in " + attr_suffix
            text_prompt += ". "

            prompt = prompt + text_prompt
        return prompt