ZhengPeng7 commited on
Commit
1352148
·
1 Parent(s): e000df2

Upgrade the weights loading method to avoid duplicated loading.

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -40,6 +40,7 @@ weights_path = 'BiRefNet'
40
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', weights_path)), trust_remote_code=True)
41
  birefnet.to(device)
42
  birefnet.eval()
 
43
 
44
  usage_to_weights_file = {
45
  'General': 'BiRefNet',
@@ -53,17 +54,16 @@ usage_to_weights_file = {
53
 
54
  @spaces.GPU
55
  def predict(image, resolution, weights_file):
56
- global weights_path
57
  global birefnet
58
- if weights_file != weights_path:
59
  # Load BiRefNet with chosen weights
60
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
61
  print('Change weights to:', _weights_file)
62
- print('\t', weights_file, weights_path)
63
- birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
64
  birefnet.to(device)
65
  birefnet.eval()
66
- weights_path = weights_file
67
 
68
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
69
  # Image is a RGB numpy array.
 
40
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', weights_path)), trust_remote_code=True)
41
  birefnet.to(device)
42
  birefnet.eval()
43
+ birefnet.weights_path = weights_path
44
 
45
  usage_to_weights_file = {
46
  'General': 'BiRefNet',
 
54
 
55
  @spaces.GPU
56
  def predict(image, resolution, weights_file):
 
57
  global birefnet
58
+ if birefnet.weights_path != weights_file:
59
  # Load BiRefNet with chosen weights
60
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
61
  print('Change weights to:', _weights_file)
62
+ print('\t', weights_file, birefnet.weights_path)
63
+ birefnet = birefnet.from_pretrained(_weights_file)
64
  birefnet.to(device)
65
  birefnet.eval()
66
+ birefnet.weights_path = weights_file
67
 
68
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
69
  # Image is a RGB numpy array.