KvrParaskevi commited on
Commit
c1e8c7e
·
verified ·
1 Parent(s): ec4d267

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +6 -2
chatbot.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from langchain.memory import ConversationBufferMemory
3
  from langchain.chains import ConversationChain
4
  import langchain.globals
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import streamlit as st
7
 
8
  my_model_id = os.getenv('MODEL_REPO_ID', 'Default Value')
@@ -10,8 +10,12 @@ token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
10
 
11
  @st.cache_resource
12
  def load_model():
 
 
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(my_model_id)
14
- model = AutoModelForCausalLM.from_pretrained(my_model_id)
15
 
16
  return tokenizer,model
17
 
 
2
  from langchain.memory import ConversationBufferMemory
3
  from langchain.chains import ConversationChain
4
  import langchain.globals
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  import streamlit as st
7
 
8
  my_model_id = os.getenv('MODEL_REPO_ID', 'Default Value')
 
10
 
11
  @st.cache_resource
12
  def load_model():
13
+ quantization_config = BitsAndBytesConfig(
14
+ load_in_8bit=True,
15
+ # bnb_4bit_compute_dtype=torch.bfloat16
16
+ )
17
  tokenizer = AutoTokenizer.from_pretrained(my_model_id)
18
+ model = AutoModelForCausalLM.from_pretrained(my_model_id, device_map="auto",quantization_config=quantization_config)
19
 
20
  return tokenizer,model
21