DESUCLUB commited on
Commit
dba14a0
1 Parent(s): ffa13c1

Upload BLIPIntepret.py

Browse files
Files changed (1) hide show
  1. BLIPIntepret.py +10 -8
BLIPIntepret.py CHANGED
@@ -7,14 +7,14 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
7
  print(device)
8
 
9
  def init_BLIP(device):
 
10
  if device == 'cuda':
11
- bit_load = True
 
12
  else:
13
- bit_load = False
14
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
15
- model = Blip2ForConditionalGeneration.from_pretrained(
16
- "Salesforce/blip2-opt-2.7b", load_in_8bit= bit_load,torch_dtype=torch.float16, device_map = 'auto'
17
- )
18
  model.eval()
19
  if torch.__version__ >= "2":
20
  model = torch.compile(model)
@@ -33,8 +33,10 @@ def infer_BLIP2(model,processor,image,device):
33
  "Question: What emotion does the person or animal in the image feel? Answer:",
34
  ]
35
  for prompt in prompts:
36
- inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
37
-
 
 
38
  generated_ids = model.generate(**inputs)
39
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
40
  outputs+= prompt+generated_text+' '
 
7
  print(device)
8
 
9
  def init_BLIP(device):
10
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
11
  if device == 'cuda':
12
+ model = Blip2ForConditionalGeneration.from_pretrained(
13
+ "Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map = 'auto')
14
  else:
15
+ print('Using CPU model')
16
+ model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b",device_map={"": device}, torch_dtype=torch.float32,low_cpu_mem_usage=True)
17
+
 
 
18
  model.eval()
19
  if torch.__version__ >= "2":
20
  model = torch.compile(model)
 
33
  "Question: What emotion does the person or animal in the image feel? Answer:",
34
  ]
35
  for prompt in prompts:
36
+ if device == 'cuda':
37
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
38
+ else:
39
+ inputs = processor(images=image, text=prompt, return_tensors="pt")
40
  generated_ids = model.generate(**inputs)
41
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
42
  outputs+= prompt+generated_text+' '