ZhengPeng7 commited on
Commit
b0bc43c
·
1 Parent(s): d967d62

Upgrade the weights loading method to avoid duplicated loading.

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -56,14 +56,16 @@ birefnet.weights_path = weights_path
56
  def predict(image, resolution, weights_file):
57
  global birefnet
58
  if birefnet.weights_path != weights_file:
 
 
59
  # Load BiRefNet with chosen weights
60
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
61
  print('Change weights to:', _weights_file)
62
- print('\t', weights_file, birefnet.weights_path)
63
  birefnet = birefnet.from_pretrained(_weights_file)
64
  birefnet.to(device)
65
  birefnet.eval()
66
  birefnet.weights_path = weights_file
 
67
 
68
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
69
  # Image is a RGB numpy array.
 
56
  def predict(image, resolution, weights_file):
57
  global birefnet
58
  if birefnet.weights_path != weights_file:
59
+ print('*' * 10)
60
+ print('\t1: ', weights_file, birefnet.weights_path)
61
  # Load BiRefNet with chosen weights
62
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
63
  print('Change weights to:', _weights_file)
 
64
  birefnet = birefnet.from_pretrained(_weights_file)
65
  birefnet.to(device)
66
  birefnet.eval()
67
  birefnet.weights_path = weights_file
68
+ print('\t2: ', weights_file, birefnet.weights_path)
69
 
70
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
71
  # Image is a RGB numpy array.