sunbal7 commited on
Commit
bf18560
·
verified ·
1 Parent(s): 7c8d482

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -7,21 +7,37 @@ from sentence_transformers import SentenceTransformer
7
  import fitz # PyMuPDF for better PDF extraction
8
  from langchain_text_splitters import RecursiveCharacterTextSplitter
9
 
 
10
  # Configuration
11
  MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
12
- EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
13
- CHUNK_SIZE = 512
14
- CHUNK_OVERLAP = 64
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Initialize session state
18
- if "docs" not in st.session_state:
19
- st.session_state.docs = []
20
- if "index" not in st.session_state:
21
- st.session_state.index = None
22
-
23
- # Model loading with better error handling
24
  @st.cache_resource
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def load_models():
26
  try:
27
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
7
  import fitz # PyMuPDF for better PDF extraction
8
  from langchain_text_splitters import RecursiveCharacterTextSplitter
9
 
10
+
11
  # Configuration
12
  MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
 
 
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
 
 
 
 
 
 
 
15
  @st.cache_resource
16
+ def load_model():
17
+ try:
18
+ # Load with explicit configuration
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
+ MODEL_NAME,
21
+ trust_remote_code=True,
22
+ revision="main"
23
+ )
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ MODEL_NAME,
27
+ device_map="auto" if DEVICE == "cuda" else None,
28
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
29
+ trust_remote_code=True,
30
+ revision="main",
31
+ low_cpu_mem_usage=True
32
+ )
33
+ return model, tokenizer
34
+ except Exception as e:
35
+ st.error(f"Model loading failed: {str(e)}")
36
+ st.stop()
37
+
38
+ model, tokenizer = load_model()
39
+
40
+
41
  def load_models():
42
  try:
43
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)