Spaces:
Running
on
Zero
Running
on
Zero
import base64 | |
import hashlib | |
import io | |
import json | |
import os | |
import tempfile | |
from collections import OrderedDict as CollectionsOrderedDict | |
from pathlib import Path | |
from threading import Thread | |
from typing import Iterator, Optional, List, Union, OrderedDict | |
import fitz | |
import gradio as gr | |
import requests | |
import spaces | |
import torch | |
from PIL import Image | |
from colpali_engine import ColPali, ColPaliProcessor | |
from huggingface_hub import hf_hub_download | |
from pydantic import BaseModel | |
from qwen_vl_utils import process_vision_info | |
from swift.llm import ( | |
ModelType, | |
get_model_tokenizer, | |
get_default_template_type, | |
get_template, | |
inference, | |
inference_stream, | |
) | |
from tqdm import tqdm | |
from transformers import ( | |
Qwen2VLForConditionalGeneration, | |
PreTrainedTokenizer, | |
Qwen2VLProcessor, | |
TextIteratorStreamer, | |
AutoTokenizer, | |
) | |
from ultralytics import YOLO | |
from ultralytics.engine.results import Results | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
DESCRIPTION = """\ | |
# M-Longdoc: A Benchmark For Multimodal Super-Long Document Understanding And A Retrieval-Aware Tuning Framework | |
This Space demonstrates the multimodal long document understanding model with 7B parameters fine-tuned for texts, tables, and figures. Feel free to play with it, or duplicate to run generations without a queue! | |
🔎 For more details about the project, check out the [paper](https://arxiv.org/pdf/2411.06176). | |
""" | |
LICENSE = """ | |
<p/> | |
--- | |
As a derivate work of [Llama-3-8b-chat](https://huggingface.co/meta-llama/Meta-Llama-3-8B) by Meta, | |
this demo is governed by the original [license](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) and [acceptable use policy](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/USE_POLICY.md). | |
""" | |
class MultimodalSample(BaseModel): | |
question: str | |
answer: str | |
category: str | |
evidence_pages: List[int] = [] | |
raw_output: str = "" | |
pred: str = "" | |
source: str = "" | |
annotator: str = "" | |
generator: str = "" | |
retrieved_pages: List[int] = [] | |
class MultimodalObject(BaseModel): | |
id: str = "" | |
page: int = 0 | |
text: str = "" | |
image_string: str = "" | |
snippet: str = "" | |
score: float = 0.0 | |
source: str = "" | |
category: str = "" | |
def get_image(self) -> Optional[Image.Image]: | |
if self.image_string: | |
return convert_text_to_image(self.image_string) | |
def from_image(cls, image: Image.Image, **kwargs): | |
return cls(image_string=convert_image_to_text(image), **kwargs) | |
class ObjectDetector(BaseModel, arbitrary_types_allowed=True): | |
def run(self, image: Image.Image) -> List[MultimodalObject]: | |
raise NotImplementedError() | |
class YoloDetector(ObjectDetector): | |
repo_id: str = "DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet" | |
filename: str = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt" | |
local_dir: str = "data/yolo" | |
client: Optional[YOLO] = None | |
def load(self): | |
if self.client is None: | |
if not Path(self.local_dir, self.filename).exists(): | |
hf_hub_download( | |
repo_id=self.repo_id, | |
filename=self.filename, | |
local_dir=self.local_dir, | |
) | |
self.client = YOLO(Path(self.local_dir, self.filename)) | |
def save_image(self, image: Image.Image) -> str: | |
text = convert_image_to_text(image) | |
hash_id = hashlib.md5(text.encode()).hexdigest() | |
path = Path(self.local_dir, f"{hash_id}.png") | |
image.save(path) | |
return str(path) | |
def extract_subimage(image: Image.Image, box: List[float]) -> Image.Image: | |
return image.crop((round(box[0]), round(box[1]), round(box[2]), round(box[3]))) | |
def run(self, image: Image.Image) -> List[MultimodalObject]: | |
self.load() | |
path = self.save_image(image) | |
results: List[Results] = self.client(source=[path]) | |
assert len(results) == 1 | |
objects = [] | |
for i, label_id in enumerate(results[0].boxes.cls): | |
label = results[0].names[label_id.item()] | |
score = results[0].boxes.conf[i].item() | |
box: List[float] = results[0].boxes.xyxy[i].tolist() | |
subimage = self.extract_subimage(image, box) | |
objects.append( | |
MultimodalObject( | |
image_string=convert_image_to_text(subimage), | |
category=label, | |
score=score, | |
) | |
) | |
return objects | |
class MultimodalPage(BaseModel): | |
number: int | |
objects: List[MultimodalObject] | |
text: str | |
image_string: str | |
source: str | |
score: float = 0.0 | |
def get_tables_and_figures(self) -> List[MultimodalObject]: | |
return [o for o in self.objects if o.category in ["Table", "Picture"]] | |
def get_full_image(self) -> Image.Image: | |
return convert_text_to_image(self.image_string) | |
def from_text(cls, text: str): | |
return MultimodalPage( | |
text=text, number=0, objects=[], image_string="", source="" | |
) | |
def from_image(cls, image: Image.Image): | |
return MultimodalPage( | |
image_string=convert_image_to_text(image), | |
number=0, | |
objects=[], | |
text="", | |
source="", | |
) | |
class MultimodalDocument(BaseModel): | |
pages: List[MultimodalPage] | |
def get_page(self, i: int) -> MultimodalPage: | |
pages = [p for p in self.pages if p.number == i] | |
assert len(pages) == 1 | |
return pages[0] | |
def load_from_pdf(cls, path: str, dpi: int = 150, detector: ObjectDetector = None): | |
# Each page as an image (with optional extracted text) | |
doc = fitz.open(path) | |
pages = [] | |
for i, page in enumerate(tqdm(doc.pages(), desc=path)): | |
text = page.get_text() | |
zoom = dpi / 72 # 72 is the default DPI | |
matrix = fitz.Matrix(zoom, zoom) | |
pix = page.get_pixmap(matrix=matrix) | |
image = Image.frombytes("RGB", (pix.width, pix.height), pix.samples) | |
objects = [] | |
if detector: | |
objects = detector.run(image) | |
for o in objects: | |
o.page, o.source = i + 1, path | |
pages.append( | |
MultimodalPage( | |
number=i + 1, | |
objects=objects, | |
text=text, | |
image_string=convert_image_to_text(image), | |
source=path, | |
) | |
) | |
return cls(pages=pages) | |
def load(cls, path: str): | |
pages = [] | |
with open(path) as f: | |
for line in f: | |
pages.append(MultimodalPage(**json.loads(line))) | |
return cls(pages=pages) | |
def save(self, path: str): | |
Path(path).parent.mkdir(exist_ok=True, parents=True) | |
with open(path, "w") as f: | |
for o in self.pages: | |
print(o.model_dump_json(), file=f) | |
def get_domain(self) -> str: | |
filename = Path(self.pages[0].source).name | |
if filename.startswith("NYSE"): | |
return "Financial<br>Report" | |
elif filename[:4].isdigit() and filename[4] == "." and filename[5].isdigit(): | |
return "Academic<br>Paper" | |
else: | |
return "Technical<br>Manuals" | |
class MultimodalRetriever(BaseModel, arbitrary_types_allowed=True): | |
def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: | |
raise NotImplementedError | |
def get_top_pages(doc: MultimodalDocument, k: int) -> List[int]: | |
# Get top-k in terms of score but maintain the original order | |
doc = doc.copy(deep=True) | |
pages = sorted(doc.pages, key=lambda x: x.score, reverse=True) | |
threshold = pages[:k][-1].score | |
return [p.number for p in doc.pages if p.score >= threshold] | |
class ColpaliRetriever(MultimodalRetriever): | |
path: str = "vidore/colpali-v1.2" | |
model: Optional[ColPali] = None | |
processor: Optional[ColPaliProcessor] = None | |
device: str = "cuda" | |
cache: OrderedDict[str, torch.Tensor] = CollectionsOrderedDict() | |
def load(self): | |
if self.model is None: | |
self.model = ColPali.from_pretrained( | |
self.path, torch_dtype=torch.bfloat16, device_map=self.device | |
) | |
self.model = self.model.eval() | |
self.processor = ColPaliProcessor.from_pretrained(self.path) | |
def encode_document(self, doc: MultimodalDocument) -> torch.Tensor: | |
hash_id = hashlib.md5(doc.json().encode()).hexdigest() | |
if len(self.cache) > 100: | |
self.cache.popitem(last=False) | |
if hash_id not in self.cache: | |
images = [page.get_full_image() for page in doc.pages] | |
batch_size = 8 | |
ds: List[torch.Tensor] = [] | |
for i in tqdm(range(0, len(images), batch_size), desc="Encoding document"): | |
batch = self.processor.process_images(images[i : i + batch_size]) | |
with torch.no_grad(): | |
# noinspection PyTypeChecker | |
ds.append(self.model(**batch.to(self.device)).cpu()) | |
lengths = [x.shape[1] for x in ds] | |
if len(set(lengths)) != 1: | |
print("Warning: Inconsistent lengths from colqwen", set(lengths)) | |
assert "colqwen" in self.path | |
for i, x in enumerate(ds): | |
ds[i] = x[:, : min(lengths), :] | |
self.cache[hash_id] = torch.cat(ds) | |
return self.cache[hash_id] | |
def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: | |
doc = doc.copy(deep=True) | |
self.load() | |
ds = self.encode_document(doc) | |
with torch.no_grad(): | |
# noinspection PyTypeChecker | |
qs = self.model(**self.processor.process_queries([query]).to(self.device)) | |
# noinspection PyTypeChecker | |
scores = self.processor.score_multi_vector(qs.cpu(), ds).squeeze() | |
assert len(scores) == len(doc.pages) | |
for i, page in enumerate(doc.pages): | |
page.score = scores[i].item() | |
return doc | |
class DummyRetriever(MultimodalRetriever): | |
def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: | |
doc = doc.copy(deep=True) | |
for i, page in enumerate(doc.pages): | |
page.score = i | |
return doc | |
def convert_image_to_text(image: Image) -> str: | |
# This is also how OpenAI encodes images: https://platform.openai.com/docs/guides/vision | |
with io.BytesIO() as output: | |
image.save(output, format="PNG") | |
data = output.getvalue() | |
return base64.b64encode(data).decode("utf-8") | |
def convert_text_to_image(text: str) -> Image: | |
data = base64.b64decode(text.encode("utf-8")) | |
return Image.open(io.BytesIO(data)) | |
def save_image(image: Image.Image, folder: str) -> str: | |
image_hash = hashlib.md5(image.tobytes()).hexdigest() | |
path = Path(folder, f"{image_hash}.png") | |
path.parent.mkdir(exist_ok=True, parents=True) | |
if not path.exists(): | |
image.save(path) | |
return str(path) | |
def resize_image(image: Image.Image, max_size: int) -> Image.Image: | |
# Same as modeling.py resize_image | |
width, height = image.size | |
if width <= max_size and height <= max_size: | |
return image | |
if width > height: | |
new_width = max_size | |
new_height = round(height * max_size / width) | |
else: | |
new_height = max_size | |
new_width = round(width * max_size / height) | |
return image.resize((new_width, new_height), Image.LANCZOS) | |
class EvalModel(BaseModel, arbitrary_types_allowed=True): | |
engine: str | |
timeout: int = 60 | |
temperature: float = 0.0 | |
max_output_tokens: int = 512 | |
def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
raise NotImplementedError | |
def run_many(self, inputs: List[Union[str, Image.Image]], num: int) -> List[str]: | |
raise NotImplementedError | |
class SwiftQwenModel(EvalModel): | |
# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Multi-Modal/qwen2-vl-best-practice.md | |
path: str = "" | |
model: Optional[Qwen2VLForConditionalGeneration] = None | |
tokenizer: Optional[PreTrainedTokenizer] = None | |
engine: str = ModelType.qwen2_vl_7b_instruct | |
image_size: int = 768 | |
image_dir: str = "data/qwen_images" | |
def load(self): | |
if self.model is None or self.tokenizer is None: | |
self.model, self.tokenizer = get_model_tokenizer( | |
self.engine, | |
torch.bfloat16, | |
model_kwargs={"device_map": "auto"}, | |
model_id_or_path=self.path or None, | |
) | |
def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
self.load() | |
template_type = get_default_template_type(self.engine) | |
self.model.generation_config.max_new_tokens = self.max_output_tokens | |
template = get_template(template_type, self.tokenizer) | |
text = "\n\n".join([x for x in inputs if isinstance(x, str)]) | |
content = [] | |
for x in inputs: | |
if isinstance(x, Image.Image): | |
path = save_image(resize_image(x, self.image_size), self.image_dir) | |
content.append(f"<img>{path}</img>") | |
content.append(text) | |
query = "".join(content) | |
response, history = inference(self.model, template, query) | |
return response | |
def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: | |
self.load() | |
template_type = get_default_template_type(self.engine) | |
self.model.generation_config.max_new_tokens = self.max_output_tokens | |
template = get_template(template_type, self.tokenizer) | |
text = "\n\n".join([x for x in inputs if isinstance(x, str)]) | |
content = [] | |
for x in inputs: | |
if isinstance(x, Image.Image): | |
path = save_image(resize_image(x, self.image_size), self.image_dir) | |
content.append(f"<img>{path}</img>") | |
content.append(text) | |
query = "".join(content) | |
generator = inference_stream(self.model, template, query) | |
print_idx = 0 | |
print(f"query: {query}\nresponse: ", end="") | |
for response, history in generator: | |
delta = response[print_idx:] | |
print(delta, end="", flush=True) | |
print_idx = len(response) | |
yield delta | |
class QwenModel(EvalModel): | |
path: str = "models/qwen" | |
engine: str = "Qwen/Qwen2-VL-7B-Instruct" | |
model: Optional[Qwen2VLForConditionalGeneration] = None | |
processor: Optional[Qwen2VLProcessor] = None | |
tokenizer: Optional[AutoTokenizer] = None | |
device: str = "cuda" | |
image_size: int = 768 | |
lora_path: str = "" | |
def load(self): | |
if self.model is None: | |
path = self.path if os.path.exists(self.path) else self.engine | |
print(dict(load_path=path)) | |
# noinspection PyTypeChecker | |
self.model = Qwen2VLForConditionalGeneration.from_pretrained( | |
path, torch_dtype="auto", device_map="auto" | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.engine) | |
if self.lora_path: | |
print("Loading LORA from", self.lora_path) | |
self.model.load_adapter(self.lora_path) | |
self.model = self.model.to(self.device).eval() | |
self.processor = Qwen2VLProcessor.from_pretrained(self.engine) | |
torch.manual_seed(0) | |
torch.cuda.manual_seed_all(0) | |
def make_messages(self, inputs: List[Union[str, Image.Image]]) -> List[dict]: | |
text = "\n\n".join([x for x in inputs if isinstance(x, str)]) | |
content = [ | |
dict( | |
type="image", | |
image=f"data:image;base64,{convert_image_to_text(resize_image(x, self.image_size))}", | |
) | |
for x in inputs | |
if isinstance(x, Image.Image) | |
] | |
content.append(dict(type="text", text=text)) | |
return [dict(role="user", content=content)] | |
def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
self.load() | |
messages = self.make_messages(inputs) | |
text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
# noinspection PyTypeChecker | |
model_inputs = self.processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(self.device) | |
with torch.inference_mode(): | |
generated_ids = self.model.generate( | |
**model_inputs, max_new_tokens=self.max_output_tokens | |
) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] | |
for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
output_text = self.processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
) | |
return output_text[0] | |
def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: | |
self.load() | |
messages = self.make_messages(inputs) | |
text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
# noinspection PyTypeChecker | |
model_inputs = self.processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(self.device) | |
streamer = TextIteratorStreamer( | |
self.tokenizer, | |
timeout=10.0, | |
skip_prompt=True, | |
skip_special_tokens=True, | |
) | |
generate_kwargs = dict( | |
**model_inputs, | |
streamer=streamer, | |
max_new_tokens=self.max_output_tokens, | |
) | |
t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
class DummyModel(EvalModel): | |
engine: str = "" | |
def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
return " ".join(inputs) | |
def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: | |
assert self is not None | |
text = " ".join([x for x in inputs if isinstance(x, str)]) | |
num_images = sum(1 for x in inputs if isinstance(x, Image.Image)) | |
tokens = f"Hello this is your message: {text}, images: {num_images}".split() | |
for i in range(len(tokens)): | |
yield " ".join(tokens[: i + 1]) | |
import time | |
time.sleep(0.05) | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
model = QwenModel() | |
model.load() | |
detect_model = YoloDetector() | |
detect_model.load() | |
retriever = ColpaliRetriever() | |
retriever.load() | |
else: | |
model = DummyModel() | |
detect_model = None | |
retriever = DummyRetriever() | |
def get_file_path(file: gr.File = None, url: str = None) -> Optional[str]: | |
if file is not None: | |
# noinspection PyUnresolvedReferences | |
return file.name | |
if url is not None: | |
response = requests.get(url) | |
response.raise_for_status() | |
save_path = Path(tempfile.mkdtemp(), url.split("/")[-1]) | |
if "application/pdf" in response.headers.get("Content-Type", ""): | |
# Open the file in binary write mode | |
with open(save_path, "wb") as file: | |
file.write(response.content) | |
return str(save_path) | |
def generate( | |
query: str, file: gr.File = None, url: str = None, top_k=5 | |
) -> Iterator[str]: | |
sample = MultimodalSample(question=query, answer="", category="") | |
path = get_file_path(file, url) | |
if path is not None: | |
doc = MultimodalDocument.load_from_pdf(path, detector=detect_model) | |
output = retriever.run(sample.question, doc) | |
sorted_pages = sorted(output.pages, key=lambda p: p.score, reverse=True) | |
sample.retrieved_pages = sorted([p.number for p in sorted_pages][:top_k]) | |
context = [] | |
for p in doc.pages: | |
if p.number in sample.retrieved_pages: | |
if p.text: | |
context.append(p.text) | |
context.extend(o.get_image() for o in p.get_tables_and_figures()) | |
inputs = [ | |
"Context:", | |
*context, | |
f"Answer the following question in 200 words or less: {sample.question}", | |
] | |
else: | |
inputs = [ | |
"Context:", | |
f"Answer the following question in 200 words or less: {sample.question}", | |
] | |
for text in model.run_stream(inputs): | |
yield text | |
with gr.Blocks(css_paths="style.css", fill_height=True) as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton( | |
value="Duplicate Space for private use", elem_id="duplicate-button" | |
) | |
with gr.Row(): | |
pdf_upload = gr.File(label="Upload PDF (optional)", file_types=[".pdf"]) | |
with gr.Column(): | |
url_input = gr.Textbox(label="Enter PDF URL (optional)") | |
text_input = gr.Textbox(label="Enter your message", lines=3) | |
submit_button = gr.Button("Submit") | |
result = gr.Textbox(label="Response", lines=10) | |
submit_button.click( | |
generate, inputs=[text_input, pdf_upload, url_input], outputs=result | |
) | |
gr.Markdown(LICENSE) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |