ZhengPeng7 commited on
Commit
8bcc400
·
1 Parent(s): 3218aa5

Update the model switch part.

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -36,8 +36,8 @@ class ImagePreprocessor():
36
 
37
 
38
  from transformers import AutoModelForImageSegmentation
39
- model_path = 'zhengpeng7/BiRefNet'
40
- birefnet = AutoModelForImageSegmentation.from_pretrained(model_path, trust_remote_code=True)
41
  birefnet.to(device)
42
  birefnet.eval()
43
 
@@ -46,10 +46,13 @@ birefnet.eval()
46
  # images = [image_1, image_2]
47
  @spaces.GPU
48
  def predict(image, resolution, weights_file):
49
- # Load BiRefNet with chosen weights
50
- birefnet = AutoModelForImageSegmentation.from_pretrained(weights_file, trust_remote_code=True)
51
- birefnet.to(device)
52
- birefnet.eval()
 
 
 
53
 
54
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
55
  # Image is a RGB numpy array.
 
36
 
37
 
38
  from transformers import AutoModelForImageSegmentation
39
+ weights_path = 'zhengpeng7/BiRefNet'
40
+ birefnet = AutoModelForImageSegmentation.from_pretrained(weights_path, trust_remote_code=True)
41
  birefnet.to(device)
42
  birefnet.eval()
43
 
 
46
  # images = [image_1, image_2]
47
  @spaces.GPU
48
  def predict(image, resolution, weights_file):
49
+ global weights_path
50
+ if weights_file != weights_path:
51
+ # Load BiRefNet with chosen weights
52
+ birefnet = AutoModelForImageSegmentation.from_pretrained(weights_file if weights_file is not None else 'zhengpeng7/BiRefNet', trust_remote_code=True)
53
+ birefnet.to(device)
54
+ birefnet.eval()
55
+ # weights_path = weights_file
56
 
57
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
58
  # Image is a RGB numpy array.