TommasoBB commited on
Commit
ad49360
·
verified ·
1 Parent(s): c2a3074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -21
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import gradio as gr
3
  from gradio_client import file
4
  import requests
@@ -16,10 +18,59 @@ from langchain_core.messages import HumanMessage
16
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
17
 
18
  # --- Models ---
19
- # Vision model for image analysis / OCR
20
- vision_model = ApiModel(model_id="FireRedTeam/FireRed-OCR", max_new_tokens=2048, temperature=0.3)
21
- math_model = ApiModel(model_id="Qwen/Qwen2.5-Math-1.5B", max_new_tokens=2048, temperature=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  #define the state
24
  class AgentState(TypedDict):
25
  question: str
@@ -129,17 +180,54 @@ Return a JSON object with the following fields:
129
  "transcribed_text": "All text visible in the image transcribed here."
130
  }}"""
131
 
132
- # Multimodal message: the vision model receives both text and image
133
- messages = [
134
- HumanMessage(content=[
135
- {"type": "text", "text": prompt_text},
136
- {"type": "image_url", "image_url": {"url": image_data_uri}}
137
- ])
138
- ]
139
- # Use the dedicated vision model (FireRed-OCR) for image analysis
140
- response = vision_model.invoke(messages)
141
- image_description = response.get("image_description", "")
142
- transcribed_text = response.get("transcribed_text", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  print(f"Image description: {image_description[:100]}...")
144
  print(f"Transcribed text: {transcribed_text[:100]}...")
145
  new_messages = state.get("messages", []) + [
@@ -184,7 +272,7 @@ Return a JSON object with the following field:
184
  }}"""
185
  messages = [HumanMessage(content=prompt)]
186
  response = model.invoke(messages)
187
- extracted_info = response.get("extracted_info", "")
188
  print(f"Extracted file info: {extracted_info[:100]}...")
