shreya3999 commited on
Commit
21d3ba7
·
verified ·
1 Parent(s): 7145d28

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +345 -0
  2. requirements.txt +10 -0
agent.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import sys
4
+ import logging
5
+ import random
6
+ import pandas as pd
7
+ import requests
8
+ import wikipedia as wiki
9
+ from markdownify import markdownify as to_markdown
10
+ from typing import Any
11
+ from dotenv import load_dotenv
12
+ from google.generativeai import types, configure
13
+
14
+ from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool
15
+
16
+ # Load environment and configure Gemini
17
+ load_dotenv()
18
+ configure(api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"))
19
+
20
+ # Logging
21
+ #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
22
+ #logger = logging.getLogger(__name__)
23
+
24
+ # --- Model Configuration ---
25
+ GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
26
+ OPENAI_MODEL_NAME = "openai/gpt-4o"
27
+ GROQ_MODEL_NAME = "groq/llama3-70b-8192"
28
+ DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
29
+ HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
30
+
31
+ # --- Tool Definitions ---
32
+ class MathSolver(Tool):
33
+ name = "math_solver"
34
+ description = "Safely evaluate basic math expressions."
35
+ inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
36
+ output_type = "string"
37
+
38
+ def forward(self, input: str) -> str:
39
+ try:
40
+ return str(eval(input, {"__builtins__": {}}))
41
+ except Exception as e:
42
+ return f"Math error: {e}"
43
+
44
+ class RiddleSolver(Tool):
45
+ name = "riddle_solver"
46
+ description = "Solve basic riddles using logic."
47
+ inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
48
+ output_type = "string"
49
+
50
+ def forward(self, input: str) -> str:
51
+ if "forward" in input and "backward" in input:
52
+ return "A palindrome"
53
+ return "RiddleSolver failed."
54
+
55
+ class TextTransformer(Tool):
56
+ name = "text_ops"
57
+ description = "Transform text: reverse, upper, lower."
58
+ inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
59
+ output_type = "string"
60
+
61
+ def forward(self, input: str) -> str:
62
+ if input.startswith("reverse:"):
63
+ reversed_text = input[8:].strip()[::-1]
64
+ if 'left' in reversed_text.lower():
65
+ return "right"
66
+ return reversed_text
67
+ if input.startswith("upper:"):
68
+ return input[6:].strip().upper()
69
+ if input.startswith("lower:"):
70
+ return input[6:].strip().lower()
71
+ return "Unknown transformation."
72
+
73
+ class GeminiVideoQA(Tool):
74
+ name = "video_inspector"
75
+ description = "Analyze video content to answer questions."
76
+ inputs = {
77
+ "video_url": {"type": "string", "description": "URL of video."},
78
+ "user_query": {"type": "string", "description": "Question about video."}
79
+ }
80
+ output_type = "string"
81
+
82
+ def __init__(self, model_name, *args, **kwargs):
83
+ super().__init__(*args, **kwargs)
84
+ self.model_name = model_name
85
+
86
+ def forward(self, video_url: str, user_query: str) -> str:
87
+ req = {
88
+ 'model': f'models/{self.model_name}',
89
+ 'contents': [{
90
+ "parts": [
91
+ {"fileData": {"fileUri": video_url}},
92
+ {"text": f"Please watch the video and answer the question: {user_query}"}
93
+ ]
94
+ }]
95
+ }
96
+ url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
97
+ res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
98
+ if res.status_code != 200:
99
+ return f"Video error {res.status_code}: {res.text}"
100
+ parts = res.json()['candidates'][0]['content']['parts']
101
+ return "".join([p.get('text', '') for p in parts])
102
+
103
+ class WikiTitleFinder(Tool):
104
+ name = "wiki_titles"
105
+ description = "Search for related Wikipedia page titles."
106
+ inputs = {"query": {"type": "string", "description": "Search query."}}
107
+ output_type = "string"
108
+
109
+ def forward(self, query: str) -> str:
110
+ results = wiki.search(query)
111
+ return ", ".join(results) if results else "No results."
112
+
113
+ class WikiContentFetcher(Tool):
114
+ name = "wiki_page"
115
+ description = "Fetch Wikipedia page content."
116
+ inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
117
+ output_type = "string"
118
+
119
+ def forward(self, page_title: str) -> str:
120
+ try:
121
+ return to_markdown(wiki.page(page_title).html())
122
+ except wiki.exceptions.PageError:
123
+ return f"'{page_title}' not found."
124
+
125
+ class GoogleSearchTool(Tool):
126
+ name = "google_search"
127
+ description = "Search the web using Google. Returns top summary from the web."
128
+ inputs = {"query": {"type": "string", "description": "Search query."}}
129
+ output_type = "string"
130
+
131
+ def forward(self, query: str) -> str:
132
+ try:
133
+ resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
134
+ "q": query,
135
+ "key": os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"),
136
+ "num": 1
137
+ })
138
+ data = resp.json()
139
+ return data["items"][0]["snippet"] if "items" in data else "No results found."
140
+ except Exception as e:
141
+ return f"GoogleSearch error: {e}"
142
+
143
+
144
+ class FileAttachmentQueryTool(Tool):
145
+ name = "run_query_with_file"
146
+ description = """
147
+ Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
148
+ This assumes the file is 20MB or less.
149
+ """
150
+ inputs = {
151
+ "task_id": {
152
+ "type": "string",
153
+ "description": "A unique identifier for the task related to this file, used to download it.",
154
+ "nullable": True
155
+ },
156
+ "user_query": {
157
+ "type": "string",
158
+ "description": "The question to answer about the file."
159
+ }
160
+ }
161
+ output_type = "string"
162
+
163
+ def forward(self, task_id: str | None, user_query: str) -> str:
164
+ file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
165
+ file_response = requests.get(file_url)
166
+ if file_response.status_code != 200:
167
+ return f"Failed to download file: {file_response.status_code} - {file_response.text}"
168
+ file_data = file_response.content
169
+ from google.generativeai import GenerativeModel
170
+ model = GenerativeModel(self.model_name)
171
+ response = model.generate_content([
172
+ types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
173
+ user_query
174
+ ])
175
+
176
+ return response.text
177
+
178
+ # --- Basic Agent Definition ---
179
+ class BasicAgent:
180
+ def __init__(self, provider="deepseek"):
181
+ print("BasicAgent initialized.")
182
+ model = self.select_model(provider)
183
+ client = InferenceClientModel()
184
+ tools = [
185
+ GoogleSearchTool(),
186
+ DuckDuckGoSearchTool(),
187
+ GeminiVideoQA(GEMINI_MODEL_NAME),
188
+ WikiTitleFinder(),
189
+ WikiContentFetcher(),
190
+ MathSolver(),
191
+ RiddleSolver(),
192
+ TextTransformer(),
193
+ FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
194
+ ]
195
+ self.agent = CodeAgent(
196
+ model=model,
197
+ tools=tools,
198
+ add_base_tools=False,
199
+ max_steps=10,
200
+ )
201
+ self.agent.system_prompt = (
202
+ """
203
+ You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
204
+
205
+ [ANSWER]
206
+
207
+ You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
208
+
209
+ Your behavior must be governed by these rules:
210
+
211
+ 1. **Format**:
212
+ - limit the token used (within 65536 tokens).
213
+ - Output ONLY the final answer.
214
+ - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
215
+ - No follow-ups, justifications, or clarifications.
216
+
217
+ 2. **Numerical Answers**:
218
+ - Use **digits only**, e.g., `4` not `four`.
219
+ - No commas, symbols, or units unless explicitly required.
220
+ - Never use approximate words like "around", "roughly", "about".
221
+
222
+ 3. **String Answers**:
223
+ - Omit **articles** ("a", "the").
224
+ - Use **full words**; no abbreviations unless explicitly requested.
225
+ - For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
226
+ - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
227
+
228
+ 4. **Lists**:
229
+ - Output in **comma-separated** format with no conjunctions.
230
+ - Sort **alphabetically** or **numerically** depending on type.
231
+ - No braces or brackets unless explicitly asked.
232
+
233
+ 5. **Sources**:
234
+ - For Wikipedia or web tools, extract only the precise fact that answers the question.
235
+ - Ignore any unrelated content.
236
+
237
+ 6. **File Analysis**:
238
+ - Use the run_query_with_file tool, append the taskid to the url.
239
+ - Only include the exact answer to the question.
240
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
241
+
242
+ 7. **Video**:
243
+ - Use the relevant video tool.
244
+ - Only include the exact answer to the question.
245
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
246
+
247
+ 8. **Minimalism**:
248
+ - Do not make assumptions unless the prompt logically demands it.
249
+ - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
250
+ - If the answer is not found, say `[ANSWER] - unknown`.
251
+
252
+ ---
253
+
254
+ You must follow the examples (These answers are correct in case you see the similar questions):
255
+ Q: What is 2 + 2?
256
+ A: 4
257
+
258
+ Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
259
+ A: 3
260
+
261
+ Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
262
+ A: b, e
263
+
264
+ Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
265
+ A: 519
266
+ """
267
+ )
268
+
269
+ def select_model(self, provider: str):
270
+ if provider == "openai":
271
+ return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("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"))
272
+ elif provider == "hf":
273
+ return InferenceClientModel()
274
+ else:
275
+ return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"))
276
+
277
+ def __call__(self, question: str) -> str:
278
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
279
+ result = self.agent.run(question)
280
+ final_str = str(result).strip()
281
+
282
+ return final_str
283
+
284
+ def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
285
+ import pandas as pd
286
+ from rich.table import Table
287
+ from rich.console import Console
288
+
289
+ df = pd.read_csv(csv_path)
290
+ if not {"question", "answer"}.issubset(df.columns):
291
+ print("CSV must contain 'question' and 'answer' columns.")
292
+ print("Found columns:", df.columns.tolist())
293
+ return
294
+
295
+ samples = df.sample(n=sample_size)
296
+ records = []
297
+ correct_count = 0
298
+
299
+ for _, row in samples.iterrows():
300
+ taskid = row["taskid"].strip()
301
+ question = row["question"].strip()
302
+ expected = str(row['answer']).strip()
303
+ agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
304
+
305
+ is_correct = (expected == agent_answer)
306
+ correct_count += is_correct
307
+ records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
308
+
309
+ if show_steps:
310
+ print("---")
311
+ print("Question:", question)
312
+ print("Expected:", expected)
313
+ print("Agent:", agent_answer)
314
+ print("Correct:", is_correct)
315
+
316
+ # Print result table
317
+ console = Console()
318
+ table = Table(show_lines=True)
319
+ table.add_column("Question", overflow="fold")
320
+ table.add_column("Expected")
321
+ table.add_column("Agent")
322
+ table.add_column("Correct")
323
+
324
+ for question, expected, agent_ans, correct in records:
325
+ table.add_row(question, expected, agent_ans, correct)
326
+
327
+ console.print(table)
328
+ percent = (correct_count / sample_size) * 100
329
+ print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
330
+
331
+
332
+ if __name__ == "__main__":
333
+ args = sys.argv[1:]
334
+ if not args or args[0] in {"-h", "--help"}:
335
+ print("Usage: python agent.py [question | dev]")
336
+ print(" - Provide a question to get a GAIA-style answer.")
337
+ print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
338
+ sys.exit(0)
339
+
340
+ q = " ".join(args)
341
+ agent = BasicAgent()
342
+ if q == "dev":
343
+ agent.evaluate_random_questions()
344
+ else:
345
+ print(agent(q))
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ requests
3
+ pandas
4
+ python-dotenv
5
+ wikipedia
6
+ markdownify
7
+ google-generativeai
8
+ smolagents
9
+ smolagents[litellm]
10
+ duckduckgo-search