SkalskiP commited on
Commit
b7463e4
1 Parent(s): d21820e

Added multi-turn chat.

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import time
3
- from typing import List, Tuple, Optional
4
 
5
  import google.generativeai as genai
6
  import gradio as gr
7
  from PIL import Image
8
 
 
9
  print("google-generativeai:", genai.__version__)
10
 
11
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
@@ -42,6 +43,18 @@ def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
42
  return image.resize((IMAGE_WIDTH, image_height))
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
46
  return "", chatbot + [[text_prompt, None]]
47
 
@@ -74,7 +87,7 @@ def bot(
74
  if image_prompt is None:
75
  model = genai.GenerativeModel('gemini-pro')
76
  response = model.generate_content(
77
- text_prompt,
78
  stream=True,
79
  generation_config=generation_config)
80
  response.resolve()
@@ -106,12 +119,13 @@ google_key_component = gr.Textbox(
106
  visible=GOOGLE_API_KEY is None
107
  )
108
 
109
- image_prompt_component = gr.Image(type="pil", label="Image", scale=1)
110
  chatbot_component = gr.Chatbot(
111
  label='Gemini',
112
  bubble_full_width=False,
113
  avatar_images=AVATAR_IMAGES,
114
- scale=2
 
115
  )
116
  text_prompt_component = gr.Textbox(
117
  placeholder="Hi there!",
 
1
  import os
2
  import time
3
+ from typing import List, Tuple, Optional, Dict
4
 
5
  import google.generativeai as genai
6
  import gradio as gr
7
  from PIL import Image
8
 
9
+
10
  print("google-generativeai:", genai.__version__)
11
 
12
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
 
43
  return image.resize((IMAGE_WIDTH, image_height))
44
 
45
 
46
+ def preprocess_chat_history(
47
+ history: List[Tuple[Optional[str], Optional[str]]]
48
+ ) -> List[Dict[str, List[str]]]:
49
+ messages = []
50
+ for user_message, model_message in history:
51
+ if user_message is not None:
52
+ messages.append({'role': 'user', 'parts': [user_message]})
53
+ if model_message is not None:
54
+ messages.append({'role': 'model', 'parts': [model_message]})
55
+ return messages
56
+
57
+
58
  def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
59
  return "", chatbot + [[text_prompt, None]]
60
 
 
87
  if image_prompt is None:
88
  model = genai.GenerativeModel('gemini-pro')
89
  response = model.generate_content(
90
+ preprocess_chat_history(chatbot),
91
  stream=True,
92
  generation_config=generation_config)
93
  response.resolve()
 
119
  visible=GOOGLE_API_KEY is None
120
  )
121
 
122
+ image_prompt_component = gr.Image(type="pil", label="Image", scale=1, height=400)
123
  chatbot_component = gr.Chatbot(
124
  label='Gemini',
125
  bubble_full_width=False,
126
  avatar_images=AVATAR_IMAGES,
127
+ scale=2,
128
+ height=400
129
  )
130
  text_prompt_component = gr.Textbox(
131
  placeholder="Hi there!",