venkat-srinivasan-nexusflow commited on
Commit
7d6e8c5
·
1 Parent(s): c88dbd1

Upload langdemo.py

Browse files

Migrate the previous V1 langchain demo to V2.

Files changed (1) hide show
  1. langdemo.py +147 -0
langdemo.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Union
2
+
3
+ import math
4
+
5
+ from langchain.tools.base import StructuredTool
6
+ from langchain.agents import (
7
+ Tool,
8
+ AgentExecutor,
9
+ LLMSingleActionAgent,
10
+ AgentOutputParser,
11
+ )
12
+ from langchain.schema import AgentAction, AgentFinish, OutputParserException
13
+ from langchain.prompts import StringPromptTemplate
14
+ from langchain.llms import HuggingFaceTextGenInference
15
+ from langchain.chains import LLMChain
16
+
17
+
18
+ ##########################################################
19
+ # Step 1: Define the functions you want to articulate. ###
20
+ ##########################################################
21
+
22
+
23
+ def calculator(
24
+ input_a: float,
25
+ input_b: float,
26
+ operation: Literal["add", "subtract", "multiply", "divide"],
27
+ ):
28
+ """
29
+ Computes a calculation.
30
+
31
+ Args:
32
+ input_a (float) : Required. The first input.
33
+ input_b (float) : Required. The second input.
34
+ operation (string): The operation. Choices include: add to add two numbers, subtract to subtract two numbers, multiply to multiply two numbers, and divide to divide them.
35
+ """
36
+ match operation:
37
+ case "add":
38
+ return input_a + input_b
39
+ case "subtract":
40
+ return input_a - input_b
41
+ case "multiply":
42
+ return input_a * input_b
43
+ case "divide":
44
+ return input_a / input_b
45
+
46
+
47
+ def cylinder_volume(radius, height):
48
+ """
49
+ Calculate the volume of a cylinder.
50
+
51
+ Parameters:
52
+ - radius (float): The radius of the base of the cylinder.
53
+ - height (float): The height of the cylinder.
54
+
55
+ Returns:
56
+ - float: The volume of the cylinder.
57
+ """
58
+ if radius < 0 or height < 0:
59
+ raise ValueError("Radius and height must be non-negative.")
60
+
61
+ volume = math.pi * (radius**2) * height
62
+ return volume
63
+
64
+
65
+ #############################################################
66
+ # Step 2: Let's define some utils for building the prompt ###
67
+ #############################################################
68
+
69
+
70
+ RAVEN_PROMPT = """
71
+ {raven_tools}
72
+ User Query: Question: {input}
73
+
74
+ Please pick a function from the above options that best answers the user query and fill in the appropriate arguments.<human_end>"""
75
+
76
+
77
+ # Set up a prompt template
78
+ class RavenPromptTemplate(StringPromptTemplate):
79
+ # The template to use
80
+ template: str
81
+ # The list of tools available
82
+ tools: List[Tool]
83
+
84
+ def format(self, **kwargs) -> str:
85
+ prompt = "<human>:\n"
86
+ for tool in self.tools:
87
+ func_signature, func_docstring = tool.description.split(" - ", 1)
88
+ prompt += f'\nOPTION:\n<func_start>def {func_signature}<func_end>\n<docstring_start>\n"""\n{func_docstring}\n"""\n<docstring_end>\n'
89
+ kwargs["raven_tools"] = prompt
90
+ return self.template.format(**kwargs).replace("{{", "{").replace("}}", "}")
91
+
92
+
93
+ class RavenOutputParser(AgentOutputParser):
94
+ def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
95
+ # Check if agent should finish
96
+ if "Call:" in llm_output:
97
+ return AgentFinish(
98
+ return_values={
99
+ "output": llm_output.strip()
100
+ .replace("Call:", "")
101
+ .strip()
102
+ },
103
+ log=llm_output,
104
+ )
105
+ else:
106
+ raise OutputParserException(f"Could not parse LLM output: `{llm_output}`")
107
+
108
+
109
+ ##################################################
110
+ # Step 3: Build the agent with these utilities ###
111
+ ##################################################
112
+
113
+
114
+ inference_server_url = "https://rjmy54al17scvxjr.us-east-1.aws.endpoints.huggingface.cloud"
115
+ assert (
116
+ inference_server_url is not "<YOUR ENDPOINT URL>"
117
+ ), "Please provide your own HF inference endpoint URL!"
118
+
119
+ llm = HuggingFaceTextGenInference(
120
+ inference_server_url=inference_server_url,
121
+ temperature=0.001,
122
+ max_new_tokens=400,
123
+ do_sample=False,
124
+ )
125
+ tools = [
126
+ StructuredTool.from_function(calculator),
127
+ StructuredTool.from_function(cylinder_volume),
128
+ ]
129
+ raven_prompt = RavenPromptTemplate(
130
+ template=RAVEN_PROMPT, tools=tools, input_variables=["input"]
131
+ )
132
+ llm_chain = LLMChain(llm=llm, prompt=raven_prompt)
133
+ output_parser = RavenOutputParser()
134
+ agent = LLMSingleActionAgent(
135
+ llm_chain=llm_chain,
136
+ output_parser=output_parser,
137
+ stop=["<bot_end>"],
138
+ allowed_tools=tools,
139
+ )
140
+ agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
141
+
142
+ call = agent_chain.run(
143
+ "I have a cake that is about 3 centimenters high and 200 centimeters in radius. How much cake do I have?"
144
+ )
145
+ print(eval(call))
146
+ call = agent_chain.run("What is 1+10?")
147
+ print(eval(call))