feat: normalize cosine similarity
Browse files- lib/utils/model.py +4 -0
lib/utils/model.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import yaml
|
3 |
import torch
|
|
|
4 |
|
5 |
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
|
6 |
from lib.IRRA.image import prepare_images
|
@@ -27,4 +28,7 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
|
27 |
image_feats = model.encode_image(imgs)
|
28 |
text_feats = model.encode_text(txt.unsqueeze(0))
|
29 |
|
|
|
|
|
|
|
30 |
return text_feats @ image_feats.t()
|
|
|
1 |
import streamlit as st
|
2 |
import yaml
|
3 |
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
|
6 |
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
|
7 |
from lib.IRRA.image import prepare_images
|
|
|
28 |
image_feats = model.encode_image(imgs)
|
29 |
text_feats = model.encode_text(txt.unsqueeze(0))
|
30 |
|
31 |
+
image_feats = F.normalize(image_feats, p=2, dim=1)
|
32 |
+
text_feats = F.normalize(text_feats, p=2, dim=1)
|
33 |
+
|
34 |
return text_feats @ image_feats.t()
|