zamal commited on
Commit
00b84bb
·
verified ·
1 Parent(s): 5e5b3a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -56
app.py CHANGED
@@ -1,72 +1,80 @@
1
  import gradio as gr
2
- from transformers import (
3
- AutoModelForCausalLM,
4
- AutoProcessor,
5
- GenerationConfig,
6
- BitsAndBytesConfig,
7
- )
8
  from PIL import Image
9
  import torch
 
10
 
11
- # Configuration for 4-bit quantization and GPU offloading
12
- bnb_config = BitsAndBytesConfig(
13
- load_in_4bit=True,
 
 
 
14
  )
15
 
16
- # Model repository
17
- repo_name = "cyan2k/molmo-7B-O-bnb-4bit"
18
-
19
- # Load the processor and model
20
- processor = AutoProcessor.from_pretrained(repo_name, trust_remote_code=True)
21
  model = AutoModelForCausalLM.from_pretrained(
22
- repo_name,
23
- torch_dtype=torch.float16,
24
- device_map="auto",
25
  trust_remote_code=True,
26
- quantization_config=bnb_config,
 
27
  )
28
 
29
- # Ensure model is on GPU
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- model.to(device)
32
 
33
- def describe_images(images):
34
- descriptions = []
35
- for image in images:
36
- if isinstance(image, str):
37
- image = Image.open(image)
38
- # Process the image
39
- inputs = processor.process(
40
- images=[image],
41
- text="Describe this image in great detail.",
42
- )
43
- # Move inputs to the same device as the model
44
- inputs = {k: v.to(device) for k, v in inputs.items()}
45
- # Generate output
46
- with torch.no_grad():
47
- output = model.generate_from_batch(
48
- inputs,
49
- GenerationConfig(max_new_tokens=200, stop_strings=["<|endoftext|>"]),
50
- tokenizer=processor.tokenizer,
51
- )
52
- # Decode generated tokens to text
53
- generated_tokens = output[0, inputs["input_ids"].size(1):]
54
- generated_text = processor.tokenizer.decode(
55
- generated_tokens, skip_special_tokens=True
56
- )
57
- descriptions.append(generated_text.strip())
58
- return "\n\n".join(descriptions)
59
 
60
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  with gr.Blocks() as demo:
62
- gr.Markdown("<h3><center>Image Description Generator</center></h3>")
 
63
  with gr.Row():
64
- image_input = gr.File(
65
- file_types=["image"], label="Upload Image(s)", multiple=True
66
- )
67
- generate_button = gr.Button("Generate Descriptions")
68
- output_text = gr.Textbox(label="Descriptions", lines=15)
 
 
 
 
 
 
 
 
69
 
70
- generate_button.click(describe_images, inputs=image_input, outputs=output_text)
 
 
 
 
71
 
72
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
 
 
 
 
 
3
  from PIL import Image
4
  import torch
5
+ import spaces
6
 
7
+ # Load the processor and model
8
+ processor = AutoProcessor.from_pretrained(
9
+ 'allenai/Molmo-7B-D-0924',
10
+ trust_remote_code=True,
11
+ torch_dtype='auto',
12
+ device_map='auto'
13
  )
14
 
 
 
 
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ 'allenai/Molmo-7B-D-0924',
 
 
17
  trust_remote_code=True,
18
+ torch_dtype='auto',
19
+ device_map='auto'
20
  )
21
 
 
 
 
22
 
23
+ @spaces.GPU(duration=120)
24
+ def process_image_and_text(image, text):
25
+ # Process the image and text
26
+ inputs = processor.process(
27
+ images=[Image.fromarray(image)],
28
+ text=text
29
+ )
30
+
31
+ # Move inputs to the correct device and make a batch of size 1
32
+ inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
33
+
34
+ # Generate output
35
+ output = model.generate_from_batch(
36
+ inputs,
37
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
38
+ tokenizer=processor.tokenizer
39
+ )
 
 
 
 
 
 
 
 
 
40
 
41
+ # Only get generated tokens; decode them to text
42
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
43
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
44
+
45
+ return generated_text
46
+
47
+ def chatbot(image, text, history):
48
+ if image is None:
49
+ return history + [("Please upload an image first.", None)]
50
+
51
+ response = process_image_and_text(image, text)
52
+ history.append((text, response))
53
+ return history
54
+
55
+ # Define the Gradio interface
56
  with gr.Blocks() as demo:
57
+ gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
58
+
59
  with gr.Row():
60
+ image_input = gr.Image(type="numpy")
61
+ chatbot_output = gr.Chatbot()
62
+
63
+ text_input = gr.Textbox(placeholder="Ask a question about the image...")
64
+ submit_button = gr.Button("Submit")
65
+
66
+ state = gr.State([])
67
+
68
+ submit_button.click(
69
+ chatbot,
70
+ inputs=[image_input, text_input, state],
71
+ outputs=[chatbot_output]
72
+ )
73
 
74
+ text_input.submit(
75
+ chatbot,
76
+ inputs=[image_input, text_input, state],
77
+ outputs=[chatbot_output]
78
+ )
79
 
80
+ demo.launch()