bgaspra commited on
Commit
107b2a4
·
verified ·
1 Parent(s): 8bbea5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -31,16 +31,19 @@ text_embedding_cache = {}
31
 
32
  def get_image_embedding(image):
33
  try:
 
34
  inputs = processor(
35
  images=image,
36
- text=[""], # Florence-2 requires both image and text inputs
 
37
  return_tensors="pt"
38
  ).to(device, torch_dtype)
39
 
40
  with torch.no_grad():
 
41
  outputs = model(**inputs)
42
- # Get the image embeddings from the last hidden states
43
- image_embeddings = outputs.last_hidden_state[:, 0, :] # Take CLS token
44
  return image_embeddings.cpu().numpy()
45
  except Exception as e:
46
  print(f"Error in get_image_embedding: {str(e)}")
@@ -51,15 +54,25 @@ def get_text_embedding(text):
51
  if text in text_embedding_cache:
52
  return text_embedding_cache[text]
53
 
 
54
  inputs = processor(
55
  text=text,
56
- images=None,
57
  return_tensors="pt"
58
  ).to(device, torch_dtype)
59
 
 
 
 
 
 
 
 
 
 
60
  with torch.no_grad():
61
  outputs = model(**inputs)
62
- text_embeddings = outputs.last_hidden_state[:, 0, :] # Take CLS token
63
 
64
  embedding = text_embeddings.cpu().numpy()
65
  text_embedding_cache[text] = embedding
 
31
 
32
  def get_image_embedding(image):
33
  try:
34
+ # Process image and add dummy text input
35
  inputs = processor(
36
  images=image,
37
+ text="Describe this image", # Adding a default text prompt
38
+ padding=True,
39
  return_tensors="pt"
40
  ).to(device, torch_dtype)
41
 
42
  with torch.no_grad():
43
+ # Get model outputs
44
  outputs = model(**inputs)
45
+ # Extract image features from the cross-attention layers
46
+ image_embeddings = outputs.last_hidden_state.mean(dim=1)
47
  return image_embeddings.cpu().numpy()
48
  except Exception as e:
49
  print(f"Error in get_image_embedding: {str(e)}")
 
54
  if text in text_embedding_cache:
55
  return text_embedding_cache[text]
56
 
57
+ # Process text with proper input formatting
58
  inputs = processor(
59
  text=text,
60
+ padding=True,
61
  return_tensors="pt"
62
  ).to(device, torch_dtype)
63
 
64
+ # Add required decoder input ids
65
+ inputs['decoder_input_ids'] = model.generate(
66
+ **inputs,
67
+ max_length=1,
68
+ return_dict_in_generate=True,
69
+ output_hidden_states=True,
70
+ early_stopping=True
71
+ ).sequences
72
+
73
  with torch.no_grad():
74
  outputs = model(**inputs)
75
+ text_embeddings = outputs.last_hidden_state.mean(dim=1)
76
 
77
  embedding = text_embeddings.cpu().numpy()
78
  text_embedding_cache[text] = embedding