isayahc commited on
Commit
152e491
1 Parent(s): 95b3333

now users can embed all of the iamges from a pdf

Browse files
innovation_pathfinder_ai/vector_store/chroma_vector_store.py CHANGED
@@ -28,7 +28,7 @@ from innovation_pathfinder_ai.utils.utils import (
28
  )
29
 
30
  from innovation_pathfinder_ai.utils.image_processing.image_processing import (
31
- caption_image
32
  )
33
 
34
  from typing import List, Optional, NoReturn
@@ -280,6 +280,58 @@ def add_image_to_vector_store(
280
  )
281
 
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  if __name__ == "__main__":
285
 
 
28
  )
29
 
30
  from innovation_pathfinder_ai.utils.image_processing.image_processing import (
31
+ caption_image, extract_images_from_pdf
32
  )
33
 
34
  from typing import List, Optional, NoReturn
 
280
  )
281
 
282
 
283
+ def add_images_to_vector_store(
284
+ collection_name:str,
285
+ pdf_location:str,
286
+ pdf_images_location:str,
287
+ vector_store_client:chromadb.PersistentClient)-> NoReturn:
288
+
289
+ meta_data = extract_images_from_pdf(
290
+ pdf_path=pdf_location,
291
+ output_folder=pdf_images_location,
292
+ )
293
+
294
+ embedding_function = SentenceTransformerEmbeddings(
295
+ model_name=os.getenv("EMBEDDING_MODEL"),
296
+ )
297
+
298
+ client = chromadb.PersistentClient(
299
+ path=persist_directory,
300
+ )
301
+
302
+ collection = client.get_or_create_collection(
303
+ name=collection_name,
304
+ )
305
+
306
+ captioned_images = []
307
+ for i in meta_data:
308
+ temp = caption_image(i["image_location"])
309
+ captioned_images.append(temp)
310
+
311
+
312
+ dict_with_added_captions = []
313
+ for d, caption in zip(meta_data, captioned_images):
314
+ d["image_caption"] = caption[0]['generated_text']
315
+ dict_with_added_captions.append(d)
316
+
317
+
318
+ doc_list = []
319
+ for i in dict_with_added_captions:
320
+ temp = Document(
321
+ page_content=i['image_caption'],
322
+ metadata=i
323
+ )
324
+ doc_list.append(temp)
325
+
326
+ collection.add(
327
+ ids=[generate_uuid() for i in doc_list], # give each document a uuid
328
+ documents=[i.page_content for i in doc_list], # contents of document
329
+ embeddings=[embedding_function.embed_query(i.page_content) for i in doc_list],
330
+ metadatas=[i.metadata for i in doc_list], # type: ignore
331
+ )
332
+
333
+
334
+
335
 
336
  if __name__ == "__main__":
337