Upload 3 files
Browse files- description_validator.py +65 -0
- image_validator.py +64 -0
- product_update_validator.py +21 -0
description_validator.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from warnings import filterwarnings
|
7 |
+
filterwarnings("ignore")
|
8 |
+
|
9 |
+
models = ["MPNet-base-v2", "DistilRoBERTa-v1", "MiniLM-L12-v2", "MiniLM-L6-v2"]
|
10 |
+
models_info = {
|
11 |
+
"MPNet-base-v2": {
|
12 |
+
"model_size": "420MB",
|
13 |
+
"model_url": "sentence-transformers/all-mpnet-base-v2",
|
14 |
+
"efficiency": "Moderate",
|
15 |
+
"chunk_size": 512
|
16 |
+
},
|
17 |
+
"DistilRoBERTa-v1": {
|
18 |
+
"model_size": "263MB",
|
19 |
+
"model_url": "sentence-transformers/all-distilroberta-v1",
|
20 |
+
"efficiency": "High",
|
21 |
+
"chunk_size": 512
|
22 |
+
},
|
23 |
+
"MiniLM-L12-v2": {
|
24 |
+
"model_size": "118MB",
|
25 |
+
"model_url": "sentence-transformers/all-MiniLM-L12-v2",
|
26 |
+
"efficiency": "High",
|
27 |
+
"chunk_size": 512
|
28 |
+
},
|
29 |
+
"MiniLM-L6-v2": {
|
30 |
+
"model_size": "82MB",
|
31 |
+
"model_url": "sentence-transformers/all-MiniLM-L6-v2",
|
32 |
+
"efficiency": "Very High",
|
33 |
+
"chunk_size": 512
|
34 |
+
}
|
35 |
+
}
|
36 |
+
|
37 |
+
class Description_Validator:
|
38 |
+
def __init__(self, model_name=None):
|
39 |
+
if model_name is None: model_name="DistilRoBERTa-v1"
|
40 |
+
|
41 |
+
self.model_info = models_info[model_name]
|
42 |
+
model_url = self.model_info["model_url"]
|
43 |
+
|
44 |
+
self.model = SentenceTransformer(model_url)
|
45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_url)
|
46 |
+
self.chunk_size = self.model_info["chunk_size"]
|
47 |
+
|
48 |
+
def tokenize_and_chunk(self, text):
|
49 |
+
tokens = self.tokenizer(text, truncation=False, padding=True, add_special_tokens=False)['input_ids']
|
50 |
+
token_chunks = [tokens[i:i+self.chunk_size] for i in range(0, len(tokens), self.chunk_size)]
|
51 |
+
return token_chunks
|
52 |
+
|
53 |
+
def get_average_embedding(self, text):
|
54 |
+
token_chunks = self.tokenize_and_chunk(text)
|
55 |
+
chunk_embeddings = []
|
56 |
+
for chunk in token_chunks:
|
57 |
+
chunk_embedding = self.model.encode(self.tokenizer.decode(chunk), show_progress_bar=False)
|
58 |
+
chunk_embeddings.append(chunk_embedding)
|
59 |
+
return np.mean(chunk_embeddings, axis=0)
|
60 |
+
|
61 |
+
def similarity_score(self, desc1, desc2):
|
62 |
+
embedding1 = self.get_average_embedding(desc1).reshape(1, -1)
|
63 |
+
embedding2 = self.get_average_embedding(desc2).reshape(1, -1)
|
64 |
+
similarity = cosine_similarity(embedding1, embedding2)
|
65 |
+
return similarity[0][0]
|
image_validator.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPProcessor, CLIPModel, ViTImageProcessor, ViTModel
|
2 |
+
from PIL import Image
|
3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
+
|
5 |
+
from warnings import filterwarnings
|
6 |
+
filterwarnings("ignore")
|
7 |
+
|
8 |
+
models = ["CLIP-ViT Base", "ViT Base", "DINO ViT-S16"]
|
9 |
+
models_info = {
|
10 |
+
"CLIP-ViT Base": {
|
11 |
+
"model_size": "386MB",
|
12 |
+
"model_url": "openai/clip-vit-base-patch32",
|
13 |
+
"efficiency": "High",
|
14 |
+
},
|
15 |
+
"ViT Base": {
|
16 |
+
"model_size": "304MB",
|
17 |
+
"model_url": "google/vit-base-patch16-224",
|
18 |
+
"efficiency": "High",
|
19 |
+
},
|
20 |
+
"DINO ViT-S16": {
|
21 |
+
"model_size": "1.34GB",
|
22 |
+
"model_url": "facebook/dino-vits16",
|
23 |
+
"efficiency": "Moderate",
|
24 |
+
},
|
25 |
+
}
|
26 |
+
|
27 |
+
class Image_Validator:
|
28 |
+
def __init__(self, model_name=None):
|
29 |
+
if model_name is None: model_name="ViT Base"
|
30 |
+
|
31 |
+
self.model_info = models_info[model_name]
|
32 |
+
model_url = self.model_info["model_url"]
|
33 |
+
|
34 |
+
if model_name == "CLIP-ViT Base":
|
35 |
+
self.model = CLIPModel.from_pretrained(model_url)
|
36 |
+
self.processor = CLIPProcessor.from_pretrained(model_url)
|
37 |
+
|
38 |
+
elif model_name == "ViT Base":
|
39 |
+
self.model = ViTModel.from_pretrained(model_url)
|
40 |
+
self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
|
41 |
+
|
42 |
+
elif model_name == "DINO ViT-S16":
|
43 |
+
self.model = ViTModel.from_pretrained(model_url)
|
44 |
+
self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
|
45 |
+
|
46 |
+
def get_image_embedding(self, image_path):
|
47 |
+
image = Image.open(image_path)
|
48 |
+
|
49 |
+
# Process image according to the model
|
50 |
+
if hasattr(self, 'processor'): # CLIP models
|
51 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
52 |
+
outputs = self.model.get_image_features(**inputs)
|
53 |
+
|
54 |
+
elif hasattr(self, 'feature_extractor'): # ViT models
|
55 |
+
inputs = self.feature_extractor(images=image, return_tensors="pt")
|
56 |
+
outputs = self.model(**inputs).last_hidden_state
|
57 |
+
|
58 |
+
return outputs
|
59 |
+
|
60 |
+
def similarity_score(self, image_path_1, image_path_2):
|
61 |
+
embedding1 = self.get_image_embedding(image_path_1).reshape(1, -1)
|
62 |
+
embedding2 = self.get_image_embedding(image_path_2).reshape(1, -1)
|
63 |
+
similarity = cosine_similarity(embedding1.detach().numpy(), embedding2.detach().numpy())
|
64 |
+
return similarity[0][0]
|
product_update_validator.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model_factory.description_validator import Description_Validator
|
2 |
+
from model_factory.image_validator import Image_Validator
|
3 |
+
|
4 |
+
class Update_Validator:
|
5 |
+
def __init__(self, text_model=None, image_model=None, threshold=0.7):
|
6 |
+
self.description_validator = Description_Validator(model_name=text_model)
|
7 |
+
self.image_validator = Image_Validator(model_name=image_model)
|
8 |
+
self.threshold = threshold
|
9 |
+
|
10 |
+
def validate(self, text1, text2, image_path_1, image_path_2, threshold=None, return_score=False):
|
11 |
+
description_similarity = self.description_validator.similarity_score(text1, text2)
|
12 |
+
image_similarity = self.image_validator.similarity_score(image_path_1, image_path_2)
|
13 |
+
similarity_score = 0.75 * description_similarity + 0.25 * image_similarity
|
14 |
+
|
15 |
+
if threshold is None: threshold=self.threshold
|
16 |
+
label = True if similarity_score >= threshold else False
|
17 |
+
|
18 |
+
if return_score:
|
19 |
+
return {'score':similarity_score, 'label':label}
|
20 |
+
else:
|
21 |
+
return {'label':label}
|