azeus commited on
Commit
94c8c98
·
1 Parent(s): a92e324

fix loading issue

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
  import gc
5
 
@@ -24,9 +24,25 @@ class SpacesHaikuGenerator:
24
  "Reflective": "Write a contemplative haiku about {name}, focusing on {traits}"
25
  }
26
 
27
- @st.cache_resource
28
  def load_model(self, model_name):
29
- """Load model with caching for Streamlit."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if self.current_model != model_name:
31
  # Clear previous model
32
  if self.loaded_model is not None:
@@ -35,22 +51,10 @@ class SpacesHaikuGenerator:
35
  torch.cuda.empty_cache()
36
  gc.collect()
37
 
38
- # Load new model
39
- self.loaded_tokenizer = AutoTokenizer.from_pretrained(
40
- self.models[model_name],
41
- trust_remote_code=True
42
- )
43
- self.loaded_model = AutoModelForCausalLM.from_pretrained(
44
- self.models[model_name],
45
- trust_remote_code=True,
46
- torch_dtype=torch.float16,
47
- low_cpu_mem_usage=True
48
- )
49
  self.current_model = model_name
50
 
51
- if torch.cuda.is_available():
52
- self.loaded_model = self.loaded_model.to("cuda")
53
-
54
  def generate_haiku(self, name, traits, model_name, style):
55
  """Generate a free-form haiku using the selected model."""
56
  self.load_model(model_name)
@@ -103,7 +107,11 @@ def main():
103
  st.write("Create unique AI-generated haikus about characters")
104
 
105
  # Initialize generator
106
- generator = SpacesHaikuGenerator()
 
 
 
 
107
 
108
  # Input fields
109
  col1, col2 = st.columns([1, 2])
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import gc
5
 
 
24
  "Reflective": "Write a contemplative haiku about {name}, focusing on {traits}"
25
  }
26
 
 
27
  def load_model(self, model_name):
28
+ """Load model with proper caching."""
29
+
30
+ @st.cache_resource
31
+ def _load_model_cached(_model_name):
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ self.models[_model_name],
34
+ trust_remote_code=True
35
+ )
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ self.models[_model_name],
38
+ trust_remote_code=True,
39
+ torch_dtype=torch.float16,
40
+ low_cpu_mem_usage=True
41
+ )
42
+ if torch.cuda.is_available():
43
+ model = model.to("cuda")
44
+ return model, tokenizer
45
+
46
  if self.current_model != model_name:
47
  # Clear previous model
48
  if self.loaded_model is not None:
 
51
  torch.cuda.empty_cache()
52
  gc.collect()
53
 
54
+ # Load new model using cached function
55
+ self.loaded_model, self.loaded_tokenizer = _load_model_cached(model_name)
 
 
 
 
 
 
 
 
 
56
  self.current_model = model_name
57
 
 
 
 
58
  def generate_haiku(self, name, traits, model_name, style):
59
  """Generate a free-form haiku using the selected model."""
60
  self.load_model(model_name)
 
107
  st.write("Create unique AI-generated haikus about characters")
108
 
109
  # Initialize generator
110
+ @st.cache_resource
111
+ def get_generator():
112
+ return SpacesHaikuGenerator()
113
+
114
+ generator = get_generator()
115
 
116
  # Input fields
117
  col1, col2 = st.columns([1, 2])