File size: 1,367 Bytes
0e7d253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import base64
import torch
from transformers import InstructBlipForConditionalGeneration, InstructBlipTokenizer

class InstructBlipHandler:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __call__(self, input_data):
        # Preprocess the input data
        inputs = self.preprocess(input_data)
        # Generate the output using the model
        outputs = self.model.generate(**inputs)
        # Postprocess the output
        result = self.postprocess(outputs)
        return result

    def preprocess(self, input_data):
        image_data = input_data["image"]
        text_prompt = input_data["text"]

        image = torch.tensor(base64.b64decode(image_data)).unsqueeze(0)
        text_inputs = self.tokenizer(text_prompt, return_tensors="pt")

        inputs = {
            "input_ids": text_inputs["input_ids"],
            "attention_mask": text_inputs["attention_mask"],
            "pixel_values": image
        }
        return inputs

    def postprocess(self, outputs):
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl")
tokenizer = InstructBlipTokenizer.from_pretrained("Salesforce/instructblip-flan-t5-xl")
handler = InstructBlipHandler(model, tokenizer)