Files changed (1) hide show
  1. app.py +20 -26
app.py CHANGED
@@ -10,7 +10,7 @@ from mistral_inference.transformer import Transformer
10
  from mistral_inference.generate import generate
11
 
12
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
13
- from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
14
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
15
 
16
  models_path = Path.home().joinpath('pixtral', 'Pixtral')
@@ -28,10 +28,23 @@ def image_to_base64(image_path):
28
  encoded_string = base64.b64encode(img.read()).decode('utf-8')
29
  return f"data:image/jpeg;base64,{encoded_string}"
30
 
31
- @spaces.GPU(duration=30)
32
- def run_inference(image_url, prompt):
33
- base64 = image_to_base64(image_url)
34
- completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageURLChunk(image_url=base64), TextChunk(text=prompt)])])
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  encoded = tokenizer.encode_chat_completion(completion_request)
37
 
@@ -40,26 +53,7 @@ def run_inference(image_url, prompt):
40
 
41
  out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
42
  result = tokenizer.decode(out_tokens[0])
43
- return [[prompt, result]]
44
-
45
- with gr.Blocks() as demo:
46
- with gr.Row():
47
- image_box = gr.Image(type="filepath")
48
-
49
- chatbot = gr.Chatbot(
50
- scale = 2,
51
- height=750
52
- )
53
- text_box = gr.Textbox(
54
- placeholder="Enter your text and press enter, or upload an image.",
55
- container=False,
56
- )
57
-
58
-
59
- btn = gr.Button("Submit")
60
- clicked = btn.click(run_inference,
61
- [image_box,text_box],
62
- chatbot
63
- )
64
 
 
65
  demo.queue().launch()
 
10
  from mistral_inference.generate import generate
11
 
12
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
13
+ from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk
14
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
15
 
16
  models_path = Path.home().joinpath('pixtral', 'Pixtral')
 
28
  encoded_string = base64.b64encode(img.read()).decode('utf-8')
29
  return f"data:image/jpeg;base64,{encoded_string}"
30
 
31
+ @spaces.GPU(duration=60)
32
+ def run_inference(message, history):
33
+ ## may work
34
+ messages = []
35
+ images = []
36
+ for couple in history:
37
+ if type(couple[0]) is tuple:
38
+ images += couple[0]
39
+ elif couple[0][1]:
40
+ messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(path)) for path in images]+[TextChunk(text=couple[0][1])]))
41
+ messages.append(AssistantMessage(content = couple[1]))
42
+ images = []
43
+ ##
44
+
45
+ messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])]))
46
+
47
+ completion_request = ChatCompletionRequest(messages=messages)
48
 
49
  encoded = tokenizer.encode_chat_completion(completion_request)
50
 
 
53
 
54
  out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
55
  result = tokenizer.decode(out_tokens[0])
56
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True)
59
  demo.queue().launch()