File size: 4,601 Bytes
f3305db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import base64
from enum import auto, IntEnum
from io import BytesIO

from pydantic import BaseModel


class ImageFormat(IntEnum):
    """Image formats."""

    URL = auto()
    LOCAL_FILEPATH = auto()
    PIL_IMAGE = auto()
    BYTES = auto()
    DEFAULT = auto()


class Image(BaseModel):
    url: str = ""
    filetype: str = ""
    image_format: ImageFormat = ImageFormat.BYTES
    base64_str: str = ""

    def convert_image_to_base64(self):
        """Given an image, return the base64 encoded image string."""
        from PIL import Image
        import requests

        # Load image if it has not been loaded in yet
        if self.image_format == ImageFormat.URL:
            response = requests.get(image)
            image = Image.open(BytesIO(response.content)).convert("RGBA")
            image_bytes = BytesIO()
            image.save(image_bytes, format="PNG")
        elif self.image_format == ImageFormat.LOCAL_FILEPATH:
            image = Image.open(self.url).convert("RGBA")
            image_bytes = BytesIO()
            image.save(image_bytes, format="PNG")
        elif self.image_format == ImageFormat.BYTES:
            image_bytes = image

        img_b64_str = base64.b64encode(image_bytes).decode()

        return img_b64_str

    def to_openai_image_format(self):
        if self.image_format == ImageFormat.URL:  # input is a url
            return self.url
        elif self.image_format == ImageFormat.LOCAL_FILEPATH:  # input is a local image
            self.base64_str = self.convert_image_to_base64(self.url)
            return f"data:image/{self.filetype};base64,{self.base64_str}"
        elif self.image_format == ImageFormat.BYTES:
            return f"data:image/{self.filetype};base64,{self.base64_str}"
        else:
            raise ValueError(
                f"This file is not valid or not currently supported by the OpenAI API: {self.url}"
            )

    def resize_image_and_return_image_in_bytes(self, image, max_image_size_mb):
        import math

        image_format = "png"
        max_hw, min_hw = max(image.size), min(image.size)
        aspect_ratio = max_hw / min_hw
        max_len, min_len = 1024, 1024
        shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
        longest_edge = int(shortest_edge * aspect_ratio)
        W, H = image.size
        if longest_edge != max(image.size):
            if H > W:
                H, W = longest_edge, shortest_edge
            else:
                H, W = shortest_edge, longest_edge
            image = image.resize((W, H))

        image_bytes = BytesIO()
        image.save(image_bytes, format="PNG")
        if max_image_size_mb:
            target_size_bytes = max_image_size_mb * 1024 * 1024

            current_size_bytes = image_bytes.tell()
            if current_size_bytes > target_size_bytes:
                resize_factor = (target_size_bytes / current_size_bytes) ** 0.5
                new_width = math.floor(image.width * resize_factor)
                new_height = math.floor(image.height * resize_factor)
                image = image.resize((new_width, new_height))

                image_bytes = BytesIO()
                image.save(image_bytes, format="PNG")
                current_size_bytes = image_bytes.tell()

            image_bytes.seek(0)

        return image_format, image_bytes

    def convert_url_to_image_bytes(self, max_image_size_mb):
        from PIL import Image

        if self.url.endswith(".svg"):
            import cairosvg

            with open(self.url, "rb") as svg_file:
                svg_data = svg_file.read()

            png_data = cairosvg.svg2png(bytestring=svg_data)
            pil_image = Image.open(BytesIO(png_data)).convert("RGBA")
        else:
            pil_image = Image.open(self.url).convert("RGBA")

        image_format, image_bytes = self.resize_image_and_return_image_in_bytes(
            pil_image, max_image_size_mb
        )

        img_base64_str = base64.b64encode(image_bytes.getvalue()).decode()

        return image_format, img_base64_str

    def to_conversation_format(self, max_image_size_mb):
        image_format, image_bytes = self.convert_url_to_image_bytes(
            max_image_size_mb=max_image_size_mb
        )

        self.filetype = image_format
        self.image_format = ImageFormat.BYTES
        self.base64_str = image_bytes

        return self


if __name__ == "__main__":
    image = Image(url="fastchat/serve/example_images/fridge.jpg")
    image.to_conversation_format(max_image_size_mb=5 / 1.5)

    json_str = image.model_dump_json()
    print(json_str)