import base64 import hashlib import io import json import os import tempfile from collections import OrderedDict as CollectionsOrderedDict from pathlib import Path from threading import Thread from typing import Iterator, Optional, List, Union, OrderedDict import fitz import gradio as gr import requests import spaces import torch from PIL import Image from colpali_engine import ColPali, ColPaliProcessor from huggingface_hub import hf_hub_download from pydantic import BaseModel from qwen_vl_utils import process_vision_info from swift.llm import ( ModelType, get_model_tokenizer, get_default_template_type, get_template, inference, inference_stream, ) from tqdm import tqdm from transformers import ( Qwen2VLForConditionalGeneration, PreTrainedTokenizer, Qwen2VLProcessor, TextIteratorStreamer, AutoTokenizer, ) from ultralytics import YOLO from ultralytics.engine.results import Results MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) DESCRIPTION = """\ # M-Longdoc: A Benchmark For Multimodal Super-Long Document Understanding And A Retrieval-Aware Tuning Framework This Space demonstrates the multimodal long document understanding model with 7B parameters fine-tuned for texts, tables, and figures. Feel free to play with it, or duplicate to run generations without a queue! 🔎 For more details about the project, check out the [paper](https://arxiv.org/pdf/2411.06176). """ LICENSE = """

--- As a derivate work of [Llama-3-8b-chat](https://huggingface.co/meta-llama/Meta-Llama-3-8B) by Meta, this demo is governed by the original [license](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) and [acceptable use policy](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/USE_POLICY.md). """ class MultimodalSample(BaseModel): question: str answer: str category: str evidence_pages: List[int] = [] raw_output: str = "" pred: str = "" source: str = "" annotator: str = "" generator: str = "" retrieved_pages: List[int] = [] class MultimodalObject(BaseModel): id: str = "" page: int = 0 text: str = "" image_string: str = "" snippet: str = "" score: float = 0.0 source: str = "" category: str = "" def get_image(self) -> Optional[Image.Image]: if self.image_string: return convert_text_to_image(self.image_string) @classmethod def from_image(cls, image: Image.Image, **kwargs): return cls(image_string=convert_image_to_text(image), **kwargs) class ObjectDetector(BaseModel, arbitrary_types_allowed=True): def run(self, image: Image.Image) -> List[MultimodalObject]: raise NotImplementedError() class YoloDetector(ObjectDetector): repo_id: str = "DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet" filename: str = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt" local_dir: str = "data/yolo" client: Optional[YOLO] = None def load(self): if self.client is None: if not Path(self.local_dir, self.filename).exists(): hf_hub_download( repo_id=self.repo_id, filename=self.filename, local_dir=self.local_dir, ) self.client = YOLO(Path(self.local_dir, self.filename)) def save_image(self, image: Image.Image) -> str: text = convert_image_to_text(image) hash_id = hashlib.md5(text.encode()).hexdigest() path = Path(self.local_dir, f"{hash_id}.png") image.save(path) return str(path) @staticmethod def extract_subimage(image: Image.Image, box: List[float]) -> Image.Image: return image.crop((round(box[0]), round(box[1]), round(box[2]), round(box[3]))) def run(self, image: Image.Image) -> List[MultimodalObject]: self.load() path = self.save_image(image) results: List[Results] = self.client(source=[path]) assert len(results) == 1 objects = [] for i, label_id in enumerate(results[0].boxes.cls): label = results[0].names[label_id.item()] score = results[0].boxes.conf[i].item() box: List[float] = results[0].boxes.xyxy[i].tolist() subimage = self.extract_subimage(image, box) objects.append( MultimodalObject( image_string=convert_image_to_text(subimage), category=label, score=score, ) ) return objects class MultimodalPage(BaseModel): number: int objects: List[MultimodalObject] text: str image_string: str source: str score: float = 0.0 def get_tables_and_figures(self) -> List[MultimodalObject]: return [o for o in self.objects if o.category in ["Table", "Picture"]] def get_full_image(self) -> Image.Image: return convert_text_to_image(self.image_string) @classmethod def from_text(cls, text: str): return MultimodalPage( text=text, number=0, objects=[], image_string="", source="" ) @classmethod def from_image(cls, image: Image.Image): return MultimodalPage( image_string=convert_image_to_text(image), number=0, objects=[], text="", source="", ) class MultimodalDocument(BaseModel): pages: List[MultimodalPage] def get_page(self, i: int) -> MultimodalPage: pages = [p for p in self.pages if p.number == i] assert len(pages) == 1 return pages[0] @classmethod def load_from_pdf(cls, path: str, dpi: int = 150, detector: ObjectDetector = None): # Each page as an image (with optional extracted text) doc = fitz.open(path) pages = [] for i, page in enumerate(tqdm(doc.pages(), desc=path)): text = page.get_text() zoom = dpi / 72 # 72 is the default DPI matrix = fitz.Matrix(zoom, zoom) pix = page.get_pixmap(matrix=matrix) image = Image.frombytes("RGB", (pix.width, pix.height), pix.samples) objects = [] if detector: objects = detector.run(image) for o in objects: o.page, o.source = i + 1, path pages.append( MultimodalPage( number=i + 1, objects=objects, text=text, image_string=convert_image_to_text(image), source=path, ) ) return cls(pages=pages) @classmethod def load(cls, path: str): pages = [] with open(path) as f: for line in f: pages.append(MultimodalPage(**json.loads(line))) return cls(pages=pages) def save(self, path: str): Path(path).parent.mkdir(exist_ok=True, parents=True) with open(path, "w") as f: for o in self.pages: print(o.model_dump_json(), file=f) def get_domain(self) -> str: filename = Path(self.pages[0].source).name if filename.startswith("NYSE"): return "Financial
Report" elif filename[:4].isdigit() and filename[4] == "." and filename[5].isdigit(): return "Academic
Paper" else: return "Technical
Manuals" class MultimodalRetriever(BaseModel, arbitrary_types_allowed=True): def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: raise NotImplementedError @staticmethod def get_top_pages(doc: MultimodalDocument, k: int) -> List[int]: # Get top-k in terms of score but maintain the original order doc = doc.copy(deep=True) pages = sorted(doc.pages, key=lambda x: x.score, reverse=True) threshold = pages[:k][-1].score return [p.number for p in doc.pages if p.score >= threshold] class ColpaliRetriever(MultimodalRetriever): path: str = "vidore/colpali-v1.2" model: Optional[ColPali] = None processor: Optional[ColPaliProcessor] = None device: str = "cuda" cache: OrderedDict[str, torch.Tensor] = CollectionsOrderedDict() def load(self): if self.model is None: self.model = ColPali.from_pretrained( self.path, torch_dtype=torch.bfloat16, device_map=self.device ) self.model = self.model.eval() self.processor = ColPaliProcessor.from_pretrained(self.path) def encode_document(self, doc: MultimodalDocument) -> torch.Tensor: hash_id = hashlib.md5(doc.json().encode()).hexdigest() if len(self.cache) > 100: self.cache.popitem(last=False) if hash_id not in self.cache: images = [page.get_full_image() for page in doc.pages] batch_size = 8 ds: List[torch.Tensor] = [] for i in tqdm(range(0, len(images), batch_size), desc="Encoding document"): batch = self.processor.process_images(images[i : i + batch_size]) with torch.no_grad(): # noinspection PyTypeChecker ds.append(self.model(**batch.to(self.device)).cpu()) lengths = [x.shape[1] for x in ds] if len(set(lengths)) != 1: print("Warning: Inconsistent lengths from colqwen", set(lengths)) assert "colqwen" in self.path for i, x in enumerate(ds): ds[i] = x[:, : min(lengths), :] self.cache[hash_id] = torch.cat(ds) return self.cache[hash_id] def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: doc = doc.copy(deep=True) self.load() ds = self.encode_document(doc) with torch.no_grad(): # noinspection PyTypeChecker qs = self.model(**self.processor.process_queries([query]).to(self.device)) # noinspection PyTypeChecker scores = self.processor.score_multi_vector(qs.cpu(), ds).squeeze() assert len(scores) == len(doc.pages) for i, page in enumerate(doc.pages): page.score = scores[i].item() return doc class DummyRetriever(MultimodalRetriever): def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: doc = doc.copy(deep=True) for i, page in enumerate(doc.pages): page.score = i return doc def convert_image_to_text(image: Image) -> str: # This is also how OpenAI encodes images: https://platform.openai.com/docs/guides/vision with io.BytesIO() as output: image.save(output, format="PNG") data = output.getvalue() return base64.b64encode(data).decode("utf-8") def convert_text_to_image(text: str) -> Image: data = base64.b64decode(text.encode("utf-8")) return Image.open(io.BytesIO(data)) def save_image(image: Image.Image, folder: str) -> str: image_hash = hashlib.md5(image.tobytes()).hexdigest() path = Path(folder, f"{image_hash}.png") path.parent.mkdir(exist_ok=True, parents=True) if not path.exists(): image.save(path) return str(path) def resize_image(image: Image.Image, max_size: int) -> Image.Image: # Same as modeling.py resize_image width, height = image.size if width <= max_size and height <= max_size: return image if width > height: new_width = max_size new_height = round(height * max_size / width) else: new_height = max_size new_width = round(width * max_size / height) return image.resize((new_width, new_height), Image.LANCZOS) class EvalModel(BaseModel, arbitrary_types_allowed=True): engine: str timeout: int = 60 temperature: float = 0.0 max_output_tokens: int = 512 def run(self, inputs: List[Union[str, Image.Image]]) -> str: raise NotImplementedError def run_many(self, inputs: List[Union[str, Image.Image]], num: int) -> List[str]: raise NotImplementedError class SwiftQwenModel(EvalModel): # https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Multi-Modal/qwen2-vl-best-practice.md path: str = "" model: Optional[Qwen2VLForConditionalGeneration] = None tokenizer: Optional[PreTrainedTokenizer] = None engine: str = ModelType.qwen2_vl_7b_instruct image_size: int = 768 image_dir: str = "data/qwen_images" def load(self): if self.model is None or self.tokenizer is None: self.model, self.tokenizer = get_model_tokenizer( self.engine, torch.bfloat16, model_kwargs={"device_map": "auto"}, model_id_or_path=self.path or None, ) def run(self, inputs: List[Union[str, Image.Image]]) -> str: self.load() template_type = get_default_template_type(self.engine) self.model.generation_config.max_new_tokens = self.max_output_tokens template = get_template(template_type, self.tokenizer) text = "\n\n".join([x for x in inputs if isinstance(x, str)]) content = [] for x in inputs: if isinstance(x, Image.Image): path = save_image(resize_image(x, self.image_size), self.image_dir) content.append(f"{path}") content.append(text) query = "".join(content) response, history = inference(self.model, template, query) return response def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: self.load() template_type = get_default_template_type(self.engine) self.model.generation_config.max_new_tokens = self.max_output_tokens template = get_template(template_type, self.tokenizer) text = "\n\n".join([x for x in inputs if isinstance(x, str)]) content = [] for x in inputs: if isinstance(x, Image.Image): path = save_image(resize_image(x, self.image_size), self.image_dir) content.append(f"{path}") content.append(text) query = "".join(content) generator = inference_stream(self.model, template, query) print_idx = 0 print(f"query: {query}\nresponse: ", end="") for response, history in generator: delta = response[print_idx:] print(delta, end="", flush=True) print_idx = len(response) yield delta class QwenModel(EvalModel): path: str = "models/qwen" engine: str = "Qwen/Qwen2-VL-7B-Instruct" model: Optional[Qwen2VLForConditionalGeneration] = None processor: Optional[Qwen2VLProcessor] = None tokenizer: Optional[AutoTokenizer] = None device: str = "cuda" image_size: int = 768 lora_path: str = "" def load(self): if self.model is None: path = self.path if os.path.exists(self.path) else self.engine print(dict(load_path=path)) # noinspection PyTypeChecker self.model = Qwen2VLForConditionalGeneration.from_pretrained( path, torch_dtype="auto", device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained(self.engine) if self.lora_path: print("Loading LORA from", self.lora_path) self.model.load_adapter(self.lora_path) self.model = self.model.to(self.device).eval() self.processor = Qwen2VLProcessor.from_pretrained(self.engine) torch.manual_seed(0) torch.cuda.manual_seed_all(0) def make_messages(self, inputs: List[Union[str, Image.Image]]) -> List[dict]: text = "\n\n".join([x for x in inputs if isinstance(x, str)]) content = [ dict( type="image", image=f"data:image;base64,{convert_image_to_text(resize_image(x, self.image_size))}", ) for x in inputs if isinstance(x, Image.Image) ] content.append(dict(type="text", text=text)) return [dict(role="user", content=content)] def run(self, inputs: List[Union[str, Image.Image]]) -> str: self.load() messages = self.make_messages(inputs) text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) # noinspection PyTypeChecker model_inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(self.device) with torch.inference_mode(): generated_ids = self.model.generate( **model_inputs, max_new_tokens=self.max_output_tokens ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return output_text[0] def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: self.load() messages = self.make_messages(inputs) text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) # noinspection PyTypeChecker model_inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(self.device) streamer = TextIteratorStreamer( self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True, ) generate_kwargs = dict( **model_inputs, streamer=streamer, max_new_tokens=self.max_output_tokens, ) t = Thread(target=self.model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) class DummyModel(EvalModel): engine: str = "" def run(self, inputs: List[Union[str, Image.Image]]) -> str: return " ".join(inputs) def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: assert self is not None text = " ".join([x for x in inputs if isinstance(x, str)]) num_images = sum(1 for x in inputs if isinstance(x, Image.Image)) tokens = f"Hello this is your message: {text}, images: {num_images}".split() for i in range(len(tokens)): yield " ".join(tokens[: i + 1]) import time time.sleep(0.05) if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" if torch.cuda.is_available(): model = QwenModel() model.load() detect_model = YoloDetector() detect_model.load() retriever = ColpaliRetriever() retriever.load() else: model = DummyModel() detect_model = None retriever = DummyRetriever() def get_file_path(file: gr.File = None, url: str = None) -> Optional[str]: if file is not None: # noinspection PyUnresolvedReferences return file.name if url is not None: response = requests.get(url) response.raise_for_status() save_path = Path(tempfile.mkdtemp(), url.split("/")[-1]) if "application/pdf" in response.headers.get("Content-Type", ""): # Open the file in binary write mode with open(save_path, "wb") as file: file.write(response.content) return str(save_path) @spaces.GPU def generate( query: str, file: gr.File = None, url: str = None, top_k=5 ) -> Iterator[str]: sample = MultimodalSample(question=query, answer="", category="") path = get_file_path(file, url) if path is not None: doc = MultimodalDocument.load_from_pdf(path, detector=detect_model) output = retriever.run(sample.question, doc) sorted_pages = sorted(output.pages, key=lambda p: p.score, reverse=True) sample.retrieved_pages = sorted([p.number for p in sorted_pages][:top_k]) context = [] for p in doc.pages: if p.number in sample.retrieved_pages: if p.text: context.append(p.text) context.extend(o.get_image() for o in p.get_tables_and_figures()) inputs = [ "Context:", *context, f"Answer the following question in 200 words or less: {sample.question}", ] else: inputs = [ "Context:", f"Answer the following question in 200 words or less: {sample.question}", ] for text in model.run_stream(inputs): yield text with gr.Blocks(css_paths="style.css", fill_height=True) as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button" ) with gr.Row(): pdf_upload = gr.File(label="Upload PDF (optional)", file_types=[".pdf"]) with gr.Column(): url_input = gr.Textbox(label="Enter PDF URL (optional)") text_input = gr.Textbox(label="Enter your message", lines=3) submit_button = gr.Button("Submit") result = gr.Textbox(label="Response", lines=10) submit_button.click( generate, inputs=[text_input, pdf_upload, url_input], outputs=result ) gr.Markdown(LICENSE) if __name__ == "__main__": demo.queue(max_size=20).launch()