ki1207 commited on
Commit
9a4e757
·
verified ·
1 Parent(s): 9e66ec3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -28
app.py CHANGED
@@ -25,14 +25,6 @@ model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device
25
  def answer_ad_listing_question(
26
  image: PIL.Image.Image,
27
  title: str,
28
- decoding_method: str = "Nucleus sampling",
29
- temperature: float = 1.0,
30
- length_penalty: float = 1.0,
31
- repetition_penalty: float = 1.5,
32
- max_length: int = 50,
33
- min_length: int = 1,
34
- num_beams: int = 5,
35
- top_p: float = 0.9,
36
  ) -> str:
37
  # The prompt template with the provided title
38
  prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text:
@@ -46,14 +38,14 @@ def answer_ad_listing_question(
46
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
47
  generated_ids = model.generate(
48
  **inputs,
49
- do_sample=decoding_method == "Nucleus sampling",
50
- temperature=temperature,
51
- length_penalty=length_penalty,
52
- repetition_penalty=repetition_penalty,
53
- max_length=max_length,
54
- min_length=min_length,
55
- num_beams=num_beams,
56
- top_p=top_p,
57
  )
58
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
59
  return result
@@ -82,18 +74,7 @@ with gr.Blocks() as demo:
82
  # Logic to handle clicking on "Analyze Ad Listing"
83
  submit_button.click(
84
  fn=answer_ad_listing_question,
85
- inputs=[
86
- image,
87
- ad_title, # The title from the ad
88
- "Nucleus sampling", # Default values for decoding method, temperature, etc.
89
- 1.0, # temperature
90
- 1.0, # length_penalty
91
- 1.5, # repetition_penalty
92
- 50, # max_length
93
- 1, # min_length
94
- 5, # num_beams
95
- 0.9, # top_p
96
- ],
97
  outputs=answer_output,
98
  )
99
 
 
25
  def answer_ad_listing_question(
26
  image: PIL.Image.Image,
27
  title: str,
 
 
 
 
 
 
 
 
28
  ) -> str:
29
  # The prompt template with the provided title
30
  prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text:
 
38
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
39
  generated_ids = model.generate(
40
  **inputs,
41
+ do_sample=True, # Nucleus sampling is applied
42
+ temperature=1.0, # Default temperature
43
+ length_penalty=1.0, # Default length penalty
44
+ repetition_penalty=1.5, # Default repetition penalty
45
+ max_length=50, # Default max length
46
+ min_length=1, # Default min length
47
+ num_beams=5, # Default number of beams
48
+ top_p=0.9, # Default top_p for nucleus sampling
49
  )
50
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
51
  return result
 
74
  # Logic to handle clicking on "Analyze Ad Listing"
75
  submit_button.click(
76
  fn=answer_ad_listing_question,
77
+ inputs=[image, ad_title], # Only the image and ad title are inputs
 
 
 
 
 
 
 
 
 
 
 
78
  outputs=answer_output,
79
  )
80