CM2000112 / internals /pipelines /img_to_text.py
jayparmr's picture
Upload 118 files
19b3da3
raw
history blame
1.07 kB
import re
import torch
from torchvision import transforms
from transformers import BlipForConditionalGeneration, BlipProcessor
from internals.util.commons import download_image
class Image2Text:
def load(self):
self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
).to("cuda")
def process(self, imageUrl: str) -> str:
image = download_image(imageUrl).resize((512, 512))
inputs = self.processor.__call__(image, return_tensors="pt").to(
"cuda", torch.float16
)
output_ids = self.model.generate(
**inputs, do_sample=False, top_p=0.9, max_length=128
)
output_text = self.processor.batch_decode(output_ids)
print(output_text)
output_text = output_text[0]
output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text)
return output_text