luanpoppe commited on
Commit
3b3d8b9
·
1 Parent(s): 7ea334e

feat: adicionando possibilidade de adicionar models do hugging face, mas deixando o chatGPT como padrão

Browse files
endpoint_teste/serializer.py CHANGED
@@ -16,4 +16,5 @@ class TesteSerializer(serializers.Serializer):
16
  class PDFUploadSerializer(serializers.Serializer):
17
  file = serializers.FileField()
18
  system_prompt = serializers.CharField(required=True)
19
- user_message = serializers.CharField(required=True)
 
 
16
  class PDFUploadSerializer(serializers.Serializer):
17
  file = serializers.FileField()
18
  system_prompt = serializers.CharField(required=True)
19
+ user_message = serializers.CharField(required=True)
20
+ model = serializers.CharField(required=False)
endpoint_teste/views.py CHANGED
@@ -9,6 +9,7 @@ from rest_framework.response import Response
9
 
10
  from langchain_backend.main import get_llm_answer
11
  from .serializer import TesteSerializer
 
12
 
13
  class EndpointTesteViewSet(viewsets.ModelViewSet):
14
  """Mostrará todas as tarefas"""
@@ -31,19 +32,27 @@ def getTeste(request):
31
  "Resposta": resposta_llm
32
  })
33
  if request.method == "GET":
34
- hugging_face_token = os.environ.get("hugging_face_token")
35
- API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B"
36
- headers = {"Authorization": "Bearer " + hugging_face_token}
37
- def query(payload):
38
- response = requests.post(API_URL, headers=headers, json=payload)
39
- return response.json()
40
 
41
- output = query({
42
- "inputs": "Can you please let us know more details about your something I don't know",
43
- })
44
- print('output: ', output)
45
- print('output: ', dir(output))
46
- return Response(output)
 
 
 
 
 
 
 
 
47
 
48
  @api_view(["POST"])
49
  def getPDF(request):
@@ -72,7 +81,12 @@ def getPDF(request):
72
  temp_file.write(chunk)
73
  temp_file_path = temp_file.name # Get the path of the temporary file
74
  print('temp_file_path: ', temp_file_path)
75
- resposta_llm = get_llm_answer(data["system_prompt"], data["user_message"], temp_file_path)
 
 
 
 
 
76
 
77
  os.remove(temp_file_path)
78
 
 
9
 
10
  from langchain_backend.main import get_llm_answer
11
  from .serializer import TesteSerializer
12
+ from langchain_huggingface import HuggingFaceEndpoint
13
 
14
  class EndpointTesteViewSet(viewsets.ModelViewSet):
15
  """Mostrará todas as tarefas"""
 
32
  "Resposta": resposta_llm
33
  })
34
  if request.method == "GET":
35
+ # hugging_face_token = os.environ.get("hugging_face_token")
36
+ # API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B"
37
+ # headers = {"Authorization": "Bearer " + hugging_face_token}
38
+ # def query(payload):
39
+ # response = requests.post(API_URL, headers=headers, json=payload)
40
+ # return response.json()
41
 
42
+ # output = query({
43
+ # "inputs": "Can you please let us know more details about your something I don't know",
44
+ # })
45
+ # print('output: ', output)
46
+ # print('output: ', dir(output))
47
+ llm = HuggingFaceEndpoint(
48
+ repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
49
+ task="text-generation",
50
+ max_new_tokens=100,
51
+ do_sample=False,
52
+ )
53
+ result = llm.invoke("Hugging Face is")
54
+ print('result: ', result)
55
+ return Response(result)
56
 
57
  @api_view(["POST"])
58
  def getPDF(request):
 
81
  temp_file.write(chunk)
82
  temp_file_path = temp_file.name # Get the path of the temporary file
83
  print('temp_file_path: ', temp_file_path)
84
+
85
+ resposta_llm = None
86
+ if serializer.validated_data['model']:
87
+ resposta_llm = get_llm_answer(data["system_prompt"], data["user_message"], temp_file_path, model=serializer.validated_data['model'])
88
+ else:
89
+ resposta_llm = get_llm_answer(data["system_prompt"], data["user_message"], temp_file_path)
90
 
91
  os.remove(temp_file_path)
92
 
langchain_backend/main.py CHANGED
@@ -4,13 +4,17 @@ from langchain.chains import create_retrieval_chain
4
 
5
  os.environ.get("OPENAI_API_KEY")
6
 
7
- def get_llm_answer(system_prompt, user_prompt, pdf_url):
8
  pages = None
9
  if pdf_url:
10
  pages = getPDF(pdf_url)
11
  else:
12
  pages = getPDF()
13
  retriever = create_retriever(pages)
14
- rag_chain = create_retrieval_chain(retriever, create_prompt_llm_chain(system_prompt))
 
 
 
 
15
  results = rag_chain.invoke({"input": user_prompt})
16
  return results
 
4
 
5
  os.environ.get("OPENAI_API_KEY")
6
 
7
+ def get_llm_answer(system_prompt, user_prompt, pdf_url, model):
8
  pages = None
9
  if pdf_url:
10
  pages = getPDF(pdf_url)
11
  else:
12
  pages = getPDF()
13
  retriever = create_retriever(pages)
14
+ rag_chain = None
15
+ if model:
16
+ rag_chain = create_retrieval_chain(retriever, create_prompt_llm_chain(system_prompt, model))
17
+ else:
18
+ rag_chain = create_retrieval_chain(retriever, create_prompt_llm_chain(system_prompt))
19
  results = rag_chain.invoke({"input": user_prompt})
20
  return results
langchain_backend/utils.py CHANGED
@@ -28,8 +28,8 @@ def create_retriever(documents):
28
 
29
  return retriever
30
 
31
- def create_prompt_llm_chain(system_prompt):
32
- model = ChatOpenAI(model="gpt-4o-mini")
33
 
34
  system_prompt = system_prompt + "\n\n" + "{context}"
35
  prompt = ChatPromptTemplate.from_messages(
 
28
 
29
  return retriever
30
 
31
+ def create_prompt_llm_chain(system_prompt, model="gpt-4o-mini"):
32
+ model = ChatOpenAI(model=model)
33
 
34
  system_prompt = system_prompt + "\n\n" + "{context}"
35
  prompt = ChatPromptTemplate.from_messages(
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