| from dataclasses import dataclass |
| import json |
| from typing import List, Dict, Any, Optional |
| from openai import OpenAI |
| """ |
| EXAMPLE OUTPUT: |
| |
| **************************************** |
| RUNNING QUERY: What's the weather for Paris, TX in fahrenheit? |
| |
| Agent Issued Step 1 |
| ---------------------------------------- |
| |
| Agent Issued Step 2 |
| ---------------------------------------- |
| |
| Agent Issued Step 3 |
| ---------------------------------------- |
| AGENT MESSAGE: The current weather in Paris, TX is 85 degrees fahrenheit. It is partly cloudy, with highs in the 90s. |
| Conversation Complete |
| |
| |
| **************************************** |
| RUNNING QUERY: Who won the most recent PGA? |
| |
| Agent Issued Step 1 |
| ---------------------------------------- |
| |
| Agent Issued Step 2 |
| ---------------------------------------- |
| AGENT MESSAGE: I'm sorry, but I don't have the ability to provide sports information. I can help you with weather and location data. Is there anything else I can assist you with? |
| Conversation Complete |
| """ |
|
|
| @dataclass |
| class WeatherConfig: |
| """Configuration for OpenAI and API settings""" |
| api_key: str = "" |
| api_base: str = "" |
| model: Optional[str] = None |
| max_steps: int = 5 |
|
|
| class WeatherTools: |
| """Collection of available tools/functions for the weather agent""" |
|
|
| @staticmethod |
| def get_current_weather(latitude: List[float], longitude: List[float], unit: str) -> str: |
| """Get weather for given coordinates""" |
| |
| return f"The weather is 85 degrees {unit}. It is partly cloudy, with highs in the 90's." |
|
|
| @staticmethod |
| def get_geo_coordinates(city: str, state: str) -> str: |
| """Get coordinates for a given city""" |
| coordinates = { |
| "Dallas": {"TX": (32.7767, -96.7970)}, |
| "San Francisco": {"CA": (37.7749, -122.4194)}, |
| "Paris": {"TX": (33.6609, 95.5555)} |
| } |
| lat, lon = coordinates.get(city, {}).get(state, (0, 0)) |
| |
| return f"The coordinates for {city}, {state} are: latitude {lat}, longitude {lon}" |
|
|
| @staticmethod |
| def no_relevant_function(user_query_span : str) -> str: |
| return "No relevant function for your request was found. We will stop here." |
|
|
| @staticmethod |
| def chat(chat_string : str): |
| print ("AGENT MESSAGE: ", chat_string) |
|
|
| class ToolRegistry: |
| """Registry of available tools and their schemas""" |
|
|
| @property |
| def available_functions(self) -> Dict[str, callable]: |
| return { |
| "get_current_weather": WeatherTools.get_current_weather, |
| "get_geo_coordinates": WeatherTools.get_geo_coordinates, |
| "no_relevant_function" : WeatherTools.no_relevant_function, |
| "chat" : WeatherTools.chat |
| } |
|
|
| @property |
| def tool_schemas(self) -> List[Dict[str, Any]]: |
| return [ |
| { |
| "type": "function", |
| "function": { |
| "name": "get_current_weather", |
| "description": "Get the current weather in a given location. Use exact coordinates.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "latitude": {"type": "array", "description": "The latitude for the city."}, |
| "longitude": {"type": "array", "description": "The longitude for the city."}, |
| "unit": { |
| "type": "string", |
| "description": "The unit to fetch the temperature in", |
| "enum": ["celsius", "fahrenheit"] |
| } |
| }, |
| "required": ["latitude", "longitude", "unit"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "get_geo_coordinates", |
| "description": "Get the latitude and longitude for a given city", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "city": {"type": "string", "description": "The city to find coordinates for"}, |
| "state": {"type": "string", "description": "The two-letter state abbreviation"} |
| }, |
| "required": ["city", "state"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function" : { |
| "name": "no_relevant_function", |
| "description": "Call this when no other provided function can be called to answer the user query.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "user_query_span": { |
| "type": "string", |
| "description": "The part of the user_query that cannot be answered by any other function calls." |
| } |
| }, |
| "required": ["user_query_span"] |
| } |
| } |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "chat", |
| "description": "Call this tool when you want to chat with the user. The user won't see anything except for whatever you pass into this function.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "chat_string": { |
| "type": "string", |
| "description": "The string to send to the user to chat back to them.", |
| } |
| }, |
| "required": ["chat_string"], |
| }, |
| }, |
| }, |
| ] |
|
|
| class WeatherAgent: |
| """Main agent class that handles the conversation and tool execution""" |
|
|
| def __init__(self, config: WeatherConfig): |
| self.config = config |
| self.client = OpenAI(api_key=config.api_key, base_url=config.api_base) |
| self.tools = ToolRegistry() |
| self.messages = [] |
|
|
| if not config.model: |
| models = self.client.models.list() |
| self.config.model = models.data[0].id |
|
|
| def _serialize_tool_call(self, tool_call) -> Dict[str, Any]: |
| """Convert tool call to serializable format""" |
| return { |
| "id": tool_call.id, |
| "type": tool_call.type, |
| "function": { |
| "name": tool_call.function.name, |
| "arguments": tool_call.function.arguments |
| } |
| } |
|
|
| def process_tool_calls(self, message) -> None: |
| """Process and execute tool calls from assistant""" |
| for tool_call in message.tool_calls: |
| function_name = tool_call.function.name |
| function_args = json.loads(tool_call.function.arguments) |
|
|
| function_response = self.tools.available_functions[function_name](**function_args) |
|
|
| self.messages.append({ |
| "role": "tool", |
| "content": json.dumps(function_response), |
| "tool_call_id": tool_call.id, |
| "name": function_name |
| }) |
|
|
| def run_conversation(self, initial_query: str) -> None: |
| """Run the main conversation loop""" |
| self.messages = [ |
| {"role" : "system", "content" : "Make sure to use the chat() function to provide the final answer to the user."}, |
| {"role": "user", "content": initial_query}] |
|
|
| print ("\n" * 5) |
| print ("*" * 40) |
| print (f"RUNNING QUERY: {initial_query}") |
|
|
| for step in range(self.config.max_steps): |
|
|
| response = self.client.chat.completions.create( |
| messages=self.messages, |
| model=self.config.model, |
| tools=self.tools.tool_schemas, |
| temperature=0.0, |
| ) |
|
|
| message = response.choices[0].message |
|
|
| if not message.tool_calls: |
| print("Conversation Complete") |
| break |
|
|
| print(f"\nAgent Issued Step {step + 1}") |
| print("-" * 40) |
|
|
| self.messages.append({ |
| "role": "assistant", |
| "content": json.dumps(message.content), |
| "tool_calls": [self._serialize_tool_call(tc) for tc in message.tool_calls] |
| }) |
|
|
| self.process_tool_calls(message) |
|
|
| if step >= self.config.max_steps - 1: |
| print("Maximum steps reached") |
|
|
| def main(): |
| |
| config = WeatherConfig() |
| agent = WeatherAgent(config) |
| agent.run_conversation("What's the weather for Paris, TX in fahrenheit?") |
|
|
| |
| agent.run_conversation("Who won the most recent PGA?") |
|
|
| if __name__ == "__main__": |
| main() |