geekyrakshit commited on
Commit
789b57f
·
1 Parent(s): 49cde8e

add: alias to document loader artifacts and datasets + enable mps fallback for marker

Browse files
medrag_multi_modal/document_loader/image_loader/base_img_loader.py CHANGED
@@ -100,7 +100,11 @@ class BaseImageLoader(BaseTextLoader):
100
  await task
101
 
102
  if wandb_artifact_name:
103
- artifact = wandb.Artifact(name=wandb_artifact_name, type="dataset")
 
 
 
 
104
  artifact.add_dir(local_path=image_save_dir)
105
  artifact.save()
106
  rich.print("Artifact saved and uploaded to wandb!")
 
100
  await task
101
 
102
  if wandb_artifact_name:
103
+ artifact = wandb.Artifact(
104
+ name=wandb_artifact_name,
105
+ type="dataset",
106
+ metadata={"loader_name": self.__class__.__name__},
107
+ )
108
  artifact.add_dir(local_path=image_save_dir)
109
  artifact.save()
110
  rich.print("Artifact saved and uploaded to wandb!")
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py CHANGED
@@ -6,6 +6,8 @@ from marker.models import load_all_models
6
 
7
  from .base_img_loader import BaseImageLoader
8
 
 
 
9
 
10
  class MarkerImageLoader(BaseImageLoader):
11
  """
 
6
 
7
  from .base_img_loader import BaseImageLoader
8
 
9
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
10
+
11
 
12
  class MarkerImageLoader(BaseImageLoader):
13
  """
medrag_multi_modal/document_loader/text_loader/base_text_loader.py CHANGED
@@ -131,6 +131,7 @@ class BaseTextLoader(ABC):
131
  async def process_page(page_idx):
132
  nonlocal processed_pages_counter
133
  page_data = await self.extract_page_data(page_idx, **kwargs)
 
134
  pages.append(page_data)
135
  rich.print(
136
  f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
 
131
  async def process_page(page_idx):
132
  nonlocal processed_pages_counter
133
  page_data = await self.extract_page_data(page_idx, **kwargs)
134
+ page_data["loader_name"] = self.__class__.__name__
135
  pages.append(page_data)
136
  rich.print(
137
  f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
medrag_multi_modal/document_loader/text_loader/marker_text_loader.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Dict
2
 
3
  from marker.convert import convert_single_pdf
@@ -5,6 +6,8 @@ from marker.models import load_all_models
5
 
6
  from .base_text_loader import BaseTextLoader
7
 
 
 
8
 
9
  class MarkerTextLoader(BaseTextLoader):
10
  """
 
1
+ import os
2
  from typing import Dict
3
 
4
  from marker.convert import convert_single_pdf
 
6
 
7
  from .base_text_loader import BaseTextLoader
8
 
9
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
10
+
11
 
12
  class MarkerTextLoader(BaseTextLoader):
13
  """