189
  new_messages = state.get("messages", []) + [
190
  {"role": "system", "content": "Read and extract information from the attached file."},
@@ -202,7 +290,7 @@ def handle_math(state: AgentState) -> str:
202
  print(f"Agent is handling a math problem: {question[:50]}...")
203
  messages = [HumanMessage(content=f"Solve the following math problem step by step:\n\n{question}")]
204
  response = math_model.invoke(messages)
205
- solution = response.get("solution", "")
206
  print(f"Math solution: {solution[:100]}...")
207
  new_messages = state.get("messages", []) + [
208
  {"role": "system", "content": "Handle the question if classified as a math problem."},
@@ -236,10 +324,9 @@ Context gathered:
236
  """
237
  messages = [HumanMessage(content=prompt)]
238
  # Use the general model for final answer synthesis
239
- general_model = ApiModel(model_id="Qwen3.5-35B-A3B", max_new_tokens=2048, temperature=0.3)
240
- response = general_model.invoke(messages)
241
- raw_response = response.content if hasattr(response, 'content') else str(response)
242
-
243
  # Extract the final answer after "FINAL ANSWER:" if present
244
  if "FINAL ANSWER:" in raw_response:
245
  final_answer = raw_response.split("FINAL ANSWER:")[-1].strip()
@@ -299,7 +386,6 @@ class BasicAgent:
299
  self.image_reader = tools.ImageReaderTool()
300
  self.web_search = tools.WebSearchTool()
301
  self.tools = [self.file_reader, self.image_reader, self.web_search]
302
- self.vision_model = vision_model # FireRedTeam/FireRed-OCR for image tasks
303
  print("Agent initialized.")
304
 
305
  def __call__(self, question: str, task_id: str = "", file_name: str = "") -> str:
 
1
  import os
2
+ import base64
3
+ from io import BytesIO
4
  import gradio as gr
5
  from gradio_client import file
6
  import requests
 
18
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
19
 
20
  # --- Models ---
21
+ def _build_hf_model(model_name: str) -> HfApiModel:
22
+ """Build HfApiModel across versions that expect repo_id or model_id."""
23
+ try:
24
+ return HfApiModel(repo_id=model_name, max_new_tokens=2048, temperature=0.3)
25
+ except TypeError:
26
+ return HfApiModel(model_id=model_name, max_new_tokens=2048, temperature=0.3)
27
+
28
+
29
+ # Text/math models via smolagents
30
+ model = _build_hf_model("Qwen3.5-35B-A3B")
31
+ math_model = _build_hf_model("Qwen/Qwen2.5-Math-1.5B")
32
+
33
+ # FireRed OCR (Transformers) loaded lazily to avoid startup crashes
34
+ _fire_red_model = None
35
+ _fire_red_processor = None
36
+
37
+
38
+ def _load_fire_red_ocr():
39
+ """Lazy-load FireRed OCR model and processor using Transformers."""
40
+ global _fire_red_model, _fire_red_processor
41
+ if _fire_red_model is not None and _fire_red_processor is not None:
42
+ return _fire_red_model, _fire_red_processor
43
+
44
+ import torch
45
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
46
 
47
+ _fire_red_model = Qwen3VLForConditionalGeneration.from_pretrained(
48
+ "FireRedTeam/FireRed-OCR",
49
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
50
+ device_map="auto",
51
+ )
52
+ _fire_red_processor = AutoProcessor.from_pretrained("FireRedTeam/FireRed-OCR")
53
+ return _fire_red_model, _fire_red_processor
54
+
55
+
56
+ def _extract_text_from_response(response: Any) -> str:
57
+ """Normalize model responses into plain text."""
58
+ if response is None:
59
+ return ""
60
+ if isinstance(response, str):
61
+ return response
62
+ if isinstance(response, dict):
63
+ for key in ("content", "answer", "output", "text", "solution", "extracted_info"):
64
+ if key in response and response[key] is not None:
65
+ return str(response[key])
66
+ return str(response)
67
+ content = getattr(response, "content", None)
68
+ if content is not None:
69
+ return str(content)
70
+ return str(response)
71
+
72
+
73
+
74
  #define the state
75
  class AgentState(TypedDict):
76
  question: str
 
180
  "transcribed_text": "All text visible in the image transcribed here."
181
  }}"""
182
 
183
+
184
+
185
+ try:
186
+ # Decode base64 data URI into bytes/PIL image
187
+ _, b64_data = image_data_uri.split(",", 1)
188
+ image_bytes = base64.b64decode(b64_data)
189
+ from PIL import Image
190
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
191
+
192
+ ocr_model, ocr_processor = _load_fire_red_ocr()
193
+
194
+ messages = [
195
+ {
196
+ "role": "user",
197
+ "content": [
198
+ {"type": "image", "image": image},
199
+ {"type": "text", "text": prompt_text},
200
+ ],
201
+ }
202
+ ]
203
+
204
+ text = ocr_processor.apply_chat_template(
205
+ messages,
206
+ tokenize=False,
207
+ add_generation_prompt=True,
208
+ )
209
+ inputs = ocr_processor(
210
+ text=[text],
211
+ images=[image],
212
+ return_tensors="pt",
213
+ padding=True,
214
+ )
215
+ inputs = {k: v.to(ocr_model.device) for k, v in inputs.items()}
216
+
217
+ generated_ids = ocr_model.generate(**inputs, max_new_tokens=2048)
218
+ prompt_len = inputs["input_ids"].shape[1]
219
+ generated_trimmed = generated_ids[:, prompt_len:]
220
+ output_text = ocr_processor.batch_decode(
221
+ generated_trimmed,
222
+ skip_special_tokens=True,
223
+ clean_up_tokenization_spaces=False,
224
+ )
225
+ ocr_text = output_text[0].strip() if output_text else ""
226
+ except Exception as e:
227
+ ocr_text = f"OCR error: {e}"
228
+
229
+ image_description = ocr_text
230
+ transcribed_text = ocr_text
231
  print(f"Image description: {image_description[:100]}...")
232
  print(f"Transcribed text: {transcribed_text[:100]}...")
233
  new_messages = state.get("messages", []) + [
 
272
  }}"""
273
  messages = [HumanMessage(content=prompt)]
274
  response = model.invoke(messages)
275
+ extracted_info = _extract_text_from_response(response)
276
  print(f"Extracted file info: {extracted_info[:100]}...")
277
  new_messages = state.get("messages", []) + [
278
  {"role": "system", "content": "Read and extract information from the attached file."},
 
290
  print(f"Agent is handling a math problem: {question[:50]}...")
291
  messages = [HumanMessage(content=f"Solve the following math problem step by step:\n\n{question}")]
292
  response = math_model.invoke(messages)
293
+ solution = _extract_text_from_response(response)
294
  print(f"Math solution: {solution[:100]}...")
295
  new_messages = state.get("messages", []) + [
296
  {"role": "system", "content": "Handle the question if classified as a math problem."},
 
324
  """
325
  messages = [HumanMessage(content=prompt)]
326
  # Use the general model for final answer synthesis
327
+ response = model.invoke(messages)
328
+ raw_response = _extract_text_from_response(response)
329
+
 
330
  # Extract the final answer after "FINAL ANSWER:" if present
331
  if "FINAL ANSWER:" in raw_response:
332
  final_answer = raw_response.split("FINAL ANSWER:")[-1].strip()
 
386
  self.image_reader = tools.ImageReaderTool()
387
  self.web_search = tools.WebSearchTool()
388
  self.tools = [self.file_reader, self.image_reader, self.web_search]
 
389
  print("Agent initialized.")
390
 
391
  def __call__(self, question: str, task_id: str = "", file_name: str = "") -> str: