Bingsu commited on
Commit
2a88d86
ยท
1 Parent(s): eb6e722

feat: clip model test

Browse files
Files changed (2) hide show
  1. app.py +22 -10
  2. pyproject.toml +12 -12
app.py CHANGED
@@ -5,25 +5,38 @@ import pandas as pd
5
  import streamlit as st
6
  import torch
7
  from sentence_transformers.util import semantic_search
8
- from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
9
 
10
  st.title("VitB32 Bert Ko Small Clip Test")
11
  st.markdown("Unsplash data์—์„œ ์ž…๋ ฅ ํ…์ŠคํŠธ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.")
12
 
13
 
14
  @st.cache(allow_output_mutation=True, show_spinner=False)
15
- def get_model():
16
  with st.spinner("Loading model..."):
17
- model = VisionTextDualEncoderModel.from_pretrained(
18
- "Bingsu/vitB32_bert_ko_small_clip"
19
- ).eval()
20
- processor = VisionTextDualEncoderProcessor.from_pretrained(
21
- "Bingsu/vitB32_bert_ko_small_clip"
22
- )
23
  return model, processor
24
 
25
 
26
- model, processor = get_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  info = pd.read_csv("info.csv")
29
  with open("img_id.pkl", "rb") as f:
@@ -35,7 +48,6 @@ tokens = processor(text=text, return_tensors="pt")
35
 
36
  with torch.no_grad():
37
  text_emb = model.get_text_features(**tokens)
38
- text_emb = text_emb / text_emb.norm(dim=1, keepdim=True)
39
 
40
  result = semantic_search(text_emb, img_emb, top_k=15)[0]
41
  _result = iter(result)
 
5
  import streamlit as st
6
  import torch
7
  from sentence_transformers.util import semantic_search
8
+ from transformers import AutoModel, AutoProcessor
9
 
10
  st.title("VitB32 Bert Ko Small Clip Test")
11
  st.markdown("Unsplash data์—์„œ ์ž…๋ ฅ ํ…์ŠคํŠธ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.")
12
 
13
 
14
  @st.cache(allow_output_mutation=True, show_spinner=False)
15
+ def get_dual_encoder_model():
16
  with st.spinner("Loading model..."):
17
+ model = AutoModel.from_pretrained("Bingsu/vitB32_bert_ko_small_clip").eval()
18
+ processor = AutoProcessor.from_pretrained("Bingsu/vitB32_bert_ko_small_clip")
 
 
 
 
19
  return model, processor
20
 
21
 
22
+ @st.cache(allow_output_mutation=True, show_spinner=False)
23
+ def get_clip_model():
24
+ with st.spinner("Loading model..."):
25
+ model = AutoModel.from_pretrained("Bingsu/clip-vit-base-patch32-ko").eval()
26
+ processor = AutoProcessor.from_pretrained("Bingsu/clip-vit-base-patch32-ko")
27
+ return model, processor
28
+
29
+
30
+ model_type = st.radio(
31
+ "Select model",
32
+ ["Bingsu/clip-vit-base-patch32-ko", "Bingsu/vitB32_bert_ko_small_clip"],
33
+ horizontal=True,
34
+ )
35
+
36
+ if model_type == "Bingsu/clip-vit-base-patch32-ko":
37
+ model, processor = get_clip_model()
38
+ else:
39
+ model, processor = get_dual_encoder_model()
40
 
41
  info = pd.read_csv("info.csv")
42
  with open("img_id.pkl", "rb") as f:
 
48
 
49
  with torch.no_grad():
50
  text_emb = model.get_text_features(**tokens)
 
51
 
52
  result = semantic_search(text_emb, img_emb, top_k=15)[0]
53
  _result = iter(result)
pyproject.toml CHANGED
@@ -1,25 +1,25 @@
1
  [tool.poetry]
2
  name = "vitb32_bert_ko_small_clip_test"
3
- version = "0.1.0"
4
  description = ""
5
  authors = ["Bingsu <[email protected]>"]
6
  license = "MIT"
7
 
8
  [tool.poetry.dependencies]
9
  python = "^3.9"
10
- torch = "^1.11.0"
11
- transformers = "^4.19.2"
12
- streamlit = "^1.10.0"
13
- pandas = "^1.4.2"
14
- sentence-transformers = "^2.2.0"
15
 
16
  [tool.poetry.dev-dependencies]
17
- black = "^22.3.0"
18
- isort = "^5.10.1"
19
- mypy = "^0.961"
20
- flake8 = "^4.0.1"
21
- flake8-bugbear = "^22.4.25"
22
- pre-commit = "^2.19.0"
23
 
24
  [build-system]
25
  requires = ["poetry-core>=1.0.0"]
 
1
  [tool.poetry]
2
  name = "vitb32_bert_ko_small_clip_test"
3
+ version = "0.2.0"
4
  description = ""
5
  authors = ["Bingsu <[email protected]>"]
6
  license = "MIT"
7
 
8
  [tool.poetry.dependencies]
9
  python = "^3.9"
10
+ torch = "^1.12"
11
+ transformers = "*"
12
+ streamlit = "*"
13
+ pandas = "*"
14
+ sentence-transformers = "*"
15
 
16
  [tool.poetry.dev-dependencies]
17
+ black = "*"
18
+ isort = "*"
19
+ mypy = "*"
20
+ flake8 = "*"
21
+ flake8-bugbear = "*"
22
+ pre-commit = "*"
23
 
24
  [build-system]
25
  requires = ["poetry-core>=1.0.0"]