File size: 2,614 Bytes
5fd9231 6c4cdba d869c4e 5fd9231 bf4ec84 6c4cdba bf4ec84 5fd9231 bf4ec84 6c4cdba 5fd9231 6c4cdba bf4ec84 d869c4e bf4ec84 d869c4e 5fd9231 d869c4e 5fd9231 bf4ec84 6c4cdba d880504 bf4ec84 6c4cdba bf4ec84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
import torch
import gc
import base64
import io
class EndpointHandler:
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map='auto'
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map='auto',
low_cpu_mem_usage=True
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
torch.cuda.empty_cache()
gc.collect()
inputs = data.get("inputs", {})
image_url = inputs.get("image_url")
image_data = inputs.get("image")
text_prompt = inputs.get("text_prompt", "Describe this image.")
if image_url:
try:
image = Image.open(requests.get(image_url, stream=True).raw)
except Exception as e:
return [{"error": f"Failed to load image from URL: {str(e)}"}]
elif image_data:
try:
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
except Exception as e:
return [{"error": f"Failed to decode image data: {str(e)}"}]
else:
return [{"error": "No image_url or image data provided in inputs"}]
if image.mode != "RGB":
image = image.convert("RGB")
try:
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
inputs = self.processor.process(
images=[image],
text=text_prompt
)
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
output = self.model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=self.processor.tokenizer
)
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
torch.cuda.empty_cache()
gc.collect()
return [{"generated_text": generated_text}]
except Exception as e:
return [{"error": f"Error during generation: {str(e)}"}] |