im commited on
Commit
e27c656
·
1 Parent(s): 5362145

Updated app to handle both text and image embeddings

Browse files
Files changed (1) hide show
  1. app.py +56 -35
app.py CHANGED
@@ -3,55 +3,76 @@ from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
3
  from PIL import Image
4
  import requests
5
  import torch
6
- ###hey
7
 
8
  # Load the FashionCLIP processor and model
9
  processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
10
  model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
11
 
12
- # Define the function to process image and text
13
- def process_image_and_text(product_title, image_url):
14
  try:
15
- # Fetch and process the image
16
- response = requests.get(image_url, stream=True)
17
- response.raise_for_status()
18
- image = Image.open(response.raw)
19
-
20
- # Prepare inputs for the model
21
- inputs = processor(
22
- text=[product_title],
23
- images=image,
24
- return_tensors="pt",
25
- padding=True
26
- )
27
-
28
- # Perform inference
29
- with torch.no_grad():
30
- outputs = model(**inputs)
31
-
32
- # Extract similarity score and embeddings
33
- similarity_score = outputs.logits_per_image[0].item()
34
- text_embedding = outputs.logits_per_text.cpu().numpy().tolist()
35
- image_embedding = outputs.logits_per_image.cpu().numpy().tolist()
36
-
37
- return {
38
- "similarity_score": similarity_score,
39
- "text_embedding": text_embedding,
40
- "image_embedding": image_embedding
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  except Exception as e:
43
  return {"error": str(e)}
44
 
45
  # Create the Gradio interface
46
  interface = gr.Interface(
47
- fn=process_image_and_text,
48
  inputs=[
49
- gr.Textbox(label="Product Title", placeholder="e.g., ring for men"),
50
- gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg")
51
  ],
52
  outputs="json",
53
- title="FashionCLIP API",
54
- description="Provide a product title and an image URL to compute similarity score and embeddings."
55
  )
56
 
57
  # Launch the app
 
3
  from PIL import Image
4
  import requests
5
  import torch
 
6
 
7
  # Load the FashionCLIP processor and model
8
  processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
9
  model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
10
 
11
+ # Define the function to process both text and image inputs
12
+ def generate_embeddings(input_text=None, input_image_url=None):
13
  try:
14
+ if input_image_url:
15
+ # Process image with accompanying text
16
+ response = requests.get(input_image_url, stream=True)
17
+ response.raise_for_status()
18
+ image = Image.open(response.raw)
19
+
20
+ # Use a default text if none is provided
21
+ if not input_text:
22
+ input_text = "this is an image"
23
+
24
+ # Prepare inputs for the model
25
+ inputs = processor(
26
+ text=[input_text],
27
+ images=image,
28
+ return_tensors="pt",
29
+ padding=True
30
+ )
31
+
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+
35
+ image_embedding = outputs.logits_per_image.cpu().numpy().tolist()
36
+ return {
37
+ "type": "image_embedding",
38
+ "input_image_url": input_image_url,
39
+ "input_text": input_text,
40
+ "embedding": image_embedding
41
+ }
42
+
43
+ elif input_text:
44
+ # Process text input only
45
+ inputs = processor(
46
+ text=[input_text],
47
+ images=None,
48
+ return_tensors="pt",
49
+ padding=True
50
+ )
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+
54
+ text_embedding = outputs.logits_per_text.cpu().numpy().tolist()
55
+ return {
56
+ "type": "text_embedding",
57
+ "input_text": input_text,
58
+ "embedding": text_embedding
59
+ }
60
+ else:
61
+ return {"error": "Please provide either a text query or an image URL."}
62
+
63
  except Exception as e:
64
  return {"error": str(e)}
65
 
66
  # Create the Gradio interface
67
  interface = gr.Interface(
68
+ fn=generate_embeddings,
69
  inputs=[
70
+ gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"),
71
+ gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)")
72
  ],
73
  outputs="json",
74
+ title="FashionCLIP Combined Embedding API",
75
+ description="Provide a text query and/or an image URL to compute embeddings for vector search."
76
  )
77
 
78
  # Launch the app