paloma99 commited on
Commit
df7209b
·
verified ·
1 Parent(s): bb0d917

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -7
app.py CHANGED
@@ -2,10 +2,43 @@ import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import theme
5
- import chatbot
6
 
7
  theme = theme.Theme()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Cell 1: Image Classification Model
10
  image_pipeline = pipeline(task="image-classification", model="guillen/vit-basura-test1")
11
 
@@ -23,14 +56,82 @@ image_gradio_app = gr.Interface(
23
 
24
  # Cell 2: Chatbot Model
25
 
26
- def qa_response(user_message, chat_history, context):
27
- response = qa_chain.predict(user_message, chat_history, context=context)
28
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  chatbot_gradio_app = gr.ChatInterface(
31
- fn=qa_response,
32
- title="Green Greta",
33
- theme=theme
34
  )
35
 
36
  # Combine both interfaces into a single app
 
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import theme
 
5
 
6
  theme = theme.Theme()
7
 
8
+
9
+
10
+ import os
11
+ import sys
12
+ sys.path.append('../..')
13
+
14
+ #langchain
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
16
+ from langchain.embeddings import HuggingFaceEmbeddings
17
+ from langchain.prompts import PromptTemplate
18
+ from langchain.chains import RetrievalQA
19
+ from langchain.prompts import ChatPromptTemplate
20
+ from langchain.schema import StrOutputParser
21
+ from langchain.schema.runnable import Runnable
22
+ from langchain.schema.runnable.config import RunnableConfig
23
+ from langchain.chains import (
24
+ LLMChain, ConversationalRetrievalChain)
25
+ from langchain.vectorstores import Chroma
26
+ from langchain.memory import ConversationBufferMemory
27
+ from langchain.chains import LLMChain
28
+ from langchain.prompts.prompt import PromptTemplate
29
+ from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
30
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
31
+ from langchain.document_loaders import PyPDFDirectoryLoader
32
+
33
+ from langchain_community.llms import HuggingFaceHub
34
+
35
+ from pydantic import BaseModel
36
+ import shutil
37
+
38
+
39
+
40
+
41
+
42
  # Cell 1: Image Classification Model
43
  image_pipeline = pipeline(task="image-classification", model="guillen/vit-basura-test1")
44
 
 
56
 
57
  # Cell 2: Chatbot Model
58
 
59
+ loader = PyPDFDirectoryLoader('pdfs')
60
+ data=loader.load()
61
+ # split documents
62
+ text_splitter = RecursiveCharacterTextSplitter(
63
+ chunk_size=500,
64
+ chunk_overlap=70,
65
+ length_function=len
66
+ )
67
+ docs = text_splitter.split_documents(data)
68
+ # define embedding
69
+ embeddings = HuggingFaceEmbeddings(model_name='thenlper/gte-small')
70
+ # create vector database from data
71
+ persist_directory = 'docs/chroma/'
72
+
73
+ # Remove old database files if any
74
+ shutil.rmtree(persist_directory, ignore_errors=True)
75
+ vectordb = Chroma.from_documents(
76
+ documents=docs,
77
+ embedding=embeddings,
78
+ persist_directory=persist_directory
79
+ )
80
+ # define retriever
81
+ retriever = vectordb.as_retriever(search_type="mmr")
82
+ template = """
83
+ Your name is AngryGreta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish /
84
+ Use the following pieces of context to answer the question if the question is related with recycling /
85
+ No more than two chunks of context /
86
+ Answer in the same language of the question /
87
+ Always say "thanks for asking!" at the end of the answer /
88
+ If the context is not relevant, please answer the question by using your own knowledge about the topic.
89
+
90
+ context: {context}
91
+ question: {question}
92
+ """
93
+
94
+ # Create the chat prompt templates
95
+ system_prompt = SystemMessagePromptTemplate.from_template(template)
96
+ qa_prompt = ChatPromptTemplate(
97
+ messages=[
98
+ system_prompt,
99
+ MessagesPlaceholder(variable_name="chat_history"),
100
+ HumanMessagePromptTemplate.from_template("{question}")
101
+ ]
102
+ )
103
+ llm = HuggingFaceHub(
104
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
105
+ task="text-generation",
106
+ model_kwargs={
107
+ "max_new_tokens": 1024,
108
+ "top_k": 30,
109
+ "temperature": 0.1,
110
+ "repetition_penalty": 1.03,
111
+ },
112
+ )
113
+
114
+ memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
115
+
116
+ qa_chain = ConversationalRetrievalChain.from_llm(
117
+ llm = llm,
118
+ memory = memory,
119
+ retriever = retriever,
120
+ verbose = True,
121
+ combine_docs_chain_kwargs={'prompt': qa_prompt},
122
+ get_chat_history = lambda h : h,
123
+ rephrase_question = False,
124
+ output_key = 'answer'
125
+ )
126
+
127
+ def chat_interface(question,history):
128
+
129
+ result = qa_chain.invoke({"question": question})
130
+ return result['answer'] # If the result is a string, return it directly
131
 
132
  chatbot_gradio_app = gr.ChatInterface(
133
+ fn=chat_interface,
134
+ title='Green Greta'
 
135
  )
136
 
137
  # Combine both interfaces into a single app