ydshieh commited on
Commit
8f85ccf
1 Parent(s): d1befcb

fix closed image issue

Browse files
Files changed (2) hide show
  1. app.py +52 -48
  2. model.py +2 -1
app.py CHANGED
@@ -39,55 +39,59 @@ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
39
  submitted = st.form_submit_button("Upload")
40
  if submitted and uploaded_file is not None:
41
  bytes_data = io.BytesIO(uploaded_file.getvalue())
42
- uploaded_file = None
43
- submitted = None
44
-
45
- image_id = random_image_id
46
- if sample_image_id != "None":
47
- assert type(sample_image_id) == int
48
- image_id = sample_image_id
49
-
50
- sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
51
- sample_path = os.path.join(sample_dir, sample_name)
52
-
53
- if bytes_data is not None:
54
- image = Image.open(bytes_data)
55
- bytes_data = None
56
- elif os.path.isfile(sample_path):
57
- image = Image.open(sample_path)
58
- else:
59
- url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
60
- image = Image.open(requests.get(url, stream=True).raw)
61
-
62
- width, height = image.size
63
- resized = image
64
- if height > 384:
65
- width = int(width / height * 384)
66
- height = 384
67
- resized = resized.resize(size=(width, height))
68
- if width > 512:
69
- width = 512
70
- height = int(height / width * 512)
71
- resized = resized.resize(size=(width, height))
72
-
73
-
74
- st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
75
- show = st.image(resized)
76
- show.image(resized, '\n\nSelected Image')
77
- resized.close()
78
 
79
- # For newline
80
- st.sidebar.write('\n')
81
 
82
- with st.spinner('Generating image caption ...'):
83
 
84
- caption = predict(image)
85
-
86
- caption_en = caption
87
- st.header(f'Predicted caption:\n\n')
88
- st.subheader(caption_en)
89
-
90
- st.sidebar.header("ViT-GPT2 predicts:")
91
- st.sidebar.write(f"**English**: {caption}")
92
 
93
- image.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  submitted = st.form_submit_button("Upload")
40
  if submitted and uploaded_file is not None:
41
  bytes_data = io.BytesIO(uploaded_file.getvalue())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ if (bytes_data is None) and submitted:
 
44
 
45
+ st.write("No file is selected to upload")
46
 
47
+ else:
 
 
 
 
 
 
 
48
 
49
+ image_id = random_image_id
50
+ if sample_image_id != "None":
51
+ assert type(sample_image_id) == int
52
+ image_id = sample_image_id
53
+
54
+ sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
55
+ sample_path = os.path.join(sample_dir, sample_name)
56
+
57
+ if bytes_data is not None:
58
+ image = Image.open(bytes_data)
59
+ elif os.path.isfile(sample_path):
60
+ image = Image.open(sample_path)
61
+ else:
62
+ url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
63
+ image = Image.open(requests.get(url, stream=True).raw)
64
+
65
+ width, height = image.size
66
+ resized = image.resize(size=(width, height))
67
+ if height > 384:
68
+ width = int(width / height * 384)
69
+ height = 384
70
+ resized = resized.resize(size=(width, height))
71
+ width, height = resized.size
72
+ if width > 512:
73
+ width = 512
74
+ height = int(height / width * 512)
75
+ resized = resized.resize(size=(width, height))
76
+
77
+ if bytes_data is None:
78
+ st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
79
+ show = st.image(resized)
80
+ show.image(resized, '\n\nSelected Image')
81
+ resized.close()
82
+
83
+ # For newline
84
+ st.sidebar.write('\n')
85
+
86
+ with st.spinner('Generating image caption ...'):
87
+
88
+ caption = predict(image)
89
+
90
+ caption_en = caption
91
+ st.header(f'Predicted caption:\n\n')
92
+ st.subheader(caption_en)
93
+
94
+ st.sidebar.header("ViT-GPT2 predicts: ")
95
+ st.sidebar.write(f"{caption}")
96
+
97
+ image.close()
model.py CHANGED
@@ -47,6 +47,7 @@ def generate(pixel_values):
47
  def predict(image):
48
 
49
  pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
 
50
  output_ids = generate(pixel_values)
51
  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
52
  preds = [pred.strip() for pred in preds]
@@ -58,7 +59,7 @@ def _compile():
58
 
59
  image_path = 'samples/val_000000039769.jpg'
60
  image = Image.open(image_path)
61
- caption = predict(image)
62
  image.close()
63
 
64
 
 
47
  def predict(image):
48
 
49
  pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
50
+
51
  output_ids = generate(pixel_values)
52
  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
53
  preds = [pred.strip() for pred in preds]
 
59
 
60
  image_path = 'samples/val_000000039769.jpg'
61
  image = Image.open(image_path)
62
+ predict(image)
63
  image.close()
64
 
65