Drag2121 commited on
Commit
fe312ba
·
1 Parent(s): f98b266
Files changed (1) hide show
  1. app.py +7 -15
app.py CHANGED
@@ -7,6 +7,8 @@ from langchain_community.llms import Ollama
7
  from langchain_core.messages import HumanMessage
8
  import logging
9
  from functools import lru_cache
 
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
@@ -16,7 +18,8 @@ MODEL_NAME = 'gemma2:2b'
16
 
17
  @lru_cache()
18
  def get_llm():
19
- return Ollama(model=MODEL_NAME)
 
20
 
21
  class Question(BaseModel):
22
  text: str
@@ -45,20 +48,9 @@ async def ask_question_stream(question: Question):
45
 
46
  async def generate():
47
  full_response = ""
48
- buffer = ""
49
  async for chunk in llm.astream(question.text):
50
- buffer += chunk
51
- words = re.findall(r'\S+|\s+', buffer)
52
-
53
- for word in words[:-1]:
54
- full_response += word
55
- yield word
56
-
57
- buffer = words[-1] if words else ""
58
-
59
- if buffer:
60
- full_response += buffer
61
- yield buffer
62
 
63
  # Log the full response after streaming is complete
64
  logger.info(f"Full streamed response: {full_response}")
@@ -67,7 +59,7 @@ async def ask_question_stream(question: Question):
67
  except Exception as e:
68
  logger.error(f"Error in /ask_stream endpoint: {str(e)}")
69
  raise HTTPException(status_code=500, detail=str(e))
70
-
71
  @app.on_event("startup")
72
  async def startup_event():
73
  logger.info(f"Starting up with model: {MODEL_NAME}")
 
7
  from langchain_core.messages import HumanMessage
8
  import logging
9
  from functools import lru_cache
10
+ from langchain.callbacks.manager import CallbackManager
11
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
18
 
19
  @lru_cache()
20
  def get_llm():
21
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
22
+ return Ollama(model=MODEL_NAME, callback_manager=callback_manager)
23
 
24
  class Question(BaseModel):
25
  text: str
 
48
 
49
  async def generate():
50
  full_response = ""
 
51
  async for chunk in llm.astream(question.text):
52
+ full_response += chunk
53
+ yield chunk
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Log the full response after streaming is complete
56
  logger.info(f"Full streamed response: {full_response}")
 
59
  except Exception as e:
60
  logger.error(f"Error in /ask_stream endpoint: {str(e)}")
61
  raise HTTPException(status_code=500, detail=str(e))
62
+
63
  @app.on_event("startup")
64
  async def startup_event():
65
  logger.info(f"Starting up with model: {MODEL_NAME}")