azeus
commited on
Commit
·
94c8c98
1
Parent(s):
a92e324
fix loading issue
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
import torch
|
4 |
import gc
|
5 |
|
@@ -24,9 +24,25 @@ class SpacesHaikuGenerator:
|
|
24 |
"Reflective": "Write a contemplative haiku about {name}, focusing on {traits}"
|
25 |
}
|
26 |
|
27 |
-
@st.cache_resource
|
28 |
def load_model(self, model_name):
|
29 |
-
"""Load model with caching
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
if self.current_model != model_name:
|
31 |
# Clear previous model
|
32 |
if self.loaded_model is not None:
|
@@ -35,22 +51,10 @@ class SpacesHaikuGenerator:
|
|
35 |
torch.cuda.empty_cache()
|
36 |
gc.collect()
|
37 |
|
38 |
-
# Load new model
|
39 |
-
self.loaded_tokenizer =
|
40 |
-
self.models[model_name],
|
41 |
-
trust_remote_code=True
|
42 |
-
)
|
43 |
-
self.loaded_model = AutoModelForCausalLM.from_pretrained(
|
44 |
-
self.models[model_name],
|
45 |
-
trust_remote_code=True,
|
46 |
-
torch_dtype=torch.float16,
|
47 |
-
low_cpu_mem_usage=True
|
48 |
-
)
|
49 |
self.current_model = model_name
|
50 |
|
51 |
-
if torch.cuda.is_available():
|
52 |
-
self.loaded_model = self.loaded_model.to("cuda")
|
53 |
-
|
54 |
def generate_haiku(self, name, traits, model_name, style):
|
55 |
"""Generate a free-form haiku using the selected model."""
|
56 |
self.load_model(model_name)
|
@@ -103,7 +107,11 @@ def main():
|
|
103 |
st.write("Create unique AI-generated haikus about characters")
|
104 |
|
105 |
# Initialize generator
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# Input fields
|
109 |
col1, col2 = st.columns([1, 2])
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
import torch
|
4 |
import gc
|
5 |
|
|
|
24 |
"Reflective": "Write a contemplative haiku about {name}, focusing on {traits}"
|
25 |
}
|
26 |
|
|
|
27 |
def load_model(self, model_name):
|
28 |
+
"""Load model with proper caching."""
|
29 |
+
|
30 |
+
@st.cache_resource
|
31 |
+
def _load_model_cached(_model_name):
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
33 |
+
self.models[_model_name],
|
34 |
+
trust_remote_code=True
|
35 |
+
)
|
36 |
+
model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
self.models[_model_name],
|
38 |
+
trust_remote_code=True,
|
39 |
+
torch_dtype=torch.float16,
|
40 |
+
low_cpu_mem_usage=True
|
41 |
+
)
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
model = model.to("cuda")
|
44 |
+
return model, tokenizer
|
45 |
+
|
46 |
if self.current_model != model_name:
|
47 |
# Clear previous model
|
48 |
if self.loaded_model is not None:
|
|
|
51 |
torch.cuda.empty_cache()
|
52 |
gc.collect()
|
53 |
|
54 |
+
# Load new model using cached function
|
55 |
+
self.loaded_model, self.loaded_tokenizer = _load_model_cached(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
self.current_model = model_name
|
57 |
|
|
|
|
|
|
|
58 |
def generate_haiku(self, name, traits, model_name, style):
|
59 |
"""Generate a free-form haiku using the selected model."""
|
60 |
self.load_model(model_name)
|
|
|
107 |
st.write("Create unique AI-generated haikus about characters")
|
108 |
|
109 |
# Initialize generator
|
110 |
+
@st.cache_resource
|
111 |
+
def get_generator():
|
112 |
+
return SpacesHaikuGenerator()
|
113 |
+
|
114 |
+
generator = get_generator()
|
115 |
|
116 |
# Input fields
|
117 |
col1, col2 = st.columns([1, 2])
|