jykoh commited on
Commit
f3e34b0
·
1 Parent(s): e466d0e

Update to not use classes

Browse files
Files changed (1) hide show
  1. app.py +94 -112
app.py CHANGED
@@ -12,115 +12,97 @@ import huggingface_hub
12
  import tempfile
13
 
14
 
15
- class FromageChatBot:
16
- def __init__(self):
17
- # Download model from HF Hub.
18
- ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
19
- args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
20
- self.model = models.load_fromage('./', args_path, ckpt_path)
21
- self.chat_history = ''
22
- self.input_image = None
23
-
24
-
25
- def reset(self):
26
- self.chat_history = ""
27
- self.input_image = None
28
- return [], []
29
-
30
-
31
- def upload_image(self, state, image_input):
32
- state += [(f"![](/file={image_input.name})", "(Image received. Type or ask something to continue.)")]
33
- self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
34
- return state, state
35
-
36
-
37
- def save_image_to_local(self, image: Image.Image):
38
- # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
39
- filename = next(tempfile._get_candidate_names()) + '.png'
40
- image.save(filename)
41
- return filename
42
-
43
-
44
- def generate_for_prompt(self, input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
45
- input_prompt = 'Q: ' + input_text + '\nA:'
46
- self.chat_history += input_prompt
47
- print('Generating for', self.chat_history, flush=True)
48
-
49
- # If an image was uploaded, prepend it to the model.
50
- model_inputs = None
51
- if self.input_image is not None:
52
- model_inputs = [self.input_image, self.chat_history]
53
- else:
54
- model_inputs = [self.chat_history]
55
-
56
- top_p = 1.0
57
- if temperature != 0.0:
58
- top_p = 0.95
59
-
60
- print('Running model.generate_for_images_and_texts', flush=True)
61
- model_outputs = self.model.generate_for_images_and_texts(model_inputs,
62
- num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
63
- temperature=temperature, max_num_rets=max_nm_rets)
64
- print('model_outputs', model_outputs, flush=True)
65
-
66
- im_names = []
67
- response = ''
68
- text_outputs = []
69
- for output in model_outputs:
70
- if type(output) == str:
71
- text_outputs.append(output)
72
- response += output
73
- elif type(output) == list:
74
- for image in output:
75
- filename = self.save_image_to_local(image)
76
- response += f'<img src="/file={filename}">'
77
- elif type(output) == Image.Image:
78
- filename = self.save_image_to_local(output)
79
- response += f'<img src="/file={filename}">'
80
-
81
- # TODO(jykoh): Persist image inputs.
82
- self.chat_history += ' '.join(text_outputs)
83
- if self.chat_history[-1] != '\n':
84
- self.chat_history += '\n'
85
- self.input_image = None
86
-
87
- state.append((input_text, response))
88
- return state, state
89
-
90
-
91
- def launch(self):
92
- with gr.Blocks() as demo:
93
- gr.Markdown(
94
- '### Grounding Language Models to Images for Multimodal Generation'
95
- )
96
-
97
- chatbot = gr.Chatbot()
98
- gr_state = gr.State([])
99
-
100
- with gr.Row():
101
- with gr.Column(scale=0.3, min_width=0):
102
- ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
103
- max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
104
- gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
105
- gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
106
-
107
- with gr.Column(scale=0.7, min_width=0):
108
- image_btn = gr.UploadButton("Image Input", file_types=["image"])
109
- text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
110
- clear_btn = gr.Button("Clear History")
111
-
112
- text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
113
- image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
114
- clear_btn.click(self.reset, [], [gr_state, chatbot])
115
-
116
- demo.launch(share=False, debug=True, server_name="0.0.0.0")
117
-
118
-
119
- def main():
120
- chatbot = FromageChatBot()
121
- chatbot.launch()
122
-
123
-
124
- if __name__ == "__main__":
125
- chatbot = FromageChatBot()
126
- chatbot.launch()
 
12
  import tempfile
13
 
14
 
15
+ # Download model from HF Hub.
16
+ ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
17
+ args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
18
+ model = models.load_fromage('./', args_path, ckpt_path)
19
+
20
+
21
+ def upload_image(state, image_input):
22
+ state += [(f"![](/file={image_input.name})", "(Image received. Type or ask something to continue.)")]
23
+ input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
24
+ return [state, input_image], state
25
+
26
+
27
+ def save_image_to_local(image: Image.Image):
28
+ # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
29
+ filename = next(tempfile._get_candidate_names()) + '.png'
30
+ image.save(filename)
31
+ return filename
32
+
33
+
34
+ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
35
+ input_prompt = 'Q: ' + input_text + '\nA:'
36
+ input_image = state[1]
37
+ chat_history += input_prompt
38
+ print('Generating for', chat_history, flush=True)
39
+
40
+ # If an image was uploaded, prepend it to the model.
41
+ model_inputs = None
42
+ if input_image is not None:
43
+ model_inputs = [input_image, chat_history]
44
+ else:
45
+ model_inputs = [chat_history]
46
+
47
+ top_p = 1.0
48
+ if temperature != 0.0:
49
+ top_p = 0.95
50
+
51
+ print('Running model.generate_for_images_and_texts with', model_inputs, flush=True)
52
+ model_outputs = model.generate_for_images_and_texts(model_inputs,
53
+ num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
54
+ temperature=temperature, max_num_rets=max_nm_rets)
55
+ print('model_outputs', model_outputs, flush=True)
56
+
57
+ im_names = []
58
+ response = ''
59
+ text_outputs = []
60
+ for output in model_outputs:
61
+ if type(output) == str:
62
+ text_outputs.append(output)
63
+ response += output
64
+ elif type(output) == list:
65
+ for image in output:
66
+ filename = save_image_to_local(image)
67
+ response += f'<img src="/file={filename}">'
68
+ elif type(output) == Image.Image:
69
+ filename = save_image_to_local(output)
70
+ response += f'<img src="/file={filename}">'
71
+
72
+ # TODO(jykoh): Persist image inputs.
73
+ chat_history += ' '.join(text_outputs)
74
+ if chat_history[-1] != '\n':
75
+ chat_history += '\n'
76
+
77
+ state.append((input_text, response))
78
+
79
+ # Set input image to None.
80
+ print('state', state, flush=True)
81
+ return [state, None], state
82
+
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown(
86
+ '### Grounding Language Models to Images for Multimodal Generation'
87
+ )
88
+
89
+ chatbot = gr.Chatbot()
90
+ gr_state = gr.State([[], None]) # chat_history, input_image
91
+
92
+ with gr.Row():
93
+ with gr.Column(scale=0.3, min_width=0):
94
+ ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
95
+ max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
96
+ gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
97
+ gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
98
+
99
+ with gr.Column(scale=0.7, min_width=0):
100
+ image_btn = gr.UploadButton("Image Input", file_types=["image"])
101
+ text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
102
+ clear_btn = gr.Button("Clear History")
103
+
104
+ text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
105
+ image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
106
+ clear_btn.click(reset, [], [gr_state, chatbot])
107
+
108
+ demo.launch(share=False, debug=True, server_name="0.0.0.0")