Spaces:
Sleeping
Sleeping
Commit
·
0a4529c
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.example +81 -0
- .gitignore +38 -0
- Dockerfile +36 -0
- README.md +726 -0
- app.py +2224 -0
- chunking/__init__.py +0 -0
- chunking/adaptive_selector.py +465 -0
- chunking/base_chunker.py +405 -0
- chunking/fixed_chunker.py +376 -0
- chunking/hierarchical_chunker.py +280 -0
- chunking/llamaindex_chunker.py +240 -0
- chunking/overlap_manager.py +421 -0
- chunking/semantic_chunker.py +607 -0
- chunking/token_counter.py +520 -0
- config/__init__.py +0 -0
- config/logging_config.py +301 -0
- config/models.py +697 -0
- config/settings.py +247 -0
- docs/API.md +904 -0
- docs/ARCHITECTURE.md +882 -0
- document_parser/__init__.py +0 -0
- document_parser/docx_parser.py +425 -0
- document_parser/ocr_engine.py +781 -0
- document_parser/parser_factory.py +534 -0
- document_parser/pdf_parser.py +908 -0
- document_parser/txt_parser.py +336 -0
- document_parser/zip_handler.py +716 -0
- embeddings/__init__.py +0 -0
- embeddings/batch_processor.py +365 -0
- embeddings/bge_embedder.py +351 -0
- embeddings/embedding_cache.py +331 -0
- embeddings/model_loader.py +327 -0
- evaluation/ragas_evaluator.py +667 -0
- frontend/index.html +0 -0
- generation/__init__.py +0 -0
- generation/citation_formatter.py +378 -0
- generation/general_responder.py +155 -0
- generation/llm_client.py +396 -0
- generation/prompt_builder.py +542 -0
- generation/query_classifier.py +247 -0
- generation/response_generator.py +880 -0
- generation/temperature_controller.py +430 -0
- generation/token_manager.py +431 -0
- ingestion/__init__.py +0 -0
- ingestion/async_coordinator.py +414 -0
- ingestion/progress_tracker.py +518 -0
- ingestion/router.py +516 -0
- requirements.txt +135 -0
- retrieval/__init__.py +0 -0
- retrieval/citation_tracker.py +452 -0
.env.example
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- HF Spaces Deployment Settings ---
|
| 2 |
+
|
| 3 |
+
# --- Application Settings ---
|
| 4 |
+
APP_NAME=AI Universal Knowledge Ingestion System
|
| 5 |
+
APP_VERSION=1.0.0
|
| 6 |
+
PORT=7860
|
| 7 |
+
HOST=0.0.0.0
|
| 8 |
+
DEBUG=false
|
| 9 |
+
MAX_FILE_SIZE_MB=100
|
| 10 |
+
MAX_BATCH_FILES=5
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# LLM Provider Selection
|
| 14 |
+
OLLAMA_ENABLED=false
|
| 15 |
+
USE_OPENAI=true
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# OpenAI API Key (set this in HF Space Secrets tab)
|
| 19 |
+
OPENAI_API_KEY=sk-your-actual-key-here
|
| 20 |
+
OPENAI_MODEL=gpt-4o-mini
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# --- Generation Parameters ---
|
| 24 |
+
DEFAULT_TEMPERATURE=0.1
|
| 25 |
+
TOP_P=0.9
|
| 26 |
+
MAX_TOKENS=1000
|
| 27 |
+
CONTEXT_WINDOW=8192
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# --- Embedding Settings ---
|
| 31 |
+
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
| 32 |
+
EMBEDDING_DIMENSION=384
|
| 33 |
+
EMBEDDING_DEVICE=cpu
|
| 34 |
+
EMBEDDING_BATCH_SIZE=16
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# --- Chunking Settings ---
|
| 38 |
+
FIXED_CHUNK_SIZE=384
|
| 39 |
+
FIXED_CHUNK_OVERLAP=20
|
| 40 |
+
FIXED_CHUNK_STRATEGY=fixed
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# --- Retrieval Settings ---
|
| 44 |
+
TOP_K_RETRIEVE=10
|
| 45 |
+
TOP_K_FINAL=5
|
| 46 |
+
FAISS_NPROBE=16
|
| 47 |
+
VECTOR_WEIGHT=0.6
|
| 48 |
+
BM25_WEIGHT=0.4
|
| 49 |
+
BM25_K1=1.5
|
| 50 |
+
BM25_B=0.75
|
| 51 |
+
ENABLE_RERANKING=true
|
| 52 |
+
RERANKER_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# --- Storage Settings ---
|
| 56 |
+
VECTOR_STORE_DIR=./data/vector_store
|
| 57 |
+
METADATA_DB_PATH=./data/metadata.db
|
| 58 |
+
AUTO_BACKUP=false
|
| 59 |
+
BACKUP_DIR=./data/backups
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# --- Cache Settings ---
|
| 63 |
+
ENABLE_CACHE=true
|
| 64 |
+
CACHE_TYPE=memory
|
| 65 |
+
CACHE_TTL=3600
|
| 66 |
+
CACHE_MAX_SIZE=500
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# --- Logging Settings ---
|
| 70 |
+
LOG_LEVEL=INFO
|
| 71 |
+
LOG_DIR=./logs
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# --- Performance Settings ---
|
| 75 |
+
MAX_WORKERS=2
|
| 76 |
+
ASYNC_BATCH_SIZE=5
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# --- RAGAS Evaluation ---
|
| 80 |
+
ENABLE_RAGAS=true
|
| 81 |
+
RAGAS_ENABLE_GROUND_TRUTH=false
|
.gitignore
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
venv/
|
| 8 |
+
env/
|
| 9 |
+
ENV/
|
| 10 |
+
|
| 11 |
+
# Data
|
| 12 |
+
data/
|
| 13 |
+
*.db
|
| 14 |
+
*.faiss
|
| 15 |
+
*.pkl
|
| 16 |
+
|
| 17 |
+
# DB
|
| 18 |
+
sqlite/
|
| 19 |
+
|
| 20 |
+
# Logs
|
| 21 |
+
logs/
|
| 22 |
+
*.log
|
| 23 |
+
|
| 24 |
+
# Environment
|
| 25 |
+
.env
|
| 26 |
+
|
| 27 |
+
# IDE
|
| 28 |
+
.vscode/
|
| 29 |
+
.idea/
|
| 30 |
+
*.swp
|
| 31 |
+
|
| 32 |
+
# OS
|
| 33 |
+
.DS_Store
|
| 34 |
+
Thumbs.db
|
| 35 |
+
|
| 36 |
+
# Test
|
| 37 |
+
notebooks/.ipynb_checkpoints/
|
| 38 |
+
__pycache__/
|
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install minimal system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
libmagic1 \
|
| 9 |
+
file \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements and install dependencies
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# Copy application code
|
| 17 |
+
COPY . .
|
| 18 |
+
|
| 19 |
+
# Create necessary directories
|
| 20 |
+
RUN mkdir -p data/uploads data/vector_store data/backups logs
|
| 21 |
+
|
| 22 |
+
# Set environment variables
|
| 23 |
+
ENV HOST=0.0.0.0
|
| 24 |
+
ENV PORT=7860
|
| 25 |
+
ENV OLLAMA_ENABLED=false
|
| 26 |
+
ENV PYTHONUNBUFFERED=1
|
| 27 |
+
|
| 28 |
+
# Expose port for HF Spaces
|
| 29 |
+
EXPOSE 7860
|
| 30 |
+
|
| 31 |
+
# Health check
|
| 32 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 33 |
+
CMD python -c "import requests; requests.get('http://localhost:7860/api/health', timeout=5)"
|
| 34 |
+
|
| 35 |
+
# Start application
|
| 36 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "120"]
|
README.md
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: QuerySphere
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
+
short_description: RAG platform for document Q&A with zero API costs (local LLM) or cloud deployment with OpenAI integration.
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
<div align="center">
|
| 14 |
+
|
| 15 |
+
# QuerySphere: RAG platform for document Q&A with Knowledge Ingestion
|
| 16 |
+
|
| 17 |
+
[](https://www.python.org/downloads/)
|
| 18 |
+
[](https://fastapi.tiangolo.com/)
|
| 19 |
+
[](https://opensource.org/licenses/MIT)
|
| 20 |
+
|
| 21 |
+
> **Enterprise-Grade RAG Platform with Multi-Format Document Ingestion, Hybrid Retrieval, and Zero API Costs**
|
| 22 |
+
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
A MVP grade Retrieval-Augmented Generation (RAG) system that enables organizations to unlock knowledge trapped across documents and archives while maintaining complete data privacy and eliminating costly API dependencies.
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## 📋 Table of Contents
|
| 32 |
+
|
| 33 |
+
- [Overview](#-overview)
|
| 34 |
+
- [Key Features](#-key-features)
|
| 35 |
+
- [System Architecture](#-system-architecture)
|
| 36 |
+
- [Technology Stack](#-technology-stack)
|
| 37 |
+
- [Installation](#-installation)
|
| 38 |
+
- [Quick Start](#-quick-start)
|
| 39 |
+
- [Core Components](#-core-components)
|
| 40 |
+
- [API Documentation](#-api-documentation)
|
| 41 |
+
- [Configuration](#-configuration)
|
| 42 |
+
- [RAGAS Evaluation](#-ragas-evaluation)
|
| 43 |
+
- [Troubleshooting](#-troubleshooting)
|
| 44 |
+
- [License](#-license)
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 🎯 Overview
|
| 49 |
+
|
| 50 |
+
The AI Universal Knowledge Ingestion System addresses a critical enterprise pain point: **information silos that cost organizations 20% of employee productivity**. Unlike existing solutions (Humata AI, ChatPDF, NotebookLM) that charge $49/user/month and rely on expensive cloud LLM APIs, this system offers:
|
| 51 |
+
|
| 52 |
+
### **Core Value Propositions**
|
| 53 |
+
|
| 54 |
+
| Feature | Traditional Solutions | Our System |
|
| 55 |
+
|---------|----------------------|------------|
|
| 56 |
+
| **Privacy** | Cloud-based (data leaves premises) | 100% on-premise processing |
|
| 57 |
+
| **Cost** | $49-99/user/month + API fees | Zero API costs (local inference) |
|
| 58 |
+
| **Input Types** | PDF only | PDF, DOCX, TXT, ZIP archives |
|
| 59 |
+
| **Quality Metrics** | Black box (no visibility) | RAGAS evaluation with detailed metrics |
|
| 60 |
+
| **Retrieval** | Vector-only | Hybrid (Vector + BM25 + Reranking) |
|
| 61 |
+
| **Chunking** | Fixed size | Adaptive (3 strategies) |
|
| 62 |
+
|
| 63 |
+
### **Market Context**
|
| 64 |
+
|
| 65 |
+
- **$8.5B** projected enterprise AI search market by 2027
|
| 66 |
+
- **85%** of enterprises actively adopting AI-powered knowledge management
|
| 67 |
+
- **Growing regulatory demands** for on-premise, privacy-compliant solutions
|
| 68 |
+
|
| 69 |
+
---
|
| 70 |
+
|
| 71 |
+
## ✨ Key Features
|
| 72 |
+
|
| 73 |
+
### **1. Multi-Format Document Ingestion**
|
| 74 |
+
- **Supported Formats**: PDF, DOCX, TXT
|
| 75 |
+
- **Archive Processing**: ZIP files up to 2GB with recursive extraction
|
| 76 |
+
- **Batch Upload**: Process multiple documents simultaneously
|
| 77 |
+
- **OCR Support**: Extract text from scanned documents and images (PaddleOCR or EasyOCR)
|
| 78 |
+
|
| 79 |
+
### **2. Intelligent Document Processing**
|
| 80 |
+
- **Adaptive Chunking**: Automatically selects optimal strategy based on document size
|
| 81 |
+
- Fixed-size chunks (< 50K tokens): 512 tokens with 50 overlap
|
| 82 |
+
- Semantic chunks (50K-500K tokens): Section-aware splitting
|
| 83 |
+
- Hierarchical chunks (> 500K tokens): Parent-child structure
|
| 84 |
+
- **Metadata Extraction**: Title, author, date, page numbers, section headers
|
| 85 |
+
|
| 86 |
+
### **3. Hybrid Retrieval System**
|
| 87 |
+
```mermaid
|
| 88 |
+
graph LR
|
| 89 |
+
A[User Query] --> B[Query Embedding]
|
| 90 |
+
A --> C[Keyword Analysis]
|
| 91 |
+
B --> D[Vector Search<br/>FAISS]
|
| 92 |
+
C --> E[BM25 Search]
|
| 93 |
+
D --> F[Reciprocal Rank Fusion<br/>60% Vector + 40% BM25]
|
| 94 |
+
E --> F
|
| 95 |
+
F --> G[Cross-Encoder Reranking]
|
| 96 |
+
G --> H[Top-K Results]
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
- **Vector Search**: FAISS with BGE embeddings (384-dim)
|
| 100 |
+
- **Keyword Search**: BM25 with optimized parameters (k1=1.5, b=0.75)
|
| 101 |
+
- **Fusion Methods**: Weighted, Reciprocal Rank Fusion (RRF), CombSum
|
| 102 |
+
- **Reranking**: Cross-encoder for precision boost
|
| 103 |
+
|
| 104 |
+
### **4. Local LLM Generation**
|
| 105 |
+
- **Ollama Integration**: Zero-cost inference with Mistral-7B or LLaMA-2
|
| 106 |
+
- **Adaptive Temperature**: Context-aware generation parameters
|
| 107 |
+
- **Citation Tracking**: Automatic source attribution with validation
|
| 108 |
+
- **Streaming Support**: Token-by-token response generation
|
| 109 |
+
|
| 110 |
+
### **5. RAGAS Quality Assurance**
|
| 111 |
+
- **Real-Time Evaluation**: Answer relevancy, faithfulness, context precision/recall
|
| 112 |
+
- **Automatic Metrics**: Computed for every query-response pair
|
| 113 |
+
- **Analytics Dashboard**: Track quality trends over time
|
| 114 |
+
- **Export Capability**: Download evaluation data for analysis
|
| 115 |
+
- **Session Statistics**: Aggregate metrics across conversation sessions
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
## 🗂️ System Architecture
|
| 120 |
+
|
| 121 |
+
### **High-Level Architecture**
|
| 122 |
+
|
| 123 |
+
```mermaid
|
| 124 |
+
graph TB
|
| 125 |
+
subgraph "Frontend Layer"
|
| 126 |
+
A[Web UI<br/>HTML/CSS/JS]
|
| 127 |
+
end
|
| 128 |
+
|
| 129 |
+
subgraph "API Layer"
|
| 130 |
+
B[FastAPI Gateway<br/>REST Endpoints]
|
| 131 |
+
end
|
| 132 |
+
|
| 133 |
+
subgraph "Ingestion Pipeline"
|
| 134 |
+
C[Document Parser<br/>PDF/DOCX/TXT]
|
| 135 |
+
D[Adaptive Chunker<br/>3 Strategies]
|
| 136 |
+
E[Embedding Generator<br/>BGE-small-en-v1.5]
|
| 137 |
+
end
|
| 138 |
+
|
| 139 |
+
subgraph "Storage Layer"
|
| 140 |
+
F[FAISS Vector DB<br/>~10M vectors]
|
| 141 |
+
G[BM25 Keyword Index]
|
| 142 |
+
H[SQLite Metadata<br/>Documents & Chunks]
|
| 143 |
+
I[LRU Cache<br/>Embeddings]
|
| 144 |
+
end
|
| 145 |
+
|
| 146 |
+
subgraph "Retrieval Engine"
|
| 147 |
+
J[Hybrid Retriever<br/>Vector + BM25]
|
| 148 |
+
K[Cross-Encoder<br/>Reranker]
|
| 149 |
+
L[Context Assembler]
|
| 150 |
+
end
|
| 151 |
+
|
| 152 |
+
subgraph "Generation Engine"
|
| 153 |
+
M[Ollama LLM<br/>Mistral-7B]
|
| 154 |
+
N[Prompt Builder]
|
| 155 |
+
O[Citation Formatter]
|
| 156 |
+
end
|
| 157 |
+
|
| 158 |
+
subgraph "Evaluation Engine"
|
| 159 |
+
P[RAGAS Evaluator<br/>Quality Metrics]
|
| 160 |
+
end
|
| 161 |
+
|
| 162 |
+
A --> B
|
| 163 |
+
B --> C
|
| 164 |
+
C --> D
|
| 165 |
+
D --> E
|
| 166 |
+
E --> F
|
| 167 |
+
E --> G
|
| 168 |
+
D --> H
|
| 169 |
+
E --> I
|
| 170 |
+
|
| 171 |
+
B --> J
|
| 172 |
+
J --> F
|
| 173 |
+
J --> G
|
| 174 |
+
J --> K
|
| 175 |
+
K --> L
|
| 176 |
+
|
| 177 |
+
L --> N
|
| 178 |
+
N --> M
|
| 179 |
+
M --> O
|
| 180 |
+
O --> A
|
| 181 |
+
|
| 182 |
+
M --> P
|
| 183 |
+
P --> H
|
| 184 |
+
P --> A
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### **Why This Architecture?**
|
| 188 |
+
|
| 189 |
+
#### **Modular Design**
|
| 190 |
+
Each component is independent and replaceable:
|
| 191 |
+
- **Parser**: Swap PDF libraries without affecting chunking
|
| 192 |
+
- **Embedder**: Change from BGE to OpenAI embeddings with config update
|
| 193 |
+
- **LLM**: Switch from Ollama to OpenAI API seamlessly
|
| 194 |
+
|
| 195 |
+
#### **Separation of Concerns**
|
| 196 |
+
```
|
| 197 |
+
Ingestion → Storage → Retrieval → Generation → Evaluation
|
| 198 |
+
```
|
| 199 |
+
Each stage has clear inputs/outputs and single responsibility.
|
| 200 |
+
|
| 201 |
+
#### **Performance Optimization**
|
| 202 |
+
- **Async Processing**: Non-blocking I/O for uploads and LLM calls
|
| 203 |
+
- **Batch Operations**: Embed 32 chunks simultaneously
|
| 204 |
+
- **Local Caching**: LRU cache for query embeddings and frequent retrievals
|
| 205 |
+
- **Indexing**: FAISS ANN for O(log n) search vs O(n) brute force
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
## 🔧 Technology Stack
|
| 210 |
+
|
| 211 |
+
### **Core Technologies**
|
| 212 |
+
|
| 213 |
+
| Component | Technology | Version | Why This Choice |
|
| 214 |
+
|-----------|-----------|---------|-----------------|
|
| 215 |
+
| **Backend** | FastAPI | 0.104+ | Async support, auto-docs, production-grade |
|
| 216 |
+
| **LLM** | Ollama (Mistral-7B) | Latest | Zero API costs, on-premise, 20-30 tokens/sec |
|
| 217 |
+
| **Embeddings** | BGE-small-en-v1.5 | 384-dim | SOTA quality, 10x faster than alternatives |
|
| 218 |
+
| **Vector DB** | FAISS | Latest | Battle-tested, 10x faster than ChromaDB |
|
| 219 |
+
| **Keyword Search** | BM25 (rank_bm25) | Latest | Fast probabilistic ranking |
|
| 220 |
+
| **Document Parsing** | PyPDF2, python-docx | Latest | Industry standard, reliable |
|
| 221 |
+
| **Chunking** | LlamaIndex | 0.9+ | Advanced semantic splitting |
|
| 222 |
+
| **Reranking** | Cross-Encoder | Latest | +15% accuracy, minimal latency |
|
| 223 |
+
| **Evaluation** | RAGAS | 0.1.9 | Automated RAG quality metrics |
|
| 224 |
+
| **Frontend** | Alpine.js | 3.x | Lightweight reactivity, no build step |
|
| 225 |
+
| **Database** | SQLite | 3.x | Zero-config, sufficient for metadata |
|
| 226 |
+
| **Caching** | In-Memory LRU | Python functools | Fast, no external dependencies |
|
| 227 |
+
|
| 228 |
+
### **Python Dependencies**
|
| 229 |
+
|
| 230 |
+
```
|
| 231 |
+
fastapi>=0.104.0
|
| 232 |
+
uvicorn>=0.24.0
|
| 233 |
+
ollama>=0.1.0
|
| 234 |
+
sentence-transformers>=2.2.2
|
| 235 |
+
faiss-cpu>=1.7.4
|
| 236 |
+
llama-index>=0.9.0
|
| 237 |
+
rank-bm25>=0.2.2
|
| 238 |
+
PyPDF2>=3.0.0
|
| 239 |
+
python-docx>=0.8.11
|
| 240 |
+
pydantic>=2.0.0
|
| 241 |
+
aiohttp>=3.9.0
|
| 242 |
+
tiktoken>=0.5.0
|
| 243 |
+
ragas==0.1.9
|
| 244 |
+
datasets==2.14.6
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## 📦 Installation
|
| 250 |
+
|
| 251 |
+
### **Prerequisites**
|
| 252 |
+
|
| 253 |
+
- Python 3.10 or higher
|
| 254 |
+
- 8GB RAM minimum (16GB recommended)
|
| 255 |
+
- 10GB disk space for models and indexes
|
| 256 |
+
- Ollama installed ([https://ollama.ai](https://ollama.ai))
|
| 257 |
+
|
| 258 |
+
### **Step 1: Clone Repository**
|
| 259 |
+
|
| 260 |
+
```bash
|
| 261 |
+
git clone https://github.com/satyaki-mitra/docu-vault-ai.git
|
| 262 |
+
cd docu-vault-ai
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
### **Step 2: Create Virtual Environment**
|
| 266 |
+
|
| 267 |
+
```bash
|
| 268 |
+
# Using conda (recommended)
|
| 269 |
+
conda create -n rag_env python=3.10
|
| 270 |
+
conda activate rag_env
|
| 271 |
+
|
| 272 |
+
# Or using venv
|
| 273 |
+
python -m venv rag_env
|
| 274 |
+
source rag_env/bin/activate # On Windows: rag_env\Scripts\activate
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
### **Step 3: Install Dependencies**
|
| 278 |
+
|
| 279 |
+
```bash
|
| 280 |
+
pip install -r requirements.txt
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
### **Step 4: Install Ollama and Model**
|
| 284 |
+
|
| 285 |
+
```bash
|
| 286 |
+
# Install Ollama (macOS)
|
| 287 |
+
brew install ollama
|
| 288 |
+
|
| 289 |
+
# Install Ollama (Linux)
|
| 290 |
+
curl https://ollama.ai/install.sh | sh
|
| 291 |
+
|
| 292 |
+
# Pull Mistral model
|
| 293 |
+
ollama pull mistral:7b
|
| 294 |
+
|
| 295 |
+
# Verify installation
|
| 296 |
+
ollama list
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
### **Step 5: Configure Environment**
|
| 300 |
+
|
| 301 |
+
```bash
|
| 302 |
+
# Copy example config
|
| 303 |
+
cp .env.example .env
|
| 304 |
+
|
| 305 |
+
# Edit configuration (optional)
|
| 306 |
+
nano .env
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
**Key Configuration Options:**
|
| 310 |
+
|
| 311 |
+
```bash
|
| 312 |
+
# LLM Settings
|
| 313 |
+
OLLAMA_MODEL=mistral:7b
|
| 314 |
+
DEFAULT_TEMPERATURE=0.1
|
| 315 |
+
CONTEXT_WINDOW=8192
|
| 316 |
+
|
| 317 |
+
# Retrieval Settings
|
| 318 |
+
VECTOR_WEIGHT=0.6
|
| 319 |
+
BM25_WEIGHT=0.4
|
| 320 |
+
ENABLE_RERANKING=True
|
| 321 |
+
TOP_K_RETRIEVE=10
|
| 322 |
+
|
| 323 |
+
# RAGAS Evaluation
|
| 324 |
+
ENABLE_RAGAS=True
|
| 325 |
+
RAGAS_ENABLE_GROUND_TRUTH=False
|
| 326 |
+
OPENAI_API_KEY=your_openai_api_key_here # Required for RAGAS
|
| 327 |
+
|
| 328 |
+
# Performance
|
| 329 |
+
EMBEDDING_BATCH_SIZE=32
|
| 330 |
+
MAX_WORKERS=4
|
| 331 |
+
```
|
| 332 |
+
|
| 333 |
+
---
|
| 334 |
+
|
| 335 |
+
## 🚀 Quick Start
|
| 336 |
+
|
| 337 |
+
### **1. Start Ollama Server**
|
| 338 |
+
|
| 339 |
+
```bash
|
| 340 |
+
# Terminal 1: Start Ollama
|
| 341 |
+
ollama serve
|
| 342 |
+
```
|
| 343 |
+
|
| 344 |
+
### **2. Launch Application**
|
| 345 |
+
|
| 346 |
+
```bash
|
| 347 |
+
# Terminal 2: Start RAG system
|
| 348 |
+
python app.py
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
Output:
|
| 352 |
+
```
|
| 353 |
+
INFO: Started server process [12345]
|
| 354 |
+
INFO: Waiting for application startup.
|
| 355 |
+
INFO: Application startup complete.
|
| 356 |
+
INFO: Uvicorn running on http://0.0.0.0:8000
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
### **3. Access Web Interface**
|
| 360 |
+
|
| 361 |
+
Open browser to: **http://localhost:8000**
|
| 362 |
+
|
| 363 |
+
### **4. Upload Documents**
|
| 364 |
+
|
| 365 |
+
1. Click **"Upload Documents"**
|
| 366 |
+
2. Select PDF/DOCX/TXT files (or ZIP archives)
|
| 367 |
+
3. Click **"Start Building"**
|
| 368 |
+
4. Wait for indexing to complete (progress bar shows status)
|
| 369 |
+
|
| 370 |
+
### **5. Query Your Documents**
|
| 371 |
+
|
| 372 |
+
```
|
| 373 |
+
Query: "What are the key findings in the Q3 report?"
|
| 374 |
+
|
| 375 |
+
Response: The Q3 report highlights three key findings:
|
| 376 |
+
[1] Revenue increased 23% year-over-year to $45.2M,
|
| 377 |
+
[2] Customer acquisition costs decreased 15%, and
|
| 378 |
+
[3] Net retention rate reached 118% [1].
|
| 379 |
+
|
| 380 |
+
Sources:
|
| 381 |
+
[1] Q3_Financial_Report.pdf (Page 3, Executive Summary)
|
| 382 |
+
|
| 383 |
+
RAGAS Metrics:
|
| 384 |
+
- Answer Relevancy: 0.89
|
| 385 |
+
- Faithfulness: 0.94
|
| 386 |
+
- Context Utilization: 0.87
|
| 387 |
+
- Overall Score: 0.90
|
| 388 |
+
```
|
| 389 |
+
|
| 390 |
+
---
|
| 391 |
+
|
| 392 |
+
## 🧩 Core Components
|
| 393 |
+
|
| 394 |
+
### **1. Document Ingestion Pipeline**
|
| 395 |
+
|
| 396 |
+
```python
|
| 397 |
+
# High-level flow
|
| 398 |
+
Document Upload → Parse → Clean → Chunk → Embed → Index
|
| 399 |
+
```
|
| 400 |
+
|
| 401 |
+
**Adaptive Chunking Logic:**
|
| 402 |
+
|
| 403 |
+
```mermaid
|
| 404 |
+
graph TD
|
| 405 |
+
A[Calculate Token Count] --> B{Tokens < 50K?}
|
| 406 |
+
B -->|Yes| C[Fixed Chunking<br/>512 tokens, 50 overlap]
|
| 407 |
+
B -->|No| D{Tokens < 500K?}
|
| 408 |
+
D -->|Yes| E[Semantic Chunking<br/>Section-aware]
|
| 409 |
+
D -->|No| F[Hierarchical Chunking<br/>Parent 2048, Child 512]
|
| 410 |
+
```
|
| 411 |
+
|
| 412 |
+
### **2. Hybrid Retrieval Engine**
|
| 413 |
+
|
| 414 |
+
**Retrieval Flow:**
|
| 415 |
+
|
| 416 |
+
```python
|
| 417 |
+
# Pseudocode
|
| 418 |
+
def hybrid_retrieve(query: str, top_k: int = 10):
|
| 419 |
+
# Dual retrieval
|
| 420 |
+
query_embedding = embedder.embed(query)
|
| 421 |
+
vector_results = faiss_index.search(query_embedding, top_k * 2)
|
| 422 |
+
bm25_results = bm25_index.search(query, top_k * 2)
|
| 423 |
+
|
| 424 |
+
# Fusion (RRF)
|
| 425 |
+
fused_results = reciprocal_rank_fusion(vector_results,
|
| 426 |
+
bm25_results,
|
| 427 |
+
weights = (0.6, 0.4))
|
| 428 |
+
|
| 429 |
+
# Reranking
|
| 430 |
+
reranked = cross_encoder.rerank(query, fused_results, top_k)
|
| 431 |
+
|
| 432 |
+
return reranked
|
| 433 |
+
```
|
| 434 |
+
|
| 435 |
+
### **3. Response Generation**
|
| 436 |
+
|
| 437 |
+
**Temperature Control:**
|
| 438 |
+
|
| 439 |
+
```mermaid
|
| 440 |
+
graph LR
|
| 441 |
+
A[Query Type] --> B{Factual?}
|
| 442 |
+
B -->|Yes| C[Low Temp<br/>0.1-0.2]
|
| 443 |
+
B -->|No| D[Context Quality]
|
| 444 |
+
D -->|High| E[Medium Temp<br/>0.3-0.5]
|
| 445 |
+
D -->|Low| F[High Temp<br/>0.6-0.8]
|
| 446 |
+
```
|
| 447 |
+
|
| 448 |
+
### **4. RAGAS Evaluation Module**
|
| 449 |
+
|
| 450 |
+
**Automatic Quality Assessment:**
|
| 451 |
+
|
| 452 |
+
```python
|
| 453 |
+
# After each query-response
|
| 454 |
+
ragas_result = ragas_evaluator.evaluate_single(query = user_query,
|
| 455 |
+
answer = generated_answer,
|
| 456 |
+
contexts = retrieved_chunks,
|
| 457 |
+
retrieval_time_ms = retrieval_time,
|
| 458 |
+
generation_time_ms = generation_time,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Metrics computed:
|
| 462 |
+
- Answer Relevancy (0-1)
|
| 463 |
+
- Faithfulness (0-1)
|
| 464 |
+
- Context Utilization (0-1)
|
| 465 |
+
- Context Relevancy (0-1)
|
| 466 |
+
- Overall Score (weighted average)
|
| 467 |
+
```
|
| 468 |
+
|
| 469 |
+
---
|
| 470 |
+
|
| 471 |
+
## 📚 API Documentation
|
| 472 |
+
|
| 473 |
+
### **Core Endpoints**
|
| 474 |
+
|
| 475 |
+
#### **1. Health Check**
|
| 476 |
+
|
| 477 |
+
```bash
|
| 478 |
+
GET /api/health
|
| 479 |
+
```
|
| 480 |
+
|
| 481 |
+
**Response:**
|
| 482 |
+
```json
|
| 483 |
+
{
|
| 484 |
+
"status": "healthy",
|
| 485 |
+
"timestamp": "2024-11-27T03:00:00",
|
| 486 |
+
"components": {
|
| 487 |
+
"vector_store": true,
|
| 488 |
+
"llm": true,
|
| 489 |
+
"embeddings": true,
|
| 490 |
+
"retrieval": true
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
```
|
| 494 |
+
|
| 495 |
+
#### **2. Upload Documents**
|
| 496 |
+
|
| 497 |
+
```bash
|
| 498 |
+
POST /api/upload
|
| 499 |
+
Content-Type: multipart/form-data
|
| 500 |
+
|
| 501 |
+
files: [file1.pdf, file2.docx]
|
| 502 |
+
```
|
| 503 |
+
|
| 504 |
+
#### **3. Start Processing**
|
| 505 |
+
|
| 506 |
+
```bash
|
| 507 |
+
POST /api/start-processing
|
| 508 |
+
```
|
| 509 |
+
|
| 510 |
+
#### **4. Query (Chat)**
|
| 511 |
+
|
| 512 |
+
```bash
|
| 513 |
+
POST /api/chat
|
| 514 |
+
Content-Type: application/json
|
| 515 |
+
|
| 516 |
+
{
|
| 517 |
+
"message": "What are the revenue figures?",
|
| 518 |
+
"session_id": "session_123"
|
| 519 |
+
}
|
| 520 |
+
```
|
| 521 |
+
|
| 522 |
+
**Response includes RAGAS metrics:**
|
| 523 |
+
```json
|
| 524 |
+
{
|
| 525 |
+
"session_id": "session_123",
|
| 526 |
+
"response": "Revenue for Q3 was $45.2M [1]...",
|
| 527 |
+
"sources": [...],
|
| 528 |
+
"metrics": {
|
| 529 |
+
"retrieval_time": 245,
|
| 530 |
+
"generation_time": 3100,
|
| 531 |
+
"total_time": 3350
|
| 532 |
+
},
|
| 533 |
+
"ragas_metrics": {
|
| 534 |
+
"answer_relevancy": 0.89,
|
| 535 |
+
"faithfulness": 0.94,
|
| 536 |
+
"context_utilization": 0.87,
|
| 537 |
+
"context_relevancy": 0.91,
|
| 538 |
+
"overall_score": 0.90
|
| 539 |
+
}
|
| 540 |
+
}
|
| 541 |
+
```
|
| 542 |
+
|
| 543 |
+
#### **5. RAGAS Endpoints**
|
| 544 |
+
|
| 545 |
+
```bash
|
| 546 |
+
# Get evaluation history
|
| 547 |
+
GET /api/ragas/history
|
| 548 |
+
|
| 549 |
+
# Get session statistics
|
| 550 |
+
GET /api/ragas/statistics
|
| 551 |
+
|
| 552 |
+
# Clear evaluation history
|
| 553 |
+
POST /api/ragas/clear
|
| 554 |
+
|
| 555 |
+
# Export evaluation data
|
| 556 |
+
GET /api/ragas/export
|
| 557 |
+
|
| 558 |
+
# Get RAGAS configuration
|
| 559 |
+
GET /api/ragas/config
|
| 560 |
+
```
|
| 561 |
+
|
| 562 |
+
---
|
| 563 |
+
|
| 564 |
+
## ⚙️ Configuration
|
| 565 |
+
|
| 566 |
+
### **config/settings.py**
|
| 567 |
+
|
| 568 |
+
**Key Configuration Sections:**
|
| 569 |
+
|
| 570 |
+
#### **LLM Settings**
|
| 571 |
+
```python
|
| 572 |
+
OLLAMA_MODEL = "mistral:7b"
|
| 573 |
+
DEFAULT_TEMPERATURE = 0.1
|
| 574 |
+
MAX_TOKENS = 1000
|
| 575 |
+
CONTEXT_WINDOW = 8192
|
| 576 |
+
```
|
| 577 |
+
|
| 578 |
+
#### **RAGAS Settings**
|
| 579 |
+
```python
|
| 580 |
+
ENABLE_RAGAS = True
|
| 581 |
+
RAGAS_ENABLE_GROUND_TRUTH = False
|
| 582 |
+
RAGAS_METRICS = ["answer_relevancy",
|
| 583 |
+
"faithfulness",
|
| 584 |
+
"context_utilization",
|
| 585 |
+
"context_relevancy"
|
| 586 |
+
]
|
| 587 |
+
RAGAS_EVALUATION_TIMEOUT = 60
|
| 588 |
+
RAGAS_BATCH_SIZE = 10
|
| 589 |
+
```
|
| 590 |
+
|
| 591 |
+
#### **Caching Settings**
|
| 592 |
+
```python
|
| 593 |
+
ENABLE_EMBEDDING_CACHE = True
|
| 594 |
+
CACHE_MAX_SIZE = 1000 # LRU cache size
|
| 595 |
+
CACHE_TTL = 3600 # Time to live in seconds
|
| 596 |
+
```
|
| 597 |
+
|
| 598 |
+
---
|
| 599 |
+
|
| 600 |
+
## 📊 RAGAS Evaluation
|
| 601 |
+
|
| 602 |
+
### **What is RAGAS?**
|
| 603 |
+
|
| 604 |
+
RAGAS (Retrieval-Augmented Generation Assessment) is a framework for evaluating RAG systems using automated metrics. Our implementation provides real-time quality assessment for every query-response pair.
|
| 605 |
+
|
| 606 |
+
### **Metrics Explained**
|
| 607 |
+
|
| 608 |
+
| Metric | Definition | Target | Interpretation |
|
| 609 |
+
|--------|-----------|--------|----------------|
|
| 610 |
+
| **Answer Relevancy** | How well the answer addresses the question | > 0.85 | Measures usefulness to user |
|
| 611 |
+
| **Faithfulness** | Is the answer grounded in retrieved context? | > 0.90 | Prevents hallucinations |
|
| 612 |
+
| **Context Utilization** | How well the context is used in the answer | > 0.80 | Retrieval effectiveness |
|
| 613 |
+
| **Context Relevancy** | Are retrieved chunks relevant to the query? | > 0.85 | Search quality |
|
| 614 |
+
| **Overall Score** | Weighted average of all metrics | > 0.85 | System performance |
|
| 615 |
+
|
| 616 |
+
### **Using the Analytics Dashboard**
|
| 617 |
+
|
| 618 |
+
1. Navigate to **Analytics & Quality** section
|
| 619 |
+
2. View real-time RAGAS metrics table
|
| 620 |
+
3. Monitor session statistics (averages, trends)
|
| 621 |
+
4. Export evaluation data for offline analysis
|
| 622 |
+
|
| 623 |
+
### **Example Evaluation Output**
|
| 624 |
+
|
| 625 |
+
```
|
| 626 |
+
Query: "What were the Q3 revenue trends?"
|
| 627 |
+
Answer: "Q3 revenue increased 23% YoY to $45.2M..."
|
| 628 |
+
|
| 629 |
+
RAGAS Evaluation:
|
| 630 |
+
├─ Answer Relevancy: 0.89 ✓ (Good)
|
| 631 |
+
├─ Faithfulness: 0.94 ✓ (Excellent)
|
| 632 |
+
├─ Context Utilization: 0.87 ✓ (Good)
|
| 633 |
+
├─ Context Relevancy: 0.91 ✓ (Excellent)
|
| 634 |
+
└─ Overall Score: 0.90 ✓ (Excellent)
|
| 635 |
+
|
| 636 |
+
Performance:
|
| 637 |
+
├─ Retrieval Time: 245ms
|
| 638 |
+
├─ Generation Time: 3100ms
|
| 639 |
+
└─ Total Time: 3345ms
|
| 640 |
+
```
|
| 641 |
+
|
| 642 |
+
---
|
| 643 |
+
|
| 644 |
+
## 🔧 Troubleshooting
|
| 645 |
+
|
| 646 |
+
### **Common Issues**
|
| 647 |
+
|
| 648 |
+
#### **1. "RAGAS evaluation failed"**
|
| 649 |
+
|
| 650 |
+
**Cause:** OpenAI API key not configured
|
| 651 |
+
|
| 652 |
+
**Solution:**
|
| 653 |
+
```bash
|
| 654 |
+
# Add to .env file
|
| 655 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
| 656 |
+
|
| 657 |
+
# Or disable RAGAS if not needed
|
| 658 |
+
ENABLE_RAGAS=False
|
| 659 |
+
```
|
| 660 |
+
|
| 661 |
+
#### **2. "Context assembly returning 0 chunks"**
|
| 662 |
+
|
| 663 |
+
**Cause:** Missing token counts in chunks
|
| 664 |
+
|
| 665 |
+
**Solution:** Already fixed in `context_assembler.py`. Tokens calculated on-the-fly if missing.
|
| 666 |
+
|
| 667 |
+
#### **3. "Slow query responses"**
|
| 668 |
+
|
| 669 |
+
**Solutions:**
|
| 670 |
+
- Enable embedding cache : `ENABLE_EMBEDDING_CACHE=True`
|
| 671 |
+
- Reduce retrieval count : `TOP_K_RETRIEVE=5`
|
| 672 |
+
- Disable reranking : `ENABLE_RERANKING=False`
|
| 673 |
+
|
| 674 |
+
- Use quantized model for faster inference
|
| 675 |
+
|
| 676 |
+
#### **4. "RAGAS metrics not appearing"**
|
| 677 |
+
|
| 678 |
+
**Symptoms:** Chat responses lack quality metrics
|
| 679 |
+
|
| 680 |
+
**Solution:**
|
| 681 |
+
```python
|
| 682 |
+
# Verify RAGAS is enabled in settings
|
| 683 |
+
ENABLE_RAGAS = True
|
| 684 |
+
|
| 685 |
+
# Check OpenAI API key is valid
|
| 686 |
+
# View logs for RAGAS evaluation errors
|
| 687 |
+
tail -f logs/app.log | grep "RAGAS"
|
| 688 |
+
```
|
| 689 |
+
|
| 690 |
+
---
|
| 691 |
+
|
| 692 |
+
## 📄 License
|
| 693 |
+
|
| 694 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 695 |
+
|
| 696 |
+
---
|
| 697 |
+
|
| 698 |
+
## 🙏 Acknowledgments
|
| 699 |
+
|
| 700 |
+
**Open Source Technologies:**
|
| 701 |
+
- [FastAPI](https://fastapi.tiangolo.com/) - Modern web framework
|
| 702 |
+
- [Ollama](https://ollama.ai/) - Local LLM inference
|
| 703 |
+
- [FAISS](https://github.com/facebookresearch/faiss) - Vector similarity search
|
| 704 |
+
- [LlamaIndex](https://www.llamaindex.ai/) - Document chunking
|
| 705 |
+
- [Sentence Transformers](https://www.sbert.net/) - Embedding models
|
| 706 |
+
- [RAGAS](https://github.com/explodinggradients/ragas) - RAG evaluation
|
| 707 |
+
|
| 708 |
+
**Research Papers:**
|
| 709 |
+
- Karpukhin et al. (2020) - Dense Passage Retrieval
|
| 710 |
+
- Robertson & Zaragoza (2009) - The Probabilistic Relevance Framework: BM25
|
| 711 |
+
- Lewis et al. (2020) - Retrieval-Augmented Generation
|
| 712 |
+
- Es et al. (2023) - RAGAS: Automated Evaluation of RAG
|
| 713 |
+
|
| 714 |
+
---
|
| 715 |
+
|
| 716 |
+
## 👤 Author
|
| 717 |
+
|
| 718 |
+
Satyaki Mitra | Data Scientist | Generative-AI Enthusiast
|
| 719 |
+
|
| 720 |
+
---
|
| 721 |
+
|
| 722 |
+
<div align="center">
|
| 723 |
+
|
| 724 |
+
**Built with ❤️ for the open-source community**
|
| 725 |
+
|
| 726 |
+
</div>
|
app.py
ADDED
|
@@ -0,0 +1,2224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import os
|
| 3 |
+
import gc
|
| 4 |
+
import io
|
| 5 |
+
import csv
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
import signal
|
| 9 |
+
import atexit
|
| 10 |
+
import shutil
|
| 11 |
+
import asyncio
|
| 12 |
+
import logging
|
| 13 |
+
import uvicorn
|
| 14 |
+
import tempfile
|
| 15 |
+
import traceback
|
| 16 |
+
import threading
|
| 17 |
+
from typing import Set
|
| 18 |
+
from typing import Any
|
| 19 |
+
from typing import List
|
| 20 |
+
from typing import Dict
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Tuple
|
| 23 |
+
from fastapi import File
|
| 24 |
+
from fastapi import Form
|
| 25 |
+
from signal import SIGINT
|
| 26 |
+
from signal import SIGTERM
|
| 27 |
+
from pydantic import Field
|
| 28 |
+
from fastapi import FastAPI
|
| 29 |
+
from typing import Optional
|
| 30 |
+
from datetime import datetime
|
| 31 |
+
from datetime import timedelta
|
| 32 |
+
from fastapi import UploadFile
|
| 33 |
+
from pydantic import BaseModel
|
| 34 |
+
from fastapi import HTTPException
|
| 35 |
+
from config.models import PromptType
|
| 36 |
+
from config.models import ChatRequest
|
| 37 |
+
from config.models import LLMProvider
|
| 38 |
+
from utils.helpers import IDGenerator
|
| 39 |
+
from config.models import QueryRequest
|
| 40 |
+
from config.settings import get_settings
|
| 41 |
+
from config.models import RAGASStatistics
|
| 42 |
+
from config.models import RAGASExportData
|
| 43 |
+
from config.models import DocumentMetadata
|
| 44 |
+
from fastapi.responses import HTMLResponse
|
| 45 |
+
from fastapi.responses import FileResponse
|
| 46 |
+
from fastapi.responses import JSONResponse
|
| 47 |
+
from contextlib import asynccontextmanager
|
| 48 |
+
from utils.file_handler import FileHandler
|
| 49 |
+
from utils.validators import FileValidator
|
| 50 |
+
from fastapi.staticfiles import StaticFiles
|
| 51 |
+
from utils.error_handler import RAGException
|
| 52 |
+
from utils.error_handler import FileException
|
| 53 |
+
from config.models import RAGASEvaluationResult
|
| 54 |
+
from config.logging_config import setup_logging
|
| 55 |
+
from generation.llm_client import get_llm_client
|
| 56 |
+
from embeddings.bge_embedder import get_embedder
|
| 57 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 58 |
+
from ingestion.router import get_ingestion_router
|
| 59 |
+
from utils.validators import validate_upload_file
|
| 60 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 61 |
+
from vector_store.index_builder import get_index_builder
|
| 62 |
+
from document_parser.parser_factory import ParserFactory
|
| 63 |
+
from evaluation.ragas_evaluator import get_ragas_evaluator
|
| 64 |
+
from vector_store.metadata_store import get_metadata_store
|
| 65 |
+
from embeddings.embedding_cache import get_embedding_cache
|
| 66 |
+
from ingestion.progress_tracker import get_progress_tracker
|
| 67 |
+
from retrieval.hybrid_retriever import get_hybrid_retriever
|
| 68 |
+
from chunking.adaptive_selector import get_adaptive_selector
|
| 69 |
+
from retrieval.context_assembler import get_context_assembler
|
| 70 |
+
from document_parser.parser_factory import get_parser_factory
|
| 71 |
+
from chunking.adaptive_selector import AdaptiveChunkingSelector
|
| 72 |
+
from generation.response_generator import get_response_generator
|
| 73 |
+
from config.models import ProcessingStatus as ProcessingStatusEnum
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Setup logging and settings
|
| 77 |
+
settings = get_settings()
|
| 78 |
+
logger = setup_logging(log_level = settings.LOG_LEVEL,
|
| 79 |
+
log_dir = settings.LOG_DIR,
|
| 80 |
+
enable_console = True,
|
| 81 |
+
enable_file = True,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Global Cleanup Variables
|
| 86 |
+
_cleanup_registry : Set[str] = set()
|
| 87 |
+
_cleanup_lock = threading.RLock()
|
| 88 |
+
_is_cleaning = False
|
| 89 |
+
_cleanup_executor = ThreadPoolExecutor(max_workers = 2,
|
| 90 |
+
thread_name_prefix = "cleanup_",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Analytics Cache Structure
|
| 95 |
+
class AnalyticsCache:
|
| 96 |
+
"""
|
| 97 |
+
Cache for analytics data to avoid recalculating on every request
|
| 98 |
+
"""
|
| 99 |
+
def __init__(self, ttl_seconds: int = 30):
|
| 100 |
+
self.data = None
|
| 101 |
+
self.last_calculated = None
|
| 102 |
+
self.ttl_seconds = ttl_seconds
|
| 103 |
+
self.is_calculating = False
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def is_valid(self) -> bool:
|
| 107 |
+
"""
|
| 108 |
+
Check if cache is still valid
|
| 109 |
+
"""
|
| 110 |
+
if self.data is None or self.last_calculated is None:
|
| 111 |
+
return False
|
| 112 |
+
|
| 113 |
+
elapsed = (datetime.now() - self.last_calculated).total_seconds()
|
| 114 |
+
|
| 115 |
+
return (elapsed < self.ttl_seconds)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def update(self, data: Dict):
|
| 119 |
+
"""
|
| 120 |
+
Update cache with new data
|
| 121 |
+
"""
|
| 122 |
+
self.data = data
|
| 123 |
+
self.last_calculated = datetime.now()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get(self) -> Optional[Dict]:
|
| 127 |
+
"""
|
| 128 |
+
Get cached data if valid
|
| 129 |
+
"""
|
| 130 |
+
return self.data if self.is_valid() else None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class CleanupManager:
|
| 134 |
+
"""
|
| 135 |
+
Centralized cleanup manager with multiple redundancy layers
|
| 136 |
+
"""
|
| 137 |
+
@staticmethod
|
| 138 |
+
def register_resource(resource_id: str, cleanup_func, *args, **kwargs):
|
| 139 |
+
"""
|
| 140 |
+
Register a resource for cleanup
|
| 141 |
+
"""
|
| 142 |
+
with _cleanup_lock:
|
| 143 |
+
_cleanup_registry.add(resource_id)
|
| 144 |
+
|
| 145 |
+
# Register with atexit for process termination
|
| 146 |
+
atexit.register(cleanup_func, *args, **kwargs)
|
| 147 |
+
|
| 148 |
+
return resource_id
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def unregister_resource(resource_id: str):
|
| 153 |
+
"""
|
| 154 |
+
Unregister a resource (if already cleaned up elsewhere)
|
| 155 |
+
"""
|
| 156 |
+
with _cleanup_lock:
|
| 157 |
+
if resource_id in _cleanup_registry:
|
| 158 |
+
_cleanup_registry.remove(resource_id)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
async def full_cleanup(state: Optional['AppState'] = None) -> bool:
|
| 163 |
+
"""
|
| 164 |
+
Perform full system cleanup with redundancy
|
| 165 |
+
"""
|
| 166 |
+
global _is_cleaning
|
| 167 |
+
|
| 168 |
+
with _cleanup_lock:
|
| 169 |
+
if _is_cleaning:
|
| 170 |
+
logger.warning("Cleanup already in progress")
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
_is_cleaning = True
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
logger.info("Starting comprehensive system cleanup...")
|
| 177 |
+
|
| 178 |
+
# Layer 1: Memory cleanup
|
| 179 |
+
success1 = await CleanupManager._cleanup_memory(state)
|
| 180 |
+
|
| 181 |
+
# Layer 2: Disk cleanup (async to not block)
|
| 182 |
+
success2 = await CleanupManager._cleanup_disk_async()
|
| 183 |
+
|
| 184 |
+
# Layer 3: Component cleanup
|
| 185 |
+
success3 = await CleanupManager._cleanup_components(state)
|
| 186 |
+
|
| 187 |
+
# Layer 4: External resources
|
| 188 |
+
success4 = CleanupManager._cleanup_external_resources()
|
| 189 |
+
|
| 190 |
+
# Clear registry
|
| 191 |
+
with _cleanup_lock:
|
| 192 |
+
_cleanup_registry.clear()
|
| 193 |
+
|
| 194 |
+
overall_success = all([success1, success2, success3, success4])
|
| 195 |
+
|
| 196 |
+
if overall_success:
|
| 197 |
+
logger.info("Comprehensive cleanup completed successfully")
|
| 198 |
+
|
| 199 |
+
else:
|
| 200 |
+
logger.warning("Cleanup completed with some failures")
|
| 201 |
+
|
| 202 |
+
return overall_success
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Cleanup failed catastrophically: {e}", exc_info=True)
|
| 206 |
+
return False
|
| 207 |
+
|
| 208 |
+
finally:
|
| 209 |
+
with _cleanup_lock:
|
| 210 |
+
_is_cleaning = False
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
async def _cleanup_memory(state: Optional['AppState']) -> bool:
|
| 214 |
+
"""
|
| 215 |
+
Memory cleanup
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
if not state:
|
| 219 |
+
logger.warning("No AppState provided for memory cleanup")
|
| 220 |
+
return True
|
| 221 |
+
|
| 222 |
+
# Session cleanup
|
| 223 |
+
session_count = len(state.active_sessions)
|
| 224 |
+
state.active_sessions.clear()
|
| 225 |
+
state.config_overrides.clear()
|
| 226 |
+
logger.info(f"Cleared {session_count} sessions from memory")
|
| 227 |
+
|
| 228 |
+
# Document data cleanup
|
| 229 |
+
doc_count = len(state.processed_documents)
|
| 230 |
+
chunk_count = sum(len(chunks) for chunks in state.document_chunks.values())
|
| 231 |
+
|
| 232 |
+
state.processed_documents.clear()
|
| 233 |
+
state.document_chunks.clear()
|
| 234 |
+
state.uploaded_files.clear()
|
| 235 |
+
logger.info(f"Cleared {doc_count} documents ({chunk_count} chunks) from memory")
|
| 236 |
+
|
| 237 |
+
# Performance data cleanup
|
| 238 |
+
state.query_timings.clear()
|
| 239 |
+
state.chunking_statistics.clear()
|
| 240 |
+
|
| 241 |
+
# State reset
|
| 242 |
+
state.is_ready = False
|
| 243 |
+
state.processing_status = "idle"
|
| 244 |
+
|
| 245 |
+
# Cache cleanup
|
| 246 |
+
if hasattr(state, 'analytics_cache'):
|
| 247 |
+
state.analytics_cache.data = None
|
| 248 |
+
|
| 249 |
+
# Force garbage collection
|
| 250 |
+
collected = gc.collect()
|
| 251 |
+
logger.debug(f"🧹 Garbage collection freed {collected} objects")
|
| 252 |
+
|
| 253 |
+
return True
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.error(f"Memory cleanup failed: {e}")
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
@staticmethod
|
| 260 |
+
async def _cleanup_disk_async() -> bool:
|
| 261 |
+
"""
|
| 262 |
+
Asynchronous disk cleanup
|
| 263 |
+
"""
|
| 264 |
+
try:
|
| 265 |
+
# Run in thread pool to avoid blocking
|
| 266 |
+
loop = asyncio.get_event_loop()
|
| 267 |
+
success = await loop.run_in_executor(_cleanup_executor, CleanupManager._cleanup_disk_sync)
|
| 268 |
+
|
| 269 |
+
return success
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.error(f"Async disk cleanup failed: {e}")
|
| 273 |
+
return False
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def _cleanup_disk_sync() -> bool:
|
| 278 |
+
"""
|
| 279 |
+
Synchronous disk cleanup
|
| 280 |
+
"""
|
| 281 |
+
try:
|
| 282 |
+
logger.info("Starting disk cleanup...")
|
| 283 |
+
|
| 284 |
+
# Track what we clean
|
| 285 |
+
cleaned_paths = list()
|
| 286 |
+
|
| 287 |
+
# Vector store directory
|
| 288 |
+
if settings.VECTOR_STORE_DIR.exists():
|
| 289 |
+
vector_files = list(settings.VECTOR_STORE_DIR.glob("*"))
|
| 290 |
+
for file in vector_files:
|
| 291 |
+
try:
|
| 292 |
+
if file.is_file():
|
| 293 |
+
file.unlink()
|
| 294 |
+
cleaned_paths.append(str(file))
|
| 295 |
+
|
| 296 |
+
elif file.is_dir():
|
| 297 |
+
shutil.rmtree(file)
|
| 298 |
+
cleaned_paths.append(str(file))
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.warning(f"Failed to delete {file}: {e}")
|
| 302 |
+
|
| 303 |
+
logger.info(f"Cleaned {len(cleaned_paths)} vector store files")
|
| 304 |
+
|
| 305 |
+
# Upload directory (preserve directory structure)
|
| 306 |
+
if settings.UPLOAD_DIR.exists():
|
| 307 |
+
upload_files = list(settings.UPLOAD_DIR.glob("*"))
|
| 308 |
+
for file in upload_files:
|
| 309 |
+
try:
|
| 310 |
+
if file.is_file():
|
| 311 |
+
file.unlink()
|
| 312 |
+
cleaned_paths.append(str(file))
|
| 313 |
+
|
| 314 |
+
elif file.is_dir():
|
| 315 |
+
shutil.rmtree(file)
|
| 316 |
+
cleaned_paths.append(str(file))
|
| 317 |
+
|
| 318 |
+
except Exception as e:
|
| 319 |
+
logger.warning(f"Failed to delete {file}: {e}")
|
| 320 |
+
|
| 321 |
+
# Recreate empty directory
|
| 322 |
+
settings.UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
| 323 |
+
logger.info(f"Cleaned {len(upload_files)} uploaded files")
|
| 324 |
+
|
| 325 |
+
# Metadata database
|
| 326 |
+
metadata_path = Path(settings.METADATA_DB_PATH)
|
| 327 |
+
if metadata_path.exists():
|
| 328 |
+
try:
|
| 329 |
+
metadata_path.unlink(missing_ok=True)
|
| 330 |
+
cleaned_paths.append(str(metadata_path))
|
| 331 |
+
logger.info("Cleaned metadata database")
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.warning(f"Failed to delete metadata DB: {e}")
|
| 335 |
+
|
| 336 |
+
# Backup directory
|
| 337 |
+
if settings.BACKUP_DIR.exists():
|
| 338 |
+
backup_files = list(settings.BACKUP_DIR.glob("*"))
|
| 339 |
+
for file in backup_files:
|
| 340 |
+
try:
|
| 341 |
+
if file.is_file():
|
| 342 |
+
file.unlink()
|
| 343 |
+
|
| 344 |
+
elif file.is_dir():
|
| 345 |
+
shutil.rmtree(file)
|
| 346 |
+
|
| 347 |
+
except:
|
| 348 |
+
pass # Silently fail for backups
|
| 349 |
+
logger.info(f"Cleaned {len(backup_files)} backup files")
|
| 350 |
+
|
| 351 |
+
# Temp files cleanup
|
| 352 |
+
CleanupManager._cleanup_temp_files()
|
| 353 |
+
|
| 354 |
+
logger.info(f"Disk cleanup completed: {len(cleaned_paths)} items cleaned")
|
| 355 |
+
|
| 356 |
+
return True
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
logger.error(f"Disk cleanup failed: {e}")
|
| 360 |
+
return False
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@staticmethod
|
| 364 |
+
def _cleanup_temp_files():
|
| 365 |
+
"""
|
| 366 |
+
Clean up temporary files
|
| 367 |
+
"""
|
| 368 |
+
temp_dir = tempfile.gettempdir()
|
| 369 |
+
|
| 370 |
+
# Clean our specific temp files (if any)
|
| 371 |
+
for pattern in ["rag_*", "faiss_*", "embedding_*"]:
|
| 372 |
+
for file in Path(temp_dir).glob(pattern):
|
| 373 |
+
try:
|
| 374 |
+
file.unlink(missing_ok=True)
|
| 375 |
+
except:
|
| 376 |
+
pass
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@staticmethod
|
| 380 |
+
async def _cleanup_components(state: Optional['AppState']) -> bool:
|
| 381 |
+
"""
|
| 382 |
+
Component-specific cleanup
|
| 383 |
+
"""
|
| 384 |
+
try:
|
| 385 |
+
if not state:
|
| 386 |
+
return True
|
| 387 |
+
|
| 388 |
+
components_cleaned = 0
|
| 389 |
+
|
| 390 |
+
# Vector store components
|
| 391 |
+
if state.index_builder:
|
| 392 |
+
try:
|
| 393 |
+
state.index_builder.clear_indexes()
|
| 394 |
+
components_cleaned += 1
|
| 395 |
+
|
| 396 |
+
except Exception as e:
|
| 397 |
+
logger.warning(f"Index builder cleanup failed: {e}")
|
| 398 |
+
|
| 399 |
+
if state.metadata_store and hasattr(state.metadata_store, 'clear_all'):
|
| 400 |
+
try:
|
| 401 |
+
state.metadata_store.clear_all()
|
| 402 |
+
components_cleaned += 1
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.warning(f"Metadata store cleanup failed: {e}")
|
| 406 |
+
|
| 407 |
+
# RAGAS evaluator
|
| 408 |
+
if state.ragas_evaluator and hasattr(state.ragas_evaluator, 'clear_history'):
|
| 409 |
+
try:
|
| 410 |
+
state.ragas_evaluator.clear_history()
|
| 411 |
+
components_cleaned += 1
|
| 412 |
+
|
| 413 |
+
except Exception as e:
|
| 414 |
+
logger.warning(f"RAGAS evaluator cleanup failed: {e}")
|
| 415 |
+
|
| 416 |
+
logger.info(f"Cleaned {components_cleaned} components")
|
| 417 |
+
return True
|
| 418 |
+
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.error(f"Component cleanup failed: {e}")
|
| 421 |
+
return False
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
@staticmethod
|
| 425 |
+
def _cleanup_external_resources() -> bool:
|
| 426 |
+
"""
|
| 427 |
+
External resource cleanup
|
| 428 |
+
"""
|
| 429 |
+
try:
|
| 430 |
+
# Close database connections
|
| 431 |
+
CleanupManager._close_db_connections()
|
| 432 |
+
|
| 433 |
+
# Clean up thread pool
|
| 434 |
+
_cleanup_executor.shutdown(wait = False)
|
| 435 |
+
|
| 436 |
+
logger.info("External resources cleaned")
|
| 437 |
+
return True
|
| 438 |
+
|
| 439 |
+
except Exception as e:
|
| 440 |
+
logger.error(f"External resource cleanup failed: {e}")
|
| 441 |
+
return False
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _close_db_connections():
|
| 446 |
+
"""
|
| 447 |
+
Close any open database connections
|
| 448 |
+
"""
|
| 449 |
+
try:
|
| 450 |
+
# SQLite handles this automatically in most cases
|
| 451 |
+
pass
|
| 452 |
+
except:
|
| 453 |
+
pass
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def handle_signal(signum, frame):
|
| 458 |
+
"""
|
| 459 |
+
Signal handler for graceful shutdown
|
| 460 |
+
"""
|
| 461 |
+
global _is_cleaning
|
| 462 |
+
|
| 463 |
+
# If already cleaning up, don't raise KeyboardInterrupt
|
| 464 |
+
with _cleanup_lock:
|
| 465 |
+
if _is_cleaning:
|
| 466 |
+
logger.info(f"Signal {signum} received during cleanup - ignoring")
|
| 467 |
+
return
|
| 468 |
+
|
| 469 |
+
if (signum == SIGINT):
|
| 470 |
+
logger.info("Ctrl+C received - shutdown initiated")
|
| 471 |
+
raise KeyboardInterrupt
|
| 472 |
+
|
| 473 |
+
elif (signum == SIGTERM):
|
| 474 |
+
logger.info("SIGTERM received - shutdown initiated")
|
| 475 |
+
# Just log, not scheduling anything
|
| 476 |
+
|
| 477 |
+
else:
|
| 478 |
+
logger.info(f"Signal {signum} received")
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# Global state manager
|
| 482 |
+
class AppState:
|
| 483 |
+
"""
|
| 484 |
+
Manages application state and components
|
| 485 |
+
"""
|
| 486 |
+
def __init__(self):
|
| 487 |
+
self.is_ready = False
|
| 488 |
+
self.processing_status = "idle"
|
| 489 |
+
self.uploaded_files = list()
|
| 490 |
+
self.active_sessions = dict()
|
| 491 |
+
self.processed_documents = dict()
|
| 492 |
+
self.document_chunks = dict()
|
| 493 |
+
|
| 494 |
+
# Performance tracking
|
| 495 |
+
self.query_timings = list()
|
| 496 |
+
self.chunking_statistics = dict()
|
| 497 |
+
|
| 498 |
+
# Core components
|
| 499 |
+
self.file_handler = None
|
| 500 |
+
self.parser_factory = None
|
| 501 |
+
self.chunking_selector = None
|
| 502 |
+
|
| 503 |
+
# Embeddings components
|
| 504 |
+
self.embedder = None
|
| 505 |
+
self.embedding_cache = None
|
| 506 |
+
|
| 507 |
+
# Ingestion components
|
| 508 |
+
self.ingestion_router = None
|
| 509 |
+
self.progress_tracker = None
|
| 510 |
+
|
| 511 |
+
# Vector store components
|
| 512 |
+
self.index_builder = None
|
| 513 |
+
self.metadata_store = None
|
| 514 |
+
|
| 515 |
+
# Retrieval components
|
| 516 |
+
self.hybrid_retriever = None
|
| 517 |
+
self.context_assembler = None
|
| 518 |
+
|
| 519 |
+
# Generation components
|
| 520 |
+
self.response_generator = None
|
| 521 |
+
self.llm_client = None
|
| 522 |
+
|
| 523 |
+
# RAGAS component
|
| 524 |
+
self.ragas_evaluator = None
|
| 525 |
+
|
| 526 |
+
# Processing tracking
|
| 527 |
+
self.current_processing = None
|
| 528 |
+
self.processing_progress = {"status" : "idle",
|
| 529 |
+
"current_step" : "Waiting",
|
| 530 |
+
"progress" : 0,
|
| 531 |
+
"processed" : 0,
|
| 532 |
+
"total" : 0,
|
| 533 |
+
"details" : {},
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
# Session-based configuration overrides
|
| 537 |
+
self.config_overrides = dict()
|
| 538 |
+
|
| 539 |
+
# Analytics cache
|
| 540 |
+
self.analytics_cache = AnalyticsCache(ttl_seconds = 30)
|
| 541 |
+
|
| 542 |
+
# System start time
|
| 543 |
+
self.start_time = datetime.now()
|
| 544 |
+
|
| 545 |
+
# Add cleanup tracking
|
| 546 |
+
self._cleanup_registered = False
|
| 547 |
+
self._cleanup_resources = list()
|
| 548 |
+
|
| 549 |
+
# Register with cleanup manager
|
| 550 |
+
self._register_for_cleanup()
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _register_for_cleanup(self):
|
| 554 |
+
"""
|
| 555 |
+
Register this AppState instance for cleanup
|
| 556 |
+
"""
|
| 557 |
+
if not self._cleanup_registered:
|
| 558 |
+
resource_id = f"appstate_{id(self)}"
|
| 559 |
+
|
| 560 |
+
CleanupManager.register_resource(resource_id, self._emergency_cleanup)
|
| 561 |
+
self._cleanup_resources.append(resource_id)
|
| 562 |
+
|
| 563 |
+
self._cleanup_registered = True
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def _emergency_cleanup(self):
|
| 567 |
+
"""
|
| 568 |
+
Emergency cleanup if regular cleanup fails
|
| 569 |
+
"""
|
| 570 |
+
try:
|
| 571 |
+
logger.warning("Performing emergency cleanup...")
|
| 572 |
+
|
| 573 |
+
# Brutal but effective memory clearing
|
| 574 |
+
for attr in ['active_sessions', 'processed_documents', 'document_chunks', 'uploaded_files', 'query_timings', 'chunking_statistics']:
|
| 575 |
+
if hasattr(self, attr):
|
| 576 |
+
getattr(self, attr).clear()
|
| 577 |
+
|
| 578 |
+
# Nullify heavy objects
|
| 579 |
+
self.index_builder = None
|
| 580 |
+
self.metadata_store = None
|
| 581 |
+
self.embedder = None
|
| 582 |
+
|
| 583 |
+
logger.warning("Emergency cleanup completed")
|
| 584 |
+
|
| 585 |
+
except:
|
| 586 |
+
# Last resort - don't crash during emergency cleanup
|
| 587 |
+
pass
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
async def graceful_shutdown(self):
|
| 591 |
+
"""
|
| 592 |
+
Graceful shutdown procedure
|
| 593 |
+
"""
|
| 594 |
+
logger.info("Starting graceful shutdown...")
|
| 595 |
+
|
| 596 |
+
# Notify clients (if any WebSocket connections)
|
| 597 |
+
await self._notify_clients()
|
| 598 |
+
|
| 599 |
+
# Perform cleanup
|
| 600 |
+
await CleanupManager.full_cleanup(self)
|
| 601 |
+
|
| 602 |
+
# Unregister from cleanup manager
|
| 603 |
+
for resource_id in self._cleanup_resources:
|
| 604 |
+
CleanupManager.unregister_resource(resource_id)
|
| 605 |
+
|
| 606 |
+
logger.info("Graceful shutdown completed")
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
async def _notify_clients(self):
|
| 610 |
+
"""
|
| 611 |
+
Notify connected clients of shutdown
|
| 612 |
+
"""
|
| 613 |
+
# Placeholder for WebSocket notifications
|
| 614 |
+
pass
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def add_query_timing(self, duration_ms: float):
|
| 618 |
+
"""
|
| 619 |
+
Record query timing for analytics
|
| 620 |
+
"""
|
| 621 |
+
self.query_timings.append((datetime.now(), duration_ms))
|
| 622 |
+
# Keep only last 1000 timings to prevent memory issues
|
| 623 |
+
if (len(self.query_timings) > 1000):
|
| 624 |
+
self.query_timings = self.query_timings[-1000:]
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def get_performance_metrics(self) -> Dict:
|
| 628 |
+
"""
|
| 629 |
+
Calculate performance metrics from recorded timings
|
| 630 |
+
"""
|
| 631 |
+
if not self.query_timings:
|
| 632 |
+
return {"avg_response_time" : 0,
|
| 633 |
+
"min_response_time" : 0,
|
| 634 |
+
"max_response_time" : 0,
|
| 635 |
+
"total_queries" : 0,
|
| 636 |
+
"queries_last_hour" : 0,
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# Get recent timings (last hour)
|
| 641 |
+
one_hour_ago = datetime.now() - timedelta(hours = 1)
|
| 642 |
+
recent_timings = [t for t, _ in self.query_timings if (t > one_hour_ago)]
|
| 643 |
+
|
| 644 |
+
# Calculate statistics
|
| 645 |
+
durations = [duration for _, duration in self.query_timings]
|
| 646 |
+
|
| 647 |
+
return {"avg_response_time" : int(sum(durations) / len(durations)),
|
| 648 |
+
"min_response_time" : int(min(durations)) if durations else 0,
|
| 649 |
+
"max_response_time" : int(max(durations)) if durations else 0,
|
| 650 |
+
"total_queries" : len(self.query_timings),
|
| 651 |
+
"queries_last_hour" : len(recent_timings),
|
| 652 |
+
"p95_response_time" : int(sorted(durations)[int(len(durations) * 0.95)]) if (len(durations) > 10) else 0,
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def get_chunking_statistics(self) -> Dict:
|
| 657 |
+
"""
|
| 658 |
+
Get statistics about chunking strategies used
|
| 659 |
+
"""
|
| 660 |
+
if not self.chunking_statistics:
|
| 661 |
+
return {"primary_strategy" : "adaptive",
|
| 662 |
+
"total_chunks" : 0,
|
| 663 |
+
"avg_chunk_size" : 0,
|
| 664 |
+
"strategies_used" : {},
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
total_chunks = sum(self.chunking_statistics.values())
|
| 668 |
+
strategies = {k: v for k, v in self.chunking_statistics.items() if (v > 0)}
|
| 669 |
+
|
| 670 |
+
return {"primary_strategy" : max(strategies.items(), key=lambda x: x[1])[0] if strategies else "adaptive",
|
| 671 |
+
"total_chunks" : total_chunks,
|
| 672 |
+
"strategies_used" : strategies,
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def get_system_health(self) -> Dict:
|
| 677 |
+
"""
|
| 678 |
+
Get comprehensive system health status
|
| 679 |
+
"""
|
| 680 |
+
llm_healthy = self.llm_client.check_health() if self.llm_client else False
|
| 681 |
+
vector_store_ready = self.is_ready
|
| 682 |
+
|
| 683 |
+
# Check embedding model
|
| 684 |
+
embedding_ready = self.embedder is not None
|
| 685 |
+
|
| 686 |
+
# Check retrieval components
|
| 687 |
+
retrieval_ready = (self.hybrid_retriever is not None and self.context_assembler is not None)
|
| 688 |
+
|
| 689 |
+
# Determine overall status
|
| 690 |
+
if all([llm_healthy, vector_store_ready, embedding_ready, retrieval_ready]):
|
| 691 |
+
overall_status = "healthy"
|
| 692 |
+
|
| 693 |
+
elif vector_store_ready and embedding_ready and retrieval_ready:
|
| 694 |
+
# LLM issues but RAG works
|
| 695 |
+
overall_status = "degraded"
|
| 696 |
+
|
| 697 |
+
else:
|
| 698 |
+
overall_status = "unhealthy"
|
| 699 |
+
|
| 700 |
+
return {"overall" : overall_status,
|
| 701 |
+
"llm" : llm_healthy,
|
| 702 |
+
"vector_store" : vector_store_ready,
|
| 703 |
+
"embeddings" : embedding_ready,
|
| 704 |
+
"retrieval" : retrieval_ready,
|
| 705 |
+
"generation" : self.response_generator is not None,
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def get_system_information(self) -> Dict:
|
| 710 |
+
"""
|
| 711 |
+
Get current system information
|
| 712 |
+
"""
|
| 713 |
+
# Get chunking strategy
|
| 714 |
+
chunking_strategy = "adaptive"
|
| 715 |
+
|
| 716 |
+
if self.chunking_selector:
|
| 717 |
+
try:
|
| 718 |
+
# Try to get strategy from selector
|
| 719 |
+
if (hasattr(self.chunking_selector, 'get_current_strategy')):
|
| 720 |
+
chunking_strategy = self.chunking_selector.get_current_strategy()
|
| 721 |
+
|
| 722 |
+
elif (hasattr(self.chunking_selector, 'prefer_llamaindex')):
|
| 723 |
+
chunking_strategy = "llama_index" if self.chunking_selector.prefer_llamaindex else "adaptive"
|
| 724 |
+
|
| 725 |
+
except:
|
| 726 |
+
pass
|
| 727 |
+
|
| 728 |
+
# Get vector store status
|
| 729 |
+
vector_store_status = "Not Ready"
|
| 730 |
+
|
| 731 |
+
if self.is_ready:
|
| 732 |
+
try:
|
| 733 |
+
index_stats = self.index_builder.get_index_stats() if self.index_builder else {}
|
| 734 |
+
total_chunks = index_stats.get('total_chunks_indexed', 0)
|
| 735 |
+
|
| 736 |
+
if (total_chunks > 0):
|
| 737 |
+
vector_store_status = f"Ready ({total_chunks} chunks)"
|
| 738 |
+
|
| 739 |
+
else:
|
| 740 |
+
vector_store_status = "Empty"
|
| 741 |
+
|
| 742 |
+
except:
|
| 743 |
+
vector_store_status = "Ready"
|
| 744 |
+
|
| 745 |
+
# Get model info
|
| 746 |
+
current_model = settings.OLLAMA_MODEL
|
| 747 |
+
embedding_model = settings.EMBEDDING_MODEL
|
| 748 |
+
|
| 749 |
+
# Uptime
|
| 750 |
+
uptime_seconds = (datetime.now() - self.start_time).total_seconds()
|
| 751 |
+
|
| 752 |
+
return {"vector_store_status" : vector_store_status,
|
| 753 |
+
"current_model" : current_model,
|
| 754 |
+
"embedding_model" : embedding_model,
|
| 755 |
+
"chunking_strategy" : chunking_strategy,
|
| 756 |
+
"system_uptime_seconds" : int(uptime_seconds),
|
| 757 |
+
"last_updated" : datetime.now().isoformat(),
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def calculate_quality_metrics(self) -> Dict:
|
| 762 |
+
"""
|
| 763 |
+
Calculate quality metrics for the system
|
| 764 |
+
"""
|
| 765 |
+
total_queries = 0
|
| 766 |
+
total_sources = 0
|
| 767 |
+
source_counts = list()
|
| 768 |
+
|
| 769 |
+
# Analyze session data
|
| 770 |
+
for session_id, messages in self.active_sessions.items():
|
| 771 |
+
total_queries += len(messages)
|
| 772 |
+
|
| 773 |
+
for msg in messages:
|
| 774 |
+
sources = len(msg.get('sources', []))
|
| 775 |
+
total_sources += sources
|
| 776 |
+
|
| 777 |
+
source_counts.append(sources)
|
| 778 |
+
|
| 779 |
+
# Calculate averages
|
| 780 |
+
avg_sources_per_query = total_sources / total_queries if total_queries > 0 else 0
|
| 781 |
+
|
| 782 |
+
# Calculate metrics based on heuristics
|
| 783 |
+
# These are simplified - for production, use RAGAS or similar framework
|
| 784 |
+
|
| 785 |
+
if (total_queries == 0):
|
| 786 |
+
return {"answer_relevancy" : 0.0,
|
| 787 |
+
"faithfulness" : 0.0,
|
| 788 |
+
"context_precision" : 0.0,
|
| 789 |
+
"context_recall" : None,
|
| 790 |
+
"overall_score" : 0.0,
|
| 791 |
+
"confidence" : "low",
|
| 792 |
+
"metrics_available" : False
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
# Heuristic calculations
|
| 796 |
+
answer_relevancy = min(0.9, 0.7 + (avg_sources_per_query * 0.1))
|
| 797 |
+
faithfulness = min(0.95, 0.8 + (avg_sources_per_query * 0.05))
|
| 798 |
+
context_precision = min(0.85, 0.6 + (avg_sources_per_query * 0.1))
|
| 799 |
+
|
| 800 |
+
# Overall score weighted average
|
| 801 |
+
overall_score = (answer_relevancy * 0.4 + faithfulness * 0.3 + context_precision * 0.3)
|
| 802 |
+
|
| 803 |
+
confidence = "high" if total_queries > 10 else ("medium" if (total_queries > 3) else "low")
|
| 804 |
+
|
| 805 |
+
return {"answer_relevancy" : round(answer_relevancy, 3),
|
| 806 |
+
"faithfulness" : round(faithfulness, 3),
|
| 807 |
+
"context_precision" : round(context_precision, 3),
|
| 808 |
+
"context_recall" : None, # Requires ground truth
|
| 809 |
+
"overall_score" : round(overall_score, 3),
|
| 810 |
+
"avg_sources_per_query" : round(avg_sources_per_query, 2),
|
| 811 |
+
"queries_with_sources" : sum(1 for count in source_counts if count > 0),
|
| 812 |
+
"confidence" : confidence,
|
| 813 |
+
"metrics_available" : True,
|
| 814 |
+
"evaluation_note" : "Metrics are heuristic estimates. For accurate evaluation, use RAGAS framework.",
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def calculate_comprehensive_analytics(self) -> Dict:
|
| 819 |
+
"""
|
| 820 |
+
Calculate comprehensive analytics data
|
| 821 |
+
"""
|
| 822 |
+
# Performance metrics
|
| 823 |
+
performance = self.get_performance_metrics()
|
| 824 |
+
|
| 825 |
+
# System information
|
| 826 |
+
system_info = self.get_system_information()
|
| 827 |
+
|
| 828 |
+
# Quality metrics
|
| 829 |
+
quality_metrics = self.calculate_quality_metrics()
|
| 830 |
+
|
| 831 |
+
# Health status
|
| 832 |
+
health_status = self.get_system_health()
|
| 833 |
+
|
| 834 |
+
# Chunking statistics
|
| 835 |
+
chunking_stats = self.get_chunking_statistics()
|
| 836 |
+
|
| 837 |
+
# Document statistics
|
| 838 |
+
total_docs = len(self.processed_documents)
|
| 839 |
+
total_chunks = sum(len(chunks) for chunks in self.document_chunks.values())
|
| 840 |
+
|
| 841 |
+
# Session statistics
|
| 842 |
+
total_sessions = len(self.active_sessions)
|
| 843 |
+
total_messages = sum(len(msgs) for msgs in self.active_sessions.values())
|
| 844 |
+
|
| 845 |
+
# File statistics
|
| 846 |
+
uploaded_files = len(self.uploaded_files)
|
| 847 |
+
total_file_size = sum(f.get('size', 0) for f in self.uploaded_files)
|
| 848 |
+
|
| 849 |
+
# Index statistics
|
| 850 |
+
index_stats = dict()
|
| 851 |
+
|
| 852 |
+
if self.index_builder:
|
| 853 |
+
try:
|
| 854 |
+
index_stats = self.index_builder.get_index_stats()
|
| 855 |
+
|
| 856 |
+
except:
|
| 857 |
+
index_stats = {"error": "Could not retrieve index stats"}
|
| 858 |
+
|
| 859 |
+
return {"performance_metrics" : performance,
|
| 860 |
+
"quality_metrics" : quality_metrics,
|
| 861 |
+
"system_information" : system_info,
|
| 862 |
+
"health_status" : health_status,
|
| 863 |
+
"chunking_statistics" : chunking_stats,
|
| 864 |
+
"document_statistics" : {"total_documents" : total_docs,
|
| 865 |
+
"total_chunks" : total_chunks,
|
| 866 |
+
"uploaded_files" : uploaded_files,
|
| 867 |
+
"total_file_size_bytes" : total_file_size,
|
| 868 |
+
"total_file_size_mb" : round(total_file_size / (1024 * 1024), 2) if (total_file_size > 0) else 0,
|
| 869 |
+
"avg_chunks_per_document" : round(total_chunks / total_docs, 2) if (total_docs > 0) else 0,
|
| 870 |
+
},
|
| 871 |
+
"session_statistics" : {"total_sessions" : total_sessions,
|
| 872 |
+
"total_messages" : total_messages,
|
| 873 |
+
"avg_messages_per_session" : round(total_messages / total_sessions, 2) if (total_sessions > 0) else 0
|
| 874 |
+
},
|
| 875 |
+
"index_statistics" : index_stats,
|
| 876 |
+
"calculated_at" : datetime.now().isoformat(),
|
| 877 |
+
"cache_info" : {"from_cache" : False,
|
| 878 |
+
"next_refresh_in" : self.analytics_cache.ttl_seconds,
|
| 879 |
+
}
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def _setup_signal_handlers():
|
| 884 |
+
"""
|
| 885 |
+
Setup signal handlers for graceful shutdown
|
| 886 |
+
"""
|
| 887 |
+
try:
|
| 888 |
+
signal.signal(signal.SIGINT, CleanupManager.handle_signal)
|
| 889 |
+
signal.signal(signal.SIGTERM, CleanupManager.handle_signal)
|
| 890 |
+
logger.debug("Signal handlers registered")
|
| 891 |
+
|
| 892 |
+
except Exception as e:
|
| 893 |
+
logger.warning(f"Failed to setup signal handlers: {e}")
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def _atexit_cleanup():
|
| 897 |
+
"""
|
| 898 |
+
Atexit handler as last resort
|
| 899 |
+
"""
|
| 900 |
+
logger.info("Atexit cleanup triggered")
|
| 901 |
+
|
| 902 |
+
# Check if it's already in a cleanup process
|
| 903 |
+
with _cleanup_lock:
|
| 904 |
+
if _is_cleaning:
|
| 905 |
+
logger.info("Cleanup already in progress, skipping atexit cleanup")
|
| 906 |
+
return
|
| 907 |
+
|
| 908 |
+
try:
|
| 909 |
+
# Check if app exists
|
| 910 |
+
if (('app' in globals()) and (hasattr(app.state, 'app'))):
|
| 911 |
+
# Run cleanup in background thread
|
| 912 |
+
cleanup_thread = threading.Thread(target = lambda: asyncio.run(CleanupManager.full_cleanup(app.state.app)),
|
| 913 |
+
name = "atexit_cleanup",
|
| 914 |
+
daemon = True,
|
| 915 |
+
)
|
| 916 |
+
cleanup_thread.start()
|
| 917 |
+
cleanup_thread.join(timeout = 5.0)
|
| 918 |
+
|
| 919 |
+
except Exception as e:
|
| 920 |
+
logger.error(f"Atexit cleanup error: {e}")
|
| 921 |
+
# Don't crash during atexit
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
async def _brute_force_cleanup_app_state(state: AppState):
|
| 925 |
+
"""
|
| 926 |
+
Brute force AppState cleanup
|
| 927 |
+
"""
|
| 928 |
+
try:
|
| 929 |
+
# Clear all collections
|
| 930 |
+
for attr_name in dir(state):
|
| 931 |
+
if not attr_name.startswith('_'):
|
| 932 |
+
attr = getattr(state, attr_name)
|
| 933 |
+
|
| 934 |
+
if isinstance(attr, (list, dict, set)):
|
| 935 |
+
attr.clear()
|
| 936 |
+
|
| 937 |
+
# Nullify heavy components
|
| 938 |
+
for attr_name in ['index_builder', 'metadata_store', 'embedder', 'llm_client', 'ragas_evaluator']:
|
| 939 |
+
if hasattr(state, attr_name):
|
| 940 |
+
setattr(state, attr_name, None)
|
| 941 |
+
|
| 942 |
+
except:
|
| 943 |
+
pass
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
# Application lifespan manager
|
| 947 |
+
@asynccontextmanager
|
| 948 |
+
async def lifespan(app: FastAPI):
|
| 949 |
+
"""
|
| 950 |
+
Manage application startup and shutdown with multiple cleanup guarantees
|
| 951 |
+
"""
|
| 952 |
+
# Setup signal handlers FIRST
|
| 953 |
+
_setup_signal_handlers()
|
| 954 |
+
|
| 955 |
+
# Register atexit cleanup
|
| 956 |
+
atexit.register(_atexit_cleanup)
|
| 957 |
+
|
| 958 |
+
logger.info("Starting AI Universal Knowledge Ingestion System...")
|
| 959 |
+
|
| 960 |
+
try:
|
| 961 |
+
# Initialize application state
|
| 962 |
+
app.state.app = AppState()
|
| 963 |
+
|
| 964 |
+
# Initialize core components
|
| 965 |
+
await initialize_components(app.state.app)
|
| 966 |
+
|
| 967 |
+
logger.info("Application startup complete. System ready.")
|
| 968 |
+
|
| 969 |
+
# Yield control to FastAPI
|
| 970 |
+
yield
|
| 971 |
+
|
| 972 |
+
except Exception as e:
|
| 973 |
+
logger.error(f"Application runtime error: {e}", exc_info = True)
|
| 974 |
+
raise
|
| 975 |
+
|
| 976 |
+
finally:
|
| 977 |
+
# GUARANTEED cleanup (even on crash)
|
| 978 |
+
logger.info("Beginning guaranteed cleanup sequence...")
|
| 979 |
+
|
| 980 |
+
# Set the cleaning flag
|
| 981 |
+
with _cleanup_lock:
|
| 982 |
+
_is_cleaning = True
|
| 983 |
+
|
| 984 |
+
try:
|
| 985 |
+
# Simple cleanup
|
| 986 |
+
if (hasattr(app.state, 'app')):
|
| 987 |
+
# Just clear the state, don't run full cleanup again
|
| 988 |
+
await _brute_force_cleanup_app_state(app.state.app)
|
| 989 |
+
|
| 990 |
+
# Clean up disk resources
|
| 991 |
+
await CleanupManager._cleanup_disk_async()
|
| 992 |
+
|
| 993 |
+
# Shutdown the executor
|
| 994 |
+
_cleanup_executor.shutdown(wait = True)
|
| 995 |
+
|
| 996 |
+
except Exception as e:
|
| 997 |
+
logger.error(f"Cleanup error in lifespan finally: {e}")
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
# Create FastAPI application
|
| 1002 |
+
app = FastAPI(title = "AI Universal Knowledge Ingestion System",
|
| 1003 |
+
description = "Enterprise RAG Platform with Multi-Source Ingestion",
|
| 1004 |
+
version = "1.0.0",
|
| 1005 |
+
lifespan = lifespan,
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
# Add CORS middleware
|
| 1010 |
+
app.add_middleware(CORSMiddleware,
|
| 1011 |
+
allow_origins = ["*"],
|
| 1012 |
+
allow_credentials = True,
|
| 1013 |
+
allow_methods = ["*"],
|
| 1014 |
+
allow_headers = ["*"],
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
# ============================================================================
|
| 1019 |
+
# INITIALIZATION FUNCTIONS
|
| 1020 |
+
# ============================================================================
|
| 1021 |
+
async def initialize_components(state: AppState):
|
| 1022 |
+
"""
|
| 1023 |
+
Initialize all application components
|
| 1024 |
+
"""
|
| 1025 |
+
try:
|
| 1026 |
+
logger.info("Initializing components...")
|
| 1027 |
+
|
| 1028 |
+
# Create necessary directories
|
| 1029 |
+
create_directories()
|
| 1030 |
+
|
| 1031 |
+
# Initialize utilities
|
| 1032 |
+
state.file_handler = FileHandler()
|
| 1033 |
+
logger.info("FileHandler initialized")
|
| 1034 |
+
|
| 1035 |
+
# Initialize document parsing
|
| 1036 |
+
state.parser_factory = get_parser_factory()
|
| 1037 |
+
logger.info(f"ParserFactory initialized with support for: {', '.join(state.parser_factory.get_supported_extensions())}")
|
| 1038 |
+
|
| 1039 |
+
# Initialize chunking
|
| 1040 |
+
state.chunking_selector = get_adaptive_selector()
|
| 1041 |
+
logger.info("AdaptiveChunkingSelector initialized")
|
| 1042 |
+
|
| 1043 |
+
# Initialize embeddings
|
| 1044 |
+
state.embedder = get_embedder()
|
| 1045 |
+
state.embedding_cache = get_embedding_cache()
|
| 1046 |
+
logger.info(f"Embedder initialized: {state.embedder.get_model_info()}")
|
| 1047 |
+
|
| 1048 |
+
# Initialize ingestion
|
| 1049 |
+
state.ingestion_router = get_ingestion_router()
|
| 1050 |
+
state.progress_tracker = get_progress_tracker()
|
| 1051 |
+
logger.info("Ingestion components initialized")
|
| 1052 |
+
|
| 1053 |
+
# Initialize vector store
|
| 1054 |
+
state.index_builder = get_index_builder()
|
| 1055 |
+
state.metadata_store = get_metadata_store()
|
| 1056 |
+
logger.info("Vector store components initialized")
|
| 1057 |
+
|
| 1058 |
+
# Check if indexes exist and load them
|
| 1059 |
+
if state.index_builder.is_index_built():
|
| 1060 |
+
logger.info("Existing indexes found - loading...")
|
| 1061 |
+
index_stats = state.index_builder.get_index_stats()
|
| 1062 |
+
|
| 1063 |
+
logger.info(f"Indexes loaded: {index_stats}")
|
| 1064 |
+
state.is_ready = True
|
| 1065 |
+
|
| 1066 |
+
# Initialize retrieval
|
| 1067 |
+
state.hybrid_retriever = get_hybrid_retriever()
|
| 1068 |
+
state.context_assembler = get_context_assembler()
|
| 1069 |
+
logger.info("Retrieval components initialized")
|
| 1070 |
+
|
| 1071 |
+
# Initialize generation components
|
| 1072 |
+
state.response_generator = get_response_generator(provider = LLMProvider.OLLAMA,
|
| 1073 |
+
model_name = settings.OLLAMA_MODEL,
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
state.llm_client = get_llm_client(provider = LLMProvider.OLLAMA)
|
| 1077 |
+
|
| 1078 |
+
logger.info(f"Generation components initialized: model={settings.OLLAMA_MODEL}")
|
| 1079 |
+
|
| 1080 |
+
# Check LLM health
|
| 1081 |
+
if state.llm_client.check_health():
|
| 1082 |
+
logger.info("LLM provider health check: PASSED")
|
| 1083 |
+
|
| 1084 |
+
else:
|
| 1085 |
+
logger.warning("LLM provider health check: FAILED - Ensure Ollama is running")
|
| 1086 |
+
logger.warning("- Run: ollama serve (in a separate terminal)")
|
| 1087 |
+
logger.warning("- Run: ollama pull mistral (if model not downloaded)")
|
| 1088 |
+
|
| 1089 |
+
# Initialize RAGAS evaluator
|
| 1090 |
+
if settings.ENABLE_RAGAS:
|
| 1091 |
+
state.ragas_evaluator = get_ragas_evaluator(enable_ground_truth_metrics = settings.RAGAS_ENABLE_GROUND_TRUTH)
|
| 1092 |
+
|
| 1093 |
+
logger.info("RAGAS evaluator initialized")
|
| 1094 |
+
|
| 1095 |
+
else:
|
| 1096 |
+
logger.info("RAGAS evaluation disabled in settings")
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
logger.info("All components initialized successfully")
|
| 1100 |
+
|
| 1101 |
+
except Exception as e:
|
| 1102 |
+
logger.error(f"Component initialization failed: {e}", exc_info = True)
|
| 1103 |
+
raise
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
+
async def cleanup_components(state: AppState):
|
| 1107 |
+
"""
|
| 1108 |
+
Cleanup components on shutdown
|
| 1109 |
+
"""
|
| 1110 |
+
try:
|
| 1111 |
+
logger.info("Starting component cleanup...")
|
| 1112 |
+
|
| 1113 |
+
# Use the cleanup manager
|
| 1114 |
+
await CleanupManager.full_cleanup(state)
|
| 1115 |
+
|
| 1116 |
+
logger.info("Component cleanup complete")
|
| 1117 |
+
|
| 1118 |
+
except Exception as e:
|
| 1119 |
+
logger.error(f"Component cleanup error: {e}", exc_info = True)
|
| 1120 |
+
|
| 1121 |
+
# Last-ditch effort
|
| 1122 |
+
await _brute_force_cleanup_app_state(state)
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
def create_directories():
|
| 1126 |
+
"""
|
| 1127 |
+
Create necessary directories
|
| 1128 |
+
"""
|
| 1129 |
+
directories = [settings.UPLOAD_DIR,
|
| 1130 |
+
settings.VECTOR_STORE_DIR,
|
| 1131 |
+
settings.BACKUP_DIR,
|
| 1132 |
+
Path(settings.METADATA_DB_PATH).parent,
|
| 1133 |
+
settings.LOG_DIR,
|
| 1134 |
+
]
|
| 1135 |
+
|
| 1136 |
+
for directory in directories:
|
| 1137 |
+
Path(directory).mkdir(parents = True, exist_ok = True)
|
| 1138 |
+
|
| 1139 |
+
logger.info("Directories created/verified")
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
# ============================================================================
|
| 1143 |
+
# API ENDPOINTS
|
| 1144 |
+
# ============================================================================
|
| 1145 |
+
@app.get("/", response_class = HTMLResponse)
|
| 1146 |
+
async def serve_frontend():
|
| 1147 |
+
"""
|
| 1148 |
+
Serve the main frontend HTML
|
| 1149 |
+
"""
|
| 1150 |
+
frontend_path = Path("frontend/index.html")
|
| 1151 |
+
if frontend_path.exists():
|
| 1152 |
+
return FileResponse(frontend_path)
|
| 1153 |
+
|
| 1154 |
+
raise HTTPException(status_code = 404,
|
| 1155 |
+
detail = "Frontend not found",
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
|
| 1159 |
+
@app.get("/api/health")
|
| 1160 |
+
async def health_check():
|
| 1161 |
+
"""
|
| 1162 |
+
Health check endpoint
|
| 1163 |
+
"""
|
| 1164 |
+
state = app.state.app
|
| 1165 |
+
|
| 1166 |
+
health_status = state.get_system_health()
|
| 1167 |
+
|
| 1168 |
+
return {"status" : health_status["overall"],
|
| 1169 |
+
"timestamp" : datetime.now().isoformat(),
|
| 1170 |
+
"version" : "1.0.0",
|
| 1171 |
+
"components" : {"vector_store" : health_status["vector_store"],
|
| 1172 |
+
"llm" : health_status["llm"],
|
| 1173 |
+
"embeddings" : health_status["embeddings"],
|
| 1174 |
+
"retrieval" : health_status["retrieval"],
|
| 1175 |
+
"generation" : health_status["generation"],
|
| 1176 |
+
"hybrid_retriever" : health_status["retrieval"],
|
| 1177 |
+
},
|
| 1178 |
+
"details" : health_status
|
| 1179 |
+
}
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
@app.get("/api/system-info")
|
| 1183 |
+
async def get_system_info():
|
| 1184 |
+
"""
|
| 1185 |
+
Get system information and status
|
| 1186 |
+
"""
|
| 1187 |
+
state = app.state.app
|
| 1188 |
+
|
| 1189 |
+
# Get system information
|
| 1190 |
+
system_info = state.get_system_information()
|
| 1191 |
+
|
| 1192 |
+
# Get LLM provider info
|
| 1193 |
+
llm_info = dict()
|
| 1194 |
+
|
| 1195 |
+
if state.llm_client:
|
| 1196 |
+
llm_info = state.llm_client.get_provider_info()
|
| 1197 |
+
|
| 1198 |
+
# Get current configuration
|
| 1199 |
+
current_config = {"inference_model" : settings.OLLAMA_MODEL,
|
| 1200 |
+
"embedding_model" : settings.EMBEDDING_MODEL,
|
| 1201 |
+
"vector_weight" : settings.VECTOR_WEIGHT,
|
| 1202 |
+
"bm25_weight" : settings.BM25_WEIGHT,
|
| 1203 |
+
"temperature" : settings.DEFAULT_TEMPERATURE,
|
| 1204 |
+
"max_tokens" : settings.MAX_TOKENS,
|
| 1205 |
+
"chunk_size" : settings.FIXED_CHUNK_SIZE,
|
| 1206 |
+
"chunk_overlap" : settings.FIXED_CHUNK_OVERLAP,
|
| 1207 |
+
"top_k_retrieve" : settings.TOP_K_RETRIEVE,
|
| 1208 |
+
"enable_reranking" : settings.ENABLE_RERANKING,
|
| 1209 |
+
}
|
| 1210 |
+
|
| 1211 |
+
return {"system_state" : {"is_ready" : state.is_ready,
|
| 1212 |
+
"processing_status" : state.processing_status,
|
| 1213 |
+
"total_documents" : len(state.uploaded_files),
|
| 1214 |
+
"active_sessions" : len(state.active_sessions),
|
| 1215 |
+
},
|
| 1216 |
+
"configuration" : current_config,
|
| 1217 |
+
"llm_provider" : llm_info,
|
| 1218 |
+
"system_information" : system_info,
|
| 1219 |
+
"timestamp" : datetime.now().isoformat()
|
| 1220 |
+
}
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
@app.post("/api/upload")
|
| 1224 |
+
async def upload_files(files: List[UploadFile] = File(...)):
|
| 1225 |
+
"""
|
| 1226 |
+
Upload multiple files
|
| 1227 |
+
"""
|
| 1228 |
+
state = app.state.app
|
| 1229 |
+
|
| 1230 |
+
try:
|
| 1231 |
+
logger.info(f"Received {len(files)} files for upload")
|
| 1232 |
+
uploaded_info = list()
|
| 1233 |
+
|
| 1234 |
+
for file in files:
|
| 1235 |
+
try:
|
| 1236 |
+
# Validate file type
|
| 1237 |
+
if not state.parser_factory.is_supported(Path(file.filename)):
|
| 1238 |
+
logger.warning(f"Unsupported file type: {file.filename}")
|
| 1239 |
+
continue
|
| 1240 |
+
|
| 1241 |
+
# Save file to upload directory
|
| 1242 |
+
file_path = settings.UPLOAD_DIR / FileHandler.generate_unique_filename(file.filename, settings.UPLOAD_DIR)
|
| 1243 |
+
|
| 1244 |
+
# Write file content
|
| 1245 |
+
content = await file.read()
|
| 1246 |
+
|
| 1247 |
+
with open(file_path, 'wb') as f:
|
| 1248 |
+
f.write(content)
|
| 1249 |
+
|
| 1250 |
+
# Get file metadata
|
| 1251 |
+
file_metadata = FileHandler.get_file_metadata(file_path)
|
| 1252 |
+
|
| 1253 |
+
file_info = {"filename" : file_path.name,
|
| 1254 |
+
"original_name" : file.filename,
|
| 1255 |
+
"size" : file_metadata["size_bytes"],
|
| 1256 |
+
"upload_time" : datetime.now().isoformat(),
|
| 1257 |
+
"file_path" : str(file_path),
|
| 1258 |
+
"status" : "uploaded",
|
| 1259 |
+
}
|
| 1260 |
+
|
| 1261 |
+
uploaded_info.append(file_info)
|
| 1262 |
+
state.uploaded_files.append(file_info)
|
| 1263 |
+
|
| 1264 |
+
logger.info(f"Uploaded: {file.filename} -> {file_path.name}")
|
| 1265 |
+
|
| 1266 |
+
except Exception as e:
|
| 1267 |
+
logger.error(f"Failed to upload {file.filename}: {e}")
|
| 1268 |
+
continue
|
| 1269 |
+
|
| 1270 |
+
# Clear analytics cache since we have new data
|
| 1271 |
+
state.analytics_cache.data = None
|
| 1272 |
+
|
| 1273 |
+
return {"success" : True,
|
| 1274 |
+
"message" : f"Successfully uploaded {len(uploaded_info)} files",
|
| 1275 |
+
"files" : uploaded_info,
|
| 1276 |
+
}
|
| 1277 |
+
|
| 1278 |
+
except Exception as e:
|
| 1279 |
+
logger.error(f"Upload error: {e}", exc_info = True)
|
| 1280 |
+
|
| 1281 |
+
raise HTTPException(status_code = 500,
|
| 1282 |
+
detail = str(e),
|
| 1283 |
+
)
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
@app.post("/api/start-processing")
|
| 1287 |
+
async def start_processing():
|
| 1288 |
+
"""
|
| 1289 |
+
Start processing uploaded documents
|
| 1290 |
+
"""
|
| 1291 |
+
state = app.state.app
|
| 1292 |
+
|
| 1293 |
+
if not state.uploaded_files:
|
| 1294 |
+
raise HTTPException(status_code = 400,
|
| 1295 |
+
detail = "No files uploaded",
|
| 1296 |
+
)
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
if (state.processing_status == "processing"):
|
| 1300 |
+
raise HTTPException(status_code = 400,
|
| 1301 |
+
detail = "Processing already in progress",
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
try:
|
| 1305 |
+
state.processing_status = "processing"
|
| 1306 |
+
state.processing_progress = {"status" : "processing",
|
| 1307 |
+
"current_step" : "Starting document processing...",
|
| 1308 |
+
"progress" : 0,
|
| 1309 |
+
"processed" : 0,
|
| 1310 |
+
"total" : len(state.uploaded_files),
|
| 1311 |
+
"details" : {},
|
| 1312 |
+
}
|
| 1313 |
+
|
| 1314 |
+
logger.info("Starting document processing pipeline...")
|
| 1315 |
+
|
| 1316 |
+
all_chunks = list()
|
| 1317 |
+
chunking_stats = dict()
|
| 1318 |
+
|
| 1319 |
+
# Process each file
|
| 1320 |
+
for idx, file_info in enumerate(state.uploaded_files):
|
| 1321 |
+
try:
|
| 1322 |
+
file_path = Path(file_info["file_path"])
|
| 1323 |
+
|
| 1324 |
+
# Update progress - Parsing
|
| 1325 |
+
state.processing_progress["current_step"] = f"Parsing {file_info['original_name']}..."
|
| 1326 |
+
state.processing_progress["progress"] = int((idx / len(state.uploaded_files)) * 20)
|
| 1327 |
+
|
| 1328 |
+
# Parse document
|
| 1329 |
+
logger.info(f"Parsing document: {file_path}")
|
| 1330 |
+
text, metadata = state.parser_factory.parse(file_path,
|
| 1331 |
+
extract_metadata = True,
|
| 1332 |
+
clean_text = True,
|
| 1333 |
+
)
|
| 1334 |
+
|
| 1335 |
+
if not text:
|
| 1336 |
+
logger.warning(f"No text extracted from {file_path}")
|
| 1337 |
+
continue
|
| 1338 |
+
|
| 1339 |
+
logger.info(f"Extracted {len(text)} characters from {file_path}")
|
| 1340 |
+
|
| 1341 |
+
# Update progress - Chunking
|
| 1342 |
+
state.processing_progress["current_step"] = f"Chunking {file_info['original_name']}..."
|
| 1343 |
+
state.processing_progress["progress"] = int((idx / len(state.uploaded_files)) * 40) + 20
|
| 1344 |
+
|
| 1345 |
+
# Chunk document
|
| 1346 |
+
logger.info(f"Chunking document: {metadata.document_id}")
|
| 1347 |
+
chunks = state.chunking_selector.chunk_text(text = text,
|
| 1348 |
+
metadata = metadata,
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
# Get strategy used from metadata or selector
|
| 1352 |
+
strategy_used = "adaptive" # Default
|
| 1353 |
+
|
| 1354 |
+
if (metadata and hasattr(metadata, 'chunking_strategy')):
|
| 1355 |
+
strategy_used = metadata.chunking_strategy.value if metadata.chunking_strategy else "adaptive"
|
| 1356 |
+
|
| 1357 |
+
elif (hasattr(state.chunking_selector, 'last_strategy_used')):
|
| 1358 |
+
strategy_used = state.chunking_selector.last_strategy_used
|
| 1359 |
+
|
| 1360 |
+
# Track chunking strategy usage
|
| 1361 |
+
if strategy_used not in chunking_stats:
|
| 1362 |
+
chunking_stats[strategy_used] = 0
|
| 1363 |
+
|
| 1364 |
+
chunking_stats[strategy_used] += len(chunks)
|
| 1365 |
+
|
| 1366 |
+
logger.info(f"Created {len(chunks)} chunks for {metadata.document_id} using {strategy_used}")
|
| 1367 |
+
|
| 1368 |
+
# Update progress - Embedding
|
| 1369 |
+
state.processing_progress["current_step"] = f"Generating embeddings for {file_info['original_name']}..."
|
| 1370 |
+
state.processing_progress["progress"] = int((idx / len(state.uploaded_files)) * 60) + 40
|
| 1371 |
+
|
| 1372 |
+
# Generate embeddings for chunks
|
| 1373 |
+
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
|
| 1374 |
+
chunks_with_embeddings = state.embedder.embed_chunks(chunks = chunks,
|
| 1375 |
+
batch_size = settings.EMBEDDING_BATCH_SIZE,
|
| 1376 |
+
normalize = True,
|
| 1377 |
+
)
|
| 1378 |
+
|
| 1379 |
+
logger.info(f"Generated embeddings for {len(chunks_with_embeddings)} chunks")
|
| 1380 |
+
|
| 1381 |
+
# Store chunks
|
| 1382 |
+
all_chunks.extend(chunks_with_embeddings)
|
| 1383 |
+
|
| 1384 |
+
# Store processed document and chunks
|
| 1385 |
+
state.processed_documents[metadata.document_id] = {"metadata" : metadata,
|
| 1386 |
+
"text" : text,
|
| 1387 |
+
"file_info" : file_info,
|
| 1388 |
+
"chunks_count" : len(chunks_with_embeddings),
|
| 1389 |
+
"processed_time" : datetime.now().isoformat(),
|
| 1390 |
+
"chunking_strategy" : strategy_used,
|
| 1391 |
+
}
|
| 1392 |
+
|
| 1393 |
+
state.document_chunks[metadata.document_id] = chunks_with_embeddings
|
| 1394 |
+
|
| 1395 |
+
# Update progress
|
| 1396 |
+
state.processing_progress["processed"] = idx + 1
|
| 1397 |
+
|
| 1398 |
+
except Exception as e:
|
| 1399 |
+
logger.error(f"Failed to process {file_info['original_name']}: {e}", exc_info=True)
|
| 1400 |
+
continue
|
| 1401 |
+
|
| 1402 |
+
# Update chunking statistics
|
| 1403 |
+
state.chunking_statistics = chunking_stats
|
| 1404 |
+
|
| 1405 |
+
if not all_chunks:
|
| 1406 |
+
raise Exception("No chunks were successfully processed")
|
| 1407 |
+
|
| 1408 |
+
# Update progress - Building indexes
|
| 1409 |
+
state.processing_progress["current_step"] = "Building vector and keyword indexes..."
|
| 1410 |
+
state.processing_progress["progress"] = 80
|
| 1411 |
+
|
| 1412 |
+
# Build indexes (FAISS + BM25 + Metadata)
|
| 1413 |
+
logger.info(f"Building indexes for {len(all_chunks)} chunks...")
|
| 1414 |
+
index_stats = state.index_builder.build_indexes(chunks = all_chunks,
|
| 1415 |
+
rebuild = True,
|
| 1416 |
+
)
|
| 1417 |
+
|
| 1418 |
+
logger.info(f"Indexes built: {index_stats}")
|
| 1419 |
+
|
| 1420 |
+
# Update progress - Indexing for hybrid retrieval
|
| 1421 |
+
state.processing_progress["current_step"] = "Indexing for hybrid retrieval..."
|
| 1422 |
+
state.processing_progress["progress"] = 95
|
| 1423 |
+
|
| 1424 |
+
# Mark as ready
|
| 1425 |
+
state.processing_status = "ready"
|
| 1426 |
+
state.is_ready = True
|
| 1427 |
+
state.processing_progress["status"] = "ready"
|
| 1428 |
+
state.processing_progress["current_step"] = "Processing complete"
|
| 1429 |
+
state.processing_progress["progress"] = 100
|
| 1430 |
+
|
| 1431 |
+
# Clear analytics cache
|
| 1432 |
+
state.analytics_cache.data = None
|
| 1433 |
+
|
| 1434 |
+
logger.info(f"Processing complete. Processed {len(state.processed_documents)} documents with {len(all_chunks)} total chunks.")
|
| 1435 |
+
|
| 1436 |
+
return {"success" : True,
|
| 1437 |
+
"message" : "Processing completed successfully",
|
| 1438 |
+
"status" : "ready",
|
| 1439 |
+
"documents_processed" : len(state.processed_documents),
|
| 1440 |
+
"total_chunks" : len(all_chunks),
|
| 1441 |
+
"chunking_statistics" : chunking_stats,
|
| 1442 |
+
"index_stats" : index_stats,
|
| 1443 |
+
}
|
| 1444 |
+
|
| 1445 |
+
except Exception as e:
|
| 1446 |
+
state.processing_status = "error"
|
| 1447 |
+
state.processing_progress["status"] = "error"
|
| 1448 |
+
|
| 1449 |
+
logger.error(f"Processing error: {e}", exc_info = True)
|
| 1450 |
+
|
| 1451 |
+
raise HTTPException(status_code = 500,
|
| 1452 |
+
detail = str(e),
|
| 1453 |
+
)
|
| 1454 |
+
|
| 1455 |
+
|
| 1456 |
+
@app.get("/api/processing-status")
|
| 1457 |
+
async def get_processing_status():
|
| 1458 |
+
"""
|
| 1459 |
+
Get current processing status
|
| 1460 |
+
"""
|
| 1461 |
+
state = app.state.app
|
| 1462 |
+
|
| 1463 |
+
return {"status" : state.processing_progress["status"],
|
| 1464 |
+
"progress" : state.processing_progress["progress"],
|
| 1465 |
+
"current_step" : state.processing_progress["current_step"],
|
| 1466 |
+
"processed_documents" : state.processing_progress["processed"],
|
| 1467 |
+
"total_documents" : state.processing_progress["total"],
|
| 1468 |
+
"details" : state.processing_progress["details"],
|
| 1469 |
+
}
|
| 1470 |
+
|
| 1471 |
+
|
| 1472 |
+
@app.get("/api/ragas/history")
|
| 1473 |
+
async def get_ragas_history():
|
| 1474 |
+
"""
|
| 1475 |
+
Get RAGAS evaluation history for current session
|
| 1476 |
+
"""
|
| 1477 |
+
state = app.state.app
|
| 1478 |
+
|
| 1479 |
+
if not settings.ENABLE_RAGAS or not state.ragas_evaluator:
|
| 1480 |
+
|
| 1481 |
+
raise HTTPException(status_code = 400,
|
| 1482 |
+
detail = "RAGAS evaluation is not enabled. Set ENABLE_RAGAS=True in settings.",
|
| 1483 |
+
)
|
| 1484 |
+
|
| 1485 |
+
try:
|
| 1486 |
+
history = state.ragas_evaluator.get_evaluation_history()
|
| 1487 |
+
stats = state.ragas_evaluator.get_session_statistics()
|
| 1488 |
+
|
| 1489 |
+
return {"success" : True,
|
| 1490 |
+
"total_count" : len(history),
|
| 1491 |
+
"statistics" : stats.model_dump(),
|
| 1492 |
+
"history" : history
|
| 1493 |
+
}
|
| 1494 |
+
|
| 1495 |
+
except Exception as e:
|
| 1496 |
+
logger.error(f"RAGAS history retrieval error: {e}", exc_info = True)
|
| 1497 |
+
|
| 1498 |
+
raise HTTPException(status_code = 500,
|
| 1499 |
+
detail = str(e),
|
| 1500 |
+
)
|
| 1501 |
+
|
| 1502 |
+
|
| 1503 |
+
@app.post("/api/ragas/clear")
|
| 1504 |
+
async def clear_ragas_history():
|
| 1505 |
+
"""
|
| 1506 |
+
Clear RAGAS evaluation history
|
| 1507 |
+
"""
|
| 1508 |
+
state = app.state.app
|
| 1509 |
+
|
| 1510 |
+
if not settings.ENABLE_RAGAS or not state.ragas_evaluator:
|
| 1511 |
+
|
| 1512 |
+
raise HTTPException(status_code = 400,
|
| 1513 |
+
detail = "RAGAS evaluation is not enabled.",
|
| 1514 |
+
)
|
| 1515 |
+
|
| 1516 |
+
try:
|
| 1517 |
+
state.ragas_evaluator.clear_history()
|
| 1518 |
+
|
| 1519 |
+
return {"success" : True,
|
| 1520 |
+
"message" : "RAGAS evaluation history cleared, new session started",
|
| 1521 |
+
}
|
| 1522 |
+
|
| 1523 |
+
except Exception as e:
|
| 1524 |
+
logger.error(f"RAGAS history clear error: {e}", exc_info = True)
|
| 1525 |
+
|
| 1526 |
+
raise HTTPException(status_code = 500,
|
| 1527 |
+
detail = str(e),
|
| 1528 |
+
)
|
| 1529 |
+
|
| 1530 |
+
|
| 1531 |
+
@app.get("/api/ragas/statistics")
|
| 1532 |
+
async def get_ragas_statistics():
|
| 1533 |
+
"""
|
| 1534 |
+
Get aggregate RAGAS statistics for current session
|
| 1535 |
+
"""
|
| 1536 |
+
state = app.state.app
|
| 1537 |
+
|
| 1538 |
+
if not settings.ENABLE_RAGAS or not state.ragas_evaluator:
|
| 1539 |
+
|
| 1540 |
+
raise HTTPException(status_code = 400,
|
| 1541 |
+
detail = "RAGAS evaluation is not enabled.",
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
try:
|
| 1545 |
+
stats = state.ragas_evaluator.get_session_statistics()
|
| 1546 |
+
|
| 1547 |
+
return {"success" : True,
|
| 1548 |
+
"statistics" : stats.model_dump(),
|
| 1549 |
+
}
|
| 1550 |
+
|
| 1551 |
+
except Exception as e:
|
| 1552 |
+
logger.error(f"RAGAS statistics error: {e}", exc_info = True)
|
| 1553 |
+
|
| 1554 |
+
raise HTTPException(status_code = 500,
|
| 1555 |
+
detail = str(e),
|
| 1556 |
+
)
|
| 1557 |
+
|
| 1558 |
+
|
| 1559 |
+
@app.get("/api/ragas/export")
|
| 1560 |
+
async def export_ragas_data():
|
| 1561 |
+
"""
|
| 1562 |
+
Export all RAGAS evaluation data
|
| 1563 |
+
"""
|
| 1564 |
+
state = app.state.app
|
| 1565 |
+
|
| 1566 |
+
if not settings.ENABLE_RAGAS or not state.ragas_evaluator:
|
| 1567 |
+
|
| 1568 |
+
raise HTTPException(status_code = 400,
|
| 1569 |
+
detail = "RAGAS evaluation is not enabled.",
|
| 1570 |
+
)
|
| 1571 |
+
|
| 1572 |
+
try:
|
| 1573 |
+
export_data = state.ragas_evaluator.export_to_dict()
|
| 1574 |
+
|
| 1575 |
+
return JSONResponse(content = json.loads(export_data.model_dump_json()))
|
| 1576 |
+
|
| 1577 |
+
except Exception as e:
|
| 1578 |
+
logger.error(f"RAGAS export error: {e}", exc_info = True)
|
| 1579 |
+
|
| 1580 |
+
raise HTTPException(status_code = 500,
|
| 1581 |
+
detail = str(e),
|
| 1582 |
+
)
|
| 1583 |
+
|
| 1584 |
+
|
| 1585 |
+
@app.get("/api/ragas/config")
|
| 1586 |
+
async def get_ragas_config():
|
| 1587 |
+
"""
|
| 1588 |
+
Get current RAGAS configuration
|
| 1589 |
+
"""
|
| 1590 |
+
return {"enabled" : settings.ENABLE_RAGAS,
|
| 1591 |
+
"ground_truth_enabled" : settings.RAGAS_ENABLE_GROUND_TRUTH,
|
| 1592 |
+
"base_metrics" : settings.RAGAS_METRICS,
|
| 1593 |
+
"ground_truth_metrics" : settings.RAGAS_GROUND_TRUTH_METRICS,
|
| 1594 |
+
"evaluation_timeout" : settings.RAGAS_EVALUATION_TIMEOUT,
|
| 1595 |
+
"batch_size" : settings.RAGAS_BATCH_SIZE,
|
| 1596 |
+
}
|
| 1597 |
+
|
| 1598 |
+
|
| 1599 |
+
@app.post("/api/chat")
|
| 1600 |
+
async def chat(request: ChatRequest):
|
| 1601 |
+
"""
|
| 1602 |
+
Handle chat queries with LLM-based intelligent routing (generic vs RAG)
|
| 1603 |
+
Supports both conversational queries and document-based queries
|
| 1604 |
+
"""
|
| 1605 |
+
state = app.state.app
|
| 1606 |
+
|
| 1607 |
+
message = request.message
|
| 1608 |
+
session_id = request.session_id
|
| 1609 |
+
|
| 1610 |
+
# Check LLM health (required for both general and RAG queries)
|
| 1611 |
+
if not state.llm_client.check_health():
|
| 1612 |
+
raise HTTPException(status_code = 503,
|
| 1613 |
+
detail = "LLM service unavailable. Please ensure Ollama is running.",
|
| 1614 |
+
)
|
| 1615 |
+
|
| 1616 |
+
try:
|
| 1617 |
+
logger.info(f"Chat query received: {message}")
|
| 1618 |
+
|
| 1619 |
+
# Check if documents are available
|
| 1620 |
+
has_documents = state.is_ready and (len(state.processed_documents) > 0)
|
| 1621 |
+
|
| 1622 |
+
logger.debug(f"System state - Documents available: {has_documents}, Processed docs: {len(state.processed_documents)}, System ready: {state.is_ready}")
|
| 1623 |
+
|
| 1624 |
+
# Get conversation history for this session (for general queries)
|
| 1625 |
+
conversation_history = None
|
| 1626 |
+
|
| 1627 |
+
if (session_id and (session_id in state.active_sessions)):
|
| 1628 |
+
# Convert to format expected by general_responder
|
| 1629 |
+
conversation_history = list()
|
| 1630 |
+
|
| 1631 |
+
# Last 10 messages for context
|
| 1632 |
+
for msg in state.active_sessions[session_id][-10:]:
|
| 1633 |
+
conversation_history.append({"role" : "user",
|
| 1634 |
+
"content" : msg.get("query", ""),
|
| 1635 |
+
})
|
| 1636 |
+
|
| 1637 |
+
conversation_history.append({"role" : "assistant",
|
| 1638 |
+
"content" : msg.get("response", ""),
|
| 1639 |
+
})
|
| 1640 |
+
|
| 1641 |
+
# Create QueryRequest object
|
| 1642 |
+
query_request = QueryRequest(query = message,
|
| 1643 |
+
top_k = settings.TOP_K_RETRIEVE,
|
| 1644 |
+
enable_reranking = settings.ENABLE_RERANKING,
|
| 1645 |
+
temperature = settings.DEFAULT_TEMPERATURE,
|
| 1646 |
+
top_p = settings.TOP_P,
|
| 1647 |
+
max_tokens = settings.MAX_TOKENS,
|
| 1648 |
+
include_sources = True,
|
| 1649 |
+
include_metrics = False,
|
| 1650 |
+
stream = False,
|
| 1651 |
+
)
|
| 1652 |
+
|
| 1653 |
+
# Generate response using response generator (with LLM-based routing)
|
| 1654 |
+
start_time = time.time()
|
| 1655 |
+
|
| 1656 |
+
query_response = await state.response_generator.generate_response(request = query_request,
|
| 1657 |
+
conversation_history = conversation_history,
|
| 1658 |
+
has_documents = has_documents, # Pass document availability
|
| 1659 |
+
)
|
| 1660 |
+
|
| 1661 |
+
# Convert to ms
|
| 1662 |
+
total_time = (time.time() - start_time) * 1000
|
| 1663 |
+
|
| 1664 |
+
# Record timing for analytics
|
| 1665 |
+
state.add_query_timing(total_time)
|
| 1666 |
+
|
| 1667 |
+
# Determine query type using response metadata
|
| 1668 |
+
is_general_query = False
|
| 1669 |
+
|
| 1670 |
+
# Default to rag
|
| 1671 |
+
actual_query_type = "rag"
|
| 1672 |
+
|
| 1673 |
+
# Check if response has metadata
|
| 1674 |
+
if (hasattr(query_response, 'query_type')):
|
| 1675 |
+
actual_query_type = query_response.query_type
|
| 1676 |
+
is_general_query = (actual_query_type == "general")
|
| 1677 |
+
|
| 1678 |
+
elif (hasattr(query_response, 'is_general_query')):
|
| 1679 |
+
is_general_query = query_response.is_general_query
|
| 1680 |
+
actual_query_type = "general" if is_general_query else "rag"
|
| 1681 |
+
|
| 1682 |
+
else:
|
| 1683 |
+
# Method 2: Check sources (fallback)
|
| 1684 |
+
has_sources = query_response.sources and len(query_response.sources) > 0
|
| 1685 |
+
is_general_query = not has_sources
|
| 1686 |
+
actual_query_type = "general" if is_general_query else "rag"
|
| 1687 |
+
|
| 1688 |
+
logger.debug(f"Query classification: actual_query_type={actual_query_type}, has_sources={query_response.sources and len(query_response.sources) > 0}")
|
| 1689 |
+
|
| 1690 |
+
# Extract contexts for RAGAS evaluation (only if RAG was used)
|
| 1691 |
+
contexts = list()
|
| 1692 |
+
|
| 1693 |
+
if query_response.sources:
|
| 1694 |
+
contexts = [chunk.chunk.text for chunk in query_response.sources]
|
| 1695 |
+
|
| 1696 |
+
# Run RAGAS evaluation (only if RAGAS enabled)
|
| 1697 |
+
ragas_result = None
|
| 1698 |
+
|
| 1699 |
+
if (settings.ENABLE_RAGAS and state.ragas_evaluator):
|
| 1700 |
+
try:
|
| 1701 |
+
logger.info("Running RAGAS evaluation...")
|
| 1702 |
+
|
| 1703 |
+
ragas_result = state.ragas_evaluator.evaluate_single(query = message,
|
| 1704 |
+
answer = query_response.answer,
|
| 1705 |
+
contexts = contexts,
|
| 1706 |
+
ground_truth = None,
|
| 1707 |
+
retrieval_time_ms = int(query_response.retrieval_time_ms),
|
| 1708 |
+
generation_time_ms = int(query_response.generation_time_ms),
|
| 1709 |
+
total_time_ms = int(query_response.total_time_ms),
|
| 1710 |
+
chunks_retrieved = len(query_response.sources),
|
| 1711 |
+
query_type = actual_query_type,
|
| 1712 |
+
)
|
| 1713 |
+
|
| 1714 |
+
logger.info(f"RAGAS evaluation complete: type={actual_query_type.upper()}, relevancy={ragas_result.answer_relevancy:.3f}, faithfulness={ragas_result.faithfulness:.3f}, overall={ragas_result.overall_score:.3f}")
|
| 1715 |
+
|
| 1716 |
+
except Exception as e:
|
| 1717 |
+
logger.error(f"RAGAS evaluation failed: {e}", exc_info = True)
|
| 1718 |
+
# Continue without RAGAS metrics - don't fail the request
|
| 1719 |
+
|
| 1720 |
+
# Format sources for response
|
| 1721 |
+
sources = list()
|
| 1722 |
+
|
| 1723 |
+
for i, chunk_with_score in enumerate(query_response.sources[:5], 1):
|
| 1724 |
+
chunk = chunk_with_score.chunk
|
| 1725 |
+
|
| 1726 |
+
source = {"rank" : i,
|
| 1727 |
+
"score" : chunk_with_score.score,
|
| 1728 |
+
"document_id" : chunk.document_id,
|
| 1729 |
+
"chunk_id" : chunk.chunk_id,
|
| 1730 |
+
"text_preview" : chunk.text[:500] + "..." if len(chunk.text) > 500 else chunk.text,
|
| 1731 |
+
"page_number" : chunk.page_number,
|
| 1732 |
+
"section_title" : chunk.section_title,
|
| 1733 |
+
"retrieval_method" : chunk_with_score.retrieval_method,
|
| 1734 |
+
}
|
| 1735 |
+
|
| 1736 |
+
sources.append(source)
|
| 1737 |
+
|
| 1738 |
+
# Generate session ID if not provided
|
| 1739 |
+
if not session_id:
|
| 1740 |
+
session_id = f"session_{datetime.now().timestamp()}"
|
| 1741 |
+
|
| 1742 |
+
# Determine query type for response metadata
|
| 1743 |
+
is_general_query = (actual_query_type == "general")
|
| 1744 |
+
|
| 1745 |
+
# Prepare response
|
| 1746 |
+
response = {"session_id" : session_id,
|
| 1747 |
+
"response" : query_response.answer,
|
| 1748 |
+
"sources" : sources,
|
| 1749 |
+
"is_general_query" : is_general_query,
|
| 1750 |
+
"metrics" : {"retrieval_time" : int(query_response.retrieval_time_ms),
|
| 1751 |
+
"generation_time" : int(query_response.generation_time_ms),
|
| 1752 |
+
"total_time" : int(query_response.total_time_ms),
|
| 1753 |
+
"chunks_retrieved" : len(query_response.sources),
|
| 1754 |
+
"chunks_used" : len(sources),
|
| 1755 |
+
"tokens_used" : query_response.tokens_used.get("total", 0) if query_response.tokens_used else 0,
|
| 1756 |
+
"actual_total_time" : int(total_time),
|
| 1757 |
+
"query_type" : actual_query_type,
|
| 1758 |
+
"llm_classified" : True, # Now using LLM for classification
|
| 1759 |
+
},
|
| 1760 |
+
}
|
| 1761 |
+
|
| 1762 |
+
# Add RAGAS metrics if evaluation succeeded
|
| 1763 |
+
if ragas_result:
|
| 1764 |
+
response["ragas_metrics"] = {"answer_relevancy" : round(ragas_result.answer_relevancy, 3),
|
| 1765 |
+
"faithfulness" : round(ragas_result.faithfulness, 3),
|
| 1766 |
+
"context_precision" : round(ragas_result.context_precision, 3) if ragas_result.context_precision else None,
|
| 1767 |
+
"context_relevancy" : round(ragas_result.context_relevancy, 3),
|
| 1768 |
+
"overall_score" : round(ragas_result.overall_score, 3),
|
| 1769 |
+
"context_recall" : round(ragas_result.context_recall, 3) if ragas_result.context_recall else None,
|
| 1770 |
+
"answer_similarity" : round(ragas_result.answer_similarity, 3) if ragas_result.answer_similarity else None,
|
| 1771 |
+
"answer_correctness" : round(ragas_result.answer_correctness, 3) if ragas_result.answer_correctness else None,
|
| 1772 |
+
"query_type" : ragas_result.query_type,
|
| 1773 |
+
}
|
| 1774 |
+
else:
|
| 1775 |
+
response["ragas_metrics"] = None
|
| 1776 |
+
|
| 1777 |
+
# Store in session
|
| 1778 |
+
if session_id not in state.active_sessions:
|
| 1779 |
+
state.active_sessions[session_id] = list()
|
| 1780 |
+
|
| 1781 |
+
state.active_sessions[session_id].append({"query" : message,
|
| 1782 |
+
"response" : query_response.answer,
|
| 1783 |
+
"sources" : sources,
|
| 1784 |
+
"timestamp" : datetime.now().isoformat(),
|
| 1785 |
+
"metrics" : response["metrics"],
|
| 1786 |
+
"ragas_metrics" : response.get("ragas_metrics", {}),
|
| 1787 |
+
"is_general_query" : is_general_query,
|
| 1788 |
+
})
|
| 1789 |
+
|
| 1790 |
+
# Clear analytics cache when new data is available
|
| 1791 |
+
state.analytics_cache.data = None
|
| 1792 |
+
|
| 1793 |
+
logger.info(f"Chat response generated successfully in {int(total_time)}ms | (type: {actual_query_type.upper()})")
|
| 1794 |
+
|
| 1795 |
+
return response
|
| 1796 |
+
|
| 1797 |
+
except Exception as e:
|
| 1798 |
+
logger.error(f"Chat error: {e}", exc_info = True)
|
| 1799 |
+
|
| 1800 |
+
raise HTTPException(status_code = 500,
|
| 1801 |
+
detail = str(e),
|
| 1802 |
+
)
|
| 1803 |
+
|
| 1804 |
+
|
| 1805 |
+
@app.get("/api/configuration")
|
| 1806 |
+
async def get_configuration():
|
| 1807 |
+
"""
|
| 1808 |
+
Get current configuration
|
| 1809 |
+
"""
|
| 1810 |
+
state = app.state.app
|
| 1811 |
+
|
| 1812 |
+
# Get system health
|
| 1813 |
+
health_status = state.get_system_health()
|
| 1814 |
+
|
| 1815 |
+
return {"configuration" : {"inference_model" : settings.OLLAMA_MODEL,
|
| 1816 |
+
"embedding_model" : settings.EMBEDDING_MODEL,
|
| 1817 |
+
"chunking_strategy" : "adaptive",
|
| 1818 |
+
"chunk_size" : settings.FIXED_CHUNK_SIZE,
|
| 1819 |
+
"chunk_overlap" : settings.FIXED_CHUNK_OVERLAP,
|
| 1820 |
+
"retrieval_top_k" : settings.TOP_K_RETRIEVE,
|
| 1821 |
+
"vector_weight" : settings.VECTOR_WEIGHT,
|
| 1822 |
+
"bm25_weight" : settings.BM25_WEIGHT,
|
| 1823 |
+
"temperature" : settings.DEFAULT_TEMPERATURE,
|
| 1824 |
+
"max_tokens" : settings.MAX_TOKENS,
|
| 1825 |
+
"enable_reranking" : settings.ENABLE_RERANKING,
|
| 1826 |
+
"is_ready" : state.is_ready,
|
| 1827 |
+
"llm_healthy" : health_status["llm"],
|
| 1828 |
+
},
|
| 1829 |
+
"health" : health_status,
|
| 1830 |
+
}
|
| 1831 |
+
|
| 1832 |
+
|
| 1833 |
+
@app.post("/api/configuration")
|
| 1834 |
+
async def update_configuration(temperature: float = Form(None), max_tokens: int = Form(None), retrieval_top_k: int = Form(None),
|
| 1835 |
+
vector_weight: float = Form(None), bm25_weight: float = Form(None), enable_reranking: bool = Form(None),
|
| 1836 |
+
session_id: str = Form(None)):
|
| 1837 |
+
"""
|
| 1838 |
+
Update system configuration (runtime parameters only)
|
| 1839 |
+
"""
|
| 1840 |
+
state = app.state.app
|
| 1841 |
+
|
| 1842 |
+
try:
|
| 1843 |
+
updates = dict()
|
| 1844 |
+
|
| 1845 |
+
# Runtime parameters (no rebuild required)
|
| 1846 |
+
if (temperature is not None):
|
| 1847 |
+
updates["temperature"] = temperature
|
| 1848 |
+
|
| 1849 |
+
if (max_tokens and (max_tokens != settings.MAX_TOKENS)):
|
| 1850 |
+
updates["max_tokens"] = max_tokens
|
| 1851 |
+
|
| 1852 |
+
if (retrieval_top_k and (retrieval_top_k != settings.TOP_K_RETRIEVE)):
|
| 1853 |
+
updates["retrieval_top_k"] = retrieval_top_k
|
| 1854 |
+
|
| 1855 |
+
if ((vector_weight is not None) and (vector_weight != settings.VECTOR_WEIGHT)):
|
| 1856 |
+
updates["vector_weight"] = vector_weight
|
| 1857 |
+
|
| 1858 |
+
# Update hybrid retriever weights
|
| 1859 |
+
if bm25_weight is not None:
|
| 1860 |
+
state.hybrid_retriever.update_weights(vector_weight, bm25_weight)
|
| 1861 |
+
|
| 1862 |
+
if ((bm25_weight is not None) and (bm25_weight != settings.BM25_WEIGHT)):
|
| 1863 |
+
updates["bm25_weight"] = bm25_weight
|
| 1864 |
+
|
| 1865 |
+
if (enable_reranking is not None):
|
| 1866 |
+
updates["enable_reranking"] = enable_reranking
|
| 1867 |
+
|
| 1868 |
+
# Store session-based config overrides
|
| 1869 |
+
if session_id:
|
| 1870 |
+
if session_id not in state.config_overrides:
|
| 1871 |
+
state.config_overrides[session_id] = {}
|
| 1872 |
+
|
| 1873 |
+
state.config_overrides[session_id].update(updates)
|
| 1874 |
+
|
| 1875 |
+
logger.info(f"Configuration updated: {updates}")
|
| 1876 |
+
|
| 1877 |
+
# Clear analytics cache since configuration changed
|
| 1878 |
+
state.analytics_cache.data = None
|
| 1879 |
+
|
| 1880 |
+
return {"success" : True,
|
| 1881 |
+
"message" : "Configuration updated successfully",
|
| 1882 |
+
"updates" : updates,
|
| 1883 |
+
}
|
| 1884 |
+
|
| 1885 |
+
except Exception as e:
|
| 1886 |
+
logger.error(f"Configuration update error: {e}", exc_info = True)
|
| 1887 |
+
raise HTTPException(status_code = 500,
|
| 1888 |
+
detail = str(e),
|
| 1889 |
+
)
|
| 1890 |
+
|
| 1891 |
+
|
| 1892 |
+
@app.get("/api/analytics")
|
| 1893 |
+
async def get_analytics():
|
| 1894 |
+
"""
|
| 1895 |
+
Get comprehensive system analytics and metrics with caching
|
| 1896 |
+
"""
|
| 1897 |
+
state = app.state.app
|
| 1898 |
+
|
| 1899 |
+
try:
|
| 1900 |
+
# Check cache first
|
| 1901 |
+
cached_data = state.analytics_cache.get()
|
| 1902 |
+
|
| 1903 |
+
if cached_data:
|
| 1904 |
+
cached_data["cache_info"]["from_cache"] = True
|
| 1905 |
+
|
| 1906 |
+
return cached_data
|
| 1907 |
+
|
| 1908 |
+
# Calculate fresh analytics
|
| 1909 |
+
analytics_data = state.calculate_comprehensive_analytics()
|
| 1910 |
+
|
| 1911 |
+
# Update cache
|
| 1912 |
+
state.analytics_cache.update(analytics_data)
|
| 1913 |
+
|
| 1914 |
+
return analytics_data
|
| 1915 |
+
|
| 1916 |
+
except Exception as e:
|
| 1917 |
+
logger.error(f"Analytics calculation error: {e}", exc_info = True)
|
| 1918 |
+
|
| 1919 |
+
# Return basic analytics even if calculation fails
|
| 1920 |
+
return {"performance_metrics" : {"avg_response_time" : 0,
|
| 1921 |
+
"total_queries" : 0,
|
| 1922 |
+
"queries_last_hour" : 0,
|
| 1923 |
+
"error" : "Could not calculate performance metrics"
|
| 1924 |
+
},
|
| 1925 |
+
"quality_metrics" : {"answer_relevancy" : 0.0,
|
| 1926 |
+
"faithfulness" : 0.0,
|
| 1927 |
+
"context_precision" : 0.0,
|
| 1928 |
+
"overall_score" : 0.0,
|
| 1929 |
+
"confidence" : "low",
|
| 1930 |
+
"metrics_available" : False,
|
| 1931 |
+
"error" : "Could not calculate quality metrics"
|
| 1932 |
+
},
|
| 1933 |
+
"system_information" : state.get_system_information() if hasattr(state, 'get_system_information') else {},
|
| 1934 |
+
"health_status" : {"overall" : "unknown"},
|
| 1935 |
+
"document_statistics" : {"total_documents" : len(state.processed_documents),
|
| 1936 |
+
"total_chunks" : sum(len(chunks) for chunks in state.document_chunks.values()),
|
| 1937 |
+
"uploaded_files" : len(state.uploaded_files)
|
| 1938 |
+
},
|
| 1939 |
+
"session_statistics" : {"total_sessions" : len(state.active_sessions),
|
| 1940 |
+
"total_messages" : sum(len(msgs) for msgs in state.active_sessions.values())
|
| 1941 |
+
},
|
| 1942 |
+
"calculated_at" : datetime.now().isoformat(),
|
| 1943 |
+
"error" : str(e)
|
| 1944 |
+
}
|
| 1945 |
+
|
| 1946 |
+
|
| 1947 |
+
@app.get("/api/analytics/refresh")
|
| 1948 |
+
async def refresh_analytics():
|
| 1949 |
+
"""
|
| 1950 |
+
Force refresh analytics cache
|
| 1951 |
+
"""
|
| 1952 |
+
state = app.state.app
|
| 1953 |
+
|
| 1954 |
+
try:
|
| 1955 |
+
# Clear cache
|
| 1956 |
+
state.analytics_cache.data = None
|
| 1957 |
+
|
| 1958 |
+
# Calculate fresh analytics
|
| 1959 |
+
analytics_data = state.calculate_comprehensive_analytics()
|
| 1960 |
+
|
| 1961 |
+
return {"success" : True,
|
| 1962 |
+
"message" : "Analytics cache refreshed successfully",
|
| 1963 |
+
"data" : analytics_data,
|
| 1964 |
+
}
|
| 1965 |
+
|
| 1966 |
+
except Exception as e:
|
| 1967 |
+
logger.error(f"Analytics refresh error: {e}", exc_info = True)
|
| 1968 |
+
raise HTTPException(status_code = 500,
|
| 1969 |
+
detail = str(e),
|
| 1970 |
+
)
|
| 1971 |
+
|
| 1972 |
+
|
| 1973 |
+
@app.get("/api/analytics/detailed")
|
| 1974 |
+
async def get_detailed_analytics():
|
| 1975 |
+
"""
|
| 1976 |
+
Get detailed analytics including query history and component performance
|
| 1977 |
+
"""
|
| 1978 |
+
state = app.state.app
|
| 1979 |
+
|
| 1980 |
+
try:
|
| 1981 |
+
# Get basic analytics
|
| 1982 |
+
analytics = await get_analytics()
|
| 1983 |
+
|
| 1984 |
+
# Add detailed session information
|
| 1985 |
+
detailed_sessions = list()
|
| 1986 |
+
|
| 1987 |
+
for session_id, messages in state.active_sessions.items():
|
| 1988 |
+
session_info = {"session_id" : session_id,
|
| 1989 |
+
"message_count" : len(messages),
|
| 1990 |
+
"first_message" : messages[0]["timestamp"] if messages else None,
|
| 1991 |
+
"last_message" : messages[-1]["timestamp"] if messages else None,
|
| 1992 |
+
"total_response_time" : sum(msg.get("metrics", {}).get("total_time", 0) for msg in messages),
|
| 1993 |
+
"avg_sources_per_query" : sum(len(msg.get("sources", [])) for msg in messages) / len(messages) if messages else 0,
|
| 1994 |
+
}
|
| 1995 |
+
|
| 1996 |
+
detailed_sessions.append(session_info)
|
| 1997 |
+
|
| 1998 |
+
# Add component performance if available
|
| 1999 |
+
component_performance = dict()
|
| 2000 |
+
|
| 2001 |
+
if state.hybrid_retriever:
|
| 2002 |
+
try:
|
| 2003 |
+
retrieval_stats = state.hybrid_retriever.get_retrieval_stats()
|
| 2004 |
+
component_performance["retrieval"] = retrieval_stats
|
| 2005 |
+
|
| 2006 |
+
except:
|
| 2007 |
+
component_performance["retrieval"] = {"error": "Could not retrieve stats"}
|
| 2008 |
+
|
| 2009 |
+
if state.embedder:
|
| 2010 |
+
try:
|
| 2011 |
+
embedder_info = state.embedder.get_model_info()
|
| 2012 |
+
component_performance["embeddings"] = {"model" : embedder_info.get("model_name", "unknown"),
|
| 2013 |
+
"dimension" : embedder_info.get("embedding_dim", 0),
|
| 2014 |
+
"device" : embedder_info.get("device", "cpu"),
|
| 2015 |
+
}
|
| 2016 |
+
except:
|
| 2017 |
+
component_performance["embeddings"] = {"error": "Could not retrieve stats"}
|
| 2018 |
+
|
| 2019 |
+
analytics["detailed_sessions"] = detailed_sessions
|
| 2020 |
+
analytics["component_performance"] = component_performance
|
| 2021 |
+
|
| 2022 |
+
return analytics
|
| 2023 |
+
|
| 2024 |
+
except Exception as e:
|
| 2025 |
+
logger.error(f"Detailed analytics error: {e}", exc_info = True)
|
| 2026 |
+
raise HTTPException(status_code = 500,
|
| 2027 |
+
detail = str(e),
|
| 2028 |
+
)
|
| 2029 |
+
|
| 2030 |
+
|
| 2031 |
+
@app.get("/api/export-chat/{session_id}")
|
| 2032 |
+
async def export_chat(session_id: str, format: str = "json"):
|
| 2033 |
+
"""
|
| 2034 |
+
Export chat history
|
| 2035 |
+
"""
|
| 2036 |
+
state = app.state.app
|
| 2037 |
+
if session_id not in state.active_sessions:
|
| 2038 |
+
raise HTTPException(status_code = 404,
|
| 2039 |
+
detail = "Session not found",
|
| 2040 |
+
)
|
| 2041 |
+
|
| 2042 |
+
try:
|
| 2043 |
+
chat_history = state.active_sessions[session_id]
|
| 2044 |
+
|
| 2045 |
+
if (format == "json"):
|
| 2046 |
+
return JSONResponse(content = {"session_id" : session_id,
|
| 2047 |
+
"export_time" : datetime.now().isoformat(),
|
| 2048 |
+
"total_messages" : len(chat_history),
|
| 2049 |
+
"history" : chat_history,
|
| 2050 |
+
}
|
| 2051 |
+
)
|
| 2052 |
+
|
| 2053 |
+
elif (format == "csv"):
|
| 2054 |
+
output = io.StringIO()
|
| 2055 |
+
|
| 2056 |
+
if chat_history:
|
| 2057 |
+
fieldnames = ["timestamp", "query", "response", "sources_count", "response_time_ms"]
|
| 2058 |
+
writer = csv.DictWriter(output, fieldnames = fieldnames)
|
| 2059 |
+
writer.writeheader()
|
| 2060 |
+
|
| 2061 |
+
for entry in chat_history:
|
| 2062 |
+
writer.writerow({"timestamp" : entry.get("timestamp", ""),
|
| 2063 |
+
"query" : entry.get("query", ""),
|
| 2064 |
+
"response" : entry.get("response", ""),
|
| 2065 |
+
"sources_count" : len(entry.get("sources", [])),
|
| 2066 |
+
"response_time_ms" : entry.get("metrics", {}).get("total_time", 0),
|
| 2067 |
+
})
|
| 2068 |
+
|
| 2069 |
+
return JSONResponse(content = {"csv" : output.getvalue(),
|
| 2070 |
+
"session_id" : session_id,
|
| 2071 |
+
"format" : "csv",
|
| 2072 |
+
}
|
| 2073 |
+
)
|
| 2074 |
+
|
| 2075 |
+
else:
|
| 2076 |
+
raise HTTPException(status_code = 400,
|
| 2077 |
+
detail = "Unsupported format. Use 'json' or 'csv'",
|
| 2078 |
+
)
|
| 2079 |
+
|
| 2080 |
+
except Exception as e:
|
| 2081 |
+
logger.error(f"Export error: {e}", exc_info = True)
|
| 2082 |
+
raise HTTPException(status_code = 500,
|
| 2083 |
+
detail = str(e),
|
| 2084 |
+
)
|
| 2085 |
+
|
| 2086 |
+
|
| 2087 |
+
@app.post("/api/cleanup/session/{session_id}")
|
| 2088 |
+
async def cleanup_session(session_id: str):
|
| 2089 |
+
"""
|
| 2090 |
+
Clean up specific session
|
| 2091 |
+
"""
|
| 2092 |
+
state = app.state.app
|
| 2093 |
+
|
| 2094 |
+
if session_id in state.active_sessions:
|
| 2095 |
+
del state.active_sessions[session_id]
|
| 2096 |
+
|
| 2097 |
+
if session_id in state.config_overrides:
|
| 2098 |
+
del state.config_overrides[session_id]
|
| 2099 |
+
|
| 2100 |
+
# Check if no sessions left
|
| 2101 |
+
if not state.active_sessions:
|
| 2102 |
+
logger.info("No active sessions, suggesting vector store cleanup")
|
| 2103 |
+
|
| 2104 |
+
return {"success" : True,
|
| 2105 |
+
"message" : f"Session {session_id} cleaned up",
|
| 2106 |
+
"suggestion" : "No active sessions remaining. Consider cleaning vector store.",
|
| 2107 |
+
}
|
| 2108 |
+
|
| 2109 |
+
return {"success" : True,
|
| 2110 |
+
"message" : f"Session {session_id} cleaned up",
|
| 2111 |
+
}
|
| 2112 |
+
|
| 2113 |
+
return {"success" : False,
|
| 2114 |
+
"message" : "Session not found",
|
| 2115 |
+
}
|
| 2116 |
+
|
| 2117 |
+
|
| 2118 |
+
@app.post("/api/cleanup/vector-store")
|
| 2119 |
+
async def cleanup_vector_store():
|
| 2120 |
+
"""
|
| 2121 |
+
Manual vector store cleanup
|
| 2122 |
+
"""
|
| 2123 |
+
state = app.state.app
|
| 2124 |
+
|
| 2125 |
+
try:
|
| 2126 |
+
# Use cleanup manager
|
| 2127 |
+
success = await CleanupManager.full_cleanup(state)
|
| 2128 |
+
|
| 2129 |
+
if success:
|
| 2130 |
+
return {"success" : True,
|
| 2131 |
+
"message" : "Vector store and all data cleaned up",
|
| 2132 |
+
}
|
| 2133 |
+
|
| 2134 |
+
else:
|
| 2135 |
+
return {"success" : False,
|
| 2136 |
+
"message" : "Cleanup completed with errors",
|
| 2137 |
+
}
|
| 2138 |
+
|
| 2139 |
+
except Exception as e:
|
| 2140 |
+
logger.error(f"Manual cleanup error: {e}")
|
| 2141 |
+
raise HTTPException(status_code = 500,
|
| 2142 |
+
detail = str(e),
|
| 2143 |
+
)
|
| 2144 |
+
|
| 2145 |
+
|
| 2146 |
+
@app.post("/api/cleanup/full")
|
| 2147 |
+
async def full_cleanup_endpoint():
|
| 2148 |
+
"""
|
| 2149 |
+
Full system cleanup endpoint
|
| 2150 |
+
"""
|
| 2151 |
+
state = app.state.app
|
| 2152 |
+
|
| 2153 |
+
try:
|
| 2154 |
+
# Also clean up frontend sessions
|
| 2155 |
+
state.active_sessions.clear()
|
| 2156 |
+
state.config_overrides.clear()
|
| 2157 |
+
|
| 2158 |
+
# Full cleanup
|
| 2159 |
+
success = await CleanupManager.full_cleanup(state)
|
| 2160 |
+
|
| 2161 |
+
return {"success" : success,
|
| 2162 |
+
"message" : "Full system cleanup completed",
|
| 2163 |
+
"details" : {"sessions_cleaned" : 0, # Already cleared above
|
| 2164 |
+
"memory_freed" : "All application state",
|
| 2165 |
+
"disk_space_freed" : "All vector store and uploaded files",
|
| 2166 |
+
}
|
| 2167 |
+
}
|
| 2168 |
+
|
| 2169 |
+
except Exception as e:
|
| 2170 |
+
logger.error(f"Full cleanup endpoint error: {e}")
|
| 2171 |
+
raise HTTPException(status_code = 500,
|
| 2172 |
+
detail = str(e),
|
| 2173 |
+
)
|
| 2174 |
+
|
| 2175 |
+
|
| 2176 |
+
@app.get("/api/cleanup/status")
|
| 2177 |
+
async def get_cleanup_status():
|
| 2178 |
+
"""
|
| 2179 |
+
Get cleanup status and statistics
|
| 2180 |
+
"""
|
| 2181 |
+
state = app.state.app
|
| 2182 |
+
|
| 2183 |
+
return {"sessions_active" : len(state.active_sessions),
|
| 2184 |
+
"documents_processed" : len(state.processed_documents),
|
| 2185 |
+
"total_chunks" : sum(len(chunks) for chunks in state.document_chunks.values()),
|
| 2186 |
+
"vector_store_ready" : state.is_ready,
|
| 2187 |
+
"cleanup_registry_size" : len(_cleanup_registry),
|
| 2188 |
+
"suggested_action" : "cleanup_vector_store" if state.is_ready else "upload_documents",
|
| 2189 |
+
}
|
| 2190 |
+
|
| 2191 |
+
|
| 2192 |
+
# ============================================================================
|
| 2193 |
+
# MAIN ENTRY POINT
|
| 2194 |
+
# ============================================================================
|
| 2195 |
+
if __name__ == "__main__":
|
| 2196 |
+
try:
|
| 2197 |
+
# Run the app
|
| 2198 |
+
uvicorn.run("app:app",
|
| 2199 |
+
host = settings.HOST,
|
| 2200 |
+
port = settings.PORT,
|
| 2201 |
+
reload = settings.DEBUG,
|
| 2202 |
+
log_level = "info",
|
| 2203 |
+
timeout_graceful_shutdown = 10.0,
|
| 2204 |
+
access_log = False,
|
| 2205 |
+
)
|
| 2206 |
+
|
| 2207 |
+
except KeyboardInterrupt:
|
| 2208 |
+
logger.info("Keyboard interrupt received - normal shutdown")
|
| 2209 |
+
|
| 2210 |
+
except Exception as e:
|
| 2211 |
+
logger.error(f"Application crashed: {e}", exc_info = True)
|
| 2212 |
+
|
| 2213 |
+
finally:
|
| 2214 |
+
# Simple final cleanup
|
| 2215 |
+
logger.info("Application stopping, final cleanup...")
|
| 2216 |
+
try:
|
| 2217 |
+
# Shutdown executor if it exists
|
| 2218 |
+
if '_cleanup_executor' in globals():
|
| 2219 |
+
_cleanup_executor.shutdown(wait = True)
|
| 2220 |
+
|
| 2221 |
+
except:
|
| 2222 |
+
pass
|
| 2223 |
+
|
| 2224 |
+
logger.info("Application stopped")
|
chunking/__init__.py
ADDED
|
File without changes
|
chunking/adaptive_selector.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from config.models import DocumentChunk
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.models import DocumentMetadata
|
| 8 |
+
from config.models import ChunkingStrategy
|
| 9 |
+
from config.logging_config import get_logger
|
| 10 |
+
from chunking.base_chunker import BaseChunker
|
| 11 |
+
from chunking.base_chunker import ChunkerConfig
|
| 12 |
+
from chunking.token_counter import TokenCounter
|
| 13 |
+
from chunking.fixed_chunker import FixedChunker
|
| 14 |
+
from chunking.semantic_chunker import SemanticChunker
|
| 15 |
+
from chunking.llamaindex_chunker import LlamaIndexChunker
|
| 16 |
+
from chunking.hierarchical_chunker import HierarchicalChunker
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Setup Settings and Logging
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
settings = get_settings()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AdaptiveChunkingSelector:
|
| 26 |
+
"""
|
| 27 |
+
Intelligent chunking strategy selector with structure detection:
|
| 28 |
+
- Analyzes document characteristics (size, structure, content type)
|
| 29 |
+
- Detects structured documents (projects, sections, hierarchies)
|
| 30 |
+
- Automatically selects optimal chunking strategy
|
| 31 |
+
- Prioritizes section-aware chunking for structured content
|
| 32 |
+
|
| 33 |
+
Strategy Selection Logic (UPDATED):
|
| 34 |
+
- Small docs (< 1K tokens) → Fixed chunking
|
| 35 |
+
- Medium structured docs → Semantic (section-aware)
|
| 36 |
+
- Medium unstructured docs → LlamaIndex or basic semantic
|
| 37 |
+
- Large docs (>500K tokens) → Hierarchical chunking
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, prefer_llamaindex: bool = True):
|
| 40 |
+
"""
|
| 41 |
+
Initialize adaptive selector with all chunking strategies
|
| 42 |
+
|
| 43 |
+
Arguments:
|
| 44 |
+
----------
|
| 45 |
+
prefer_llamaindex { bool } : Prefer LlamaIndex over custom semantic chunking when available
|
| 46 |
+
"""
|
| 47 |
+
self.logger = logger
|
| 48 |
+
self.token_counter = TokenCounter()
|
| 49 |
+
self.prefer_llamaindex = prefer_llamaindex
|
| 50 |
+
|
| 51 |
+
# Initialize all chunking strategies
|
| 52 |
+
self.fixed_chunker = FixedChunker()
|
| 53 |
+
self.semantic_chunker = SemanticChunker(respect_section_boundaries = True)
|
| 54 |
+
self.hierarchical_chunker = HierarchicalChunker()
|
| 55 |
+
self.llamaindex_chunker = LlamaIndexChunker()
|
| 56 |
+
|
| 57 |
+
# Strategy thresholds (from settings)
|
| 58 |
+
self.small_doc_threshold = settings.SMALL_DOC_THRESHOLD
|
| 59 |
+
self.large_doc_threshold = settings.LARGE_DOC_THRESHOLD
|
| 60 |
+
|
| 61 |
+
# Check LlamaIndex availability
|
| 62 |
+
self.llamaindex_available = self.llamaindex_chunker._initialized
|
| 63 |
+
|
| 64 |
+
self.logger.info(f"Initialized AdaptiveChunkingSelector: LlamaIndex available={self.llamaindex_available}, prefer_llamaindex={self.prefer_llamaindex}, section_aware_semantic=True")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def select_chunking_strategy(self, text: str, metadata: Optional[DocumentMetadata] = None) -> tuple[ChunkingStrategy, dict]:
|
| 68 |
+
"""
|
| 69 |
+
Analyze document and select optimal chunking strategy: Detects structured documents and prioritizes section-aware chunking
|
| 70 |
+
|
| 71 |
+
Arguments:
|
| 72 |
+
----------
|
| 73 |
+
text { str } : Document text
|
| 74 |
+
|
| 75 |
+
metadata { DocumentMetadata } : Document metadata
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
--------
|
| 79 |
+
{ tuple } : Tuple of (selected_strategy, analysis_results)
|
| 80 |
+
"""
|
| 81 |
+
analysis = self._analyze_document(text = text,
|
| 82 |
+
metadata = metadata,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Check if document has clear structure (projects, sections)
|
| 86 |
+
has_structure = analysis.get("has_structure", False)
|
| 87 |
+
structure_score = analysis.get("structure_score", 0)
|
| 88 |
+
|
| 89 |
+
# Strategy selection logic
|
| 90 |
+
if (analysis["total_tokens"] <= self.small_doc_threshold):
|
| 91 |
+
strategy = ChunkingStrategy.FIXED
|
| 92 |
+
reason = f"Small document ({analysis['total_tokens']} tokens) - fixed chunking for simplicity"
|
| 93 |
+
|
| 94 |
+
elif (analysis["total_tokens"] <= self.large_doc_threshold):
|
| 95 |
+
# Medium documents: check for structure
|
| 96 |
+
if (has_structure and (structure_score > 0.3)):
|
| 97 |
+
# Structured document detected - use section-aware semantic chunking
|
| 98 |
+
strategy = ChunkingStrategy.SEMANTIC
|
| 99 |
+
reason = (f"Medium structured document ({analysis['total_tokens']} tokens, structure_score={structure_score:.2f}) - section-aware semantic chunking")
|
| 100 |
+
|
| 101 |
+
elif self.llamaindex_available and self.prefer_llamaindex:
|
| 102 |
+
strategy = ChunkingStrategy.SEMANTIC
|
| 103 |
+
reason = f"Medium document ({analysis['total_tokens']} tokens) - LlamaIndex semantic chunking"
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
strategy = ChunkingStrategy.SEMANTIC
|
| 107 |
+
reason = f"Medium document ({analysis['total_tokens']} tokens) - semantic chunking"
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
strategy = ChunkingStrategy.HIERARCHICAL
|
| 111 |
+
reason = f"Large document ({analysis['total_tokens']} tokens) - hierarchical chunking"
|
| 112 |
+
|
| 113 |
+
# Override based on document structure if available
|
| 114 |
+
if (metadata and self._has_clear_structure(metadata)):
|
| 115 |
+
if (strategy == ChunkingStrategy.FIXED):
|
| 116 |
+
# Upgrade to semantic for structured documents
|
| 117 |
+
strategy = ChunkingStrategy.SEMANTIC
|
| 118 |
+
reason = "Document has clear structure - section-aware semantic chunking preferred"
|
| 119 |
+
|
| 120 |
+
analysis["selected_strategy"] = strategy
|
| 121 |
+
analysis["selection_reason"] = reason
|
| 122 |
+
analysis["llamaindex_used"] = ((strategy == ChunkingStrategy.SEMANTIC) and self.llamaindex_available and self.prefer_llamaindex and not has_structure)
|
| 123 |
+
|
| 124 |
+
self.logger.info(f"Selected {strategy.value}: {reason}")
|
| 125 |
+
|
| 126 |
+
return strategy, analysis
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def chunk_text(self, text: str, metadata: Optional[DocumentMetadata] = None, force_strategy: Optional[ChunkingStrategy] = None) -> List[DocumentChunk]:
|
| 130 |
+
"""
|
| 131 |
+
Automatically select strategy and chunk text
|
| 132 |
+
|
| 133 |
+
Arguments:
|
| 134 |
+
----------
|
| 135 |
+
text { str } : Document text
|
| 136 |
+
|
| 137 |
+
metadata { DocumentMetadata } : Document metadata
|
| 138 |
+
|
| 139 |
+
force_strategy { ChunkingStrategy } : Force specific strategy (optional)
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
--------
|
| 143 |
+
{ list } : List of DocumentChunk objects
|
| 144 |
+
"""
|
| 145 |
+
if not text or not text.strip():
|
| 146 |
+
return []
|
| 147 |
+
|
| 148 |
+
# Select strategy (or use forced strategy)
|
| 149 |
+
if force_strategy:
|
| 150 |
+
strategy = force_strategy
|
| 151 |
+
analysis = self._analyze_document(text = text,
|
| 152 |
+
metadata = metadata,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
reason = f"Forced strategy: {force_strategy.value}"
|
| 156 |
+
llamaindex_used = False
|
| 157 |
+
else:
|
| 158 |
+
strategy, analysis = self.select_chunking_strategy(text = text,
|
| 159 |
+
metadata = metadata,
|
| 160 |
+
)
|
| 161 |
+
reason = analysis["selection_reason"]
|
| 162 |
+
llamaindex_used = analysis["llamaindex_used"]
|
| 163 |
+
|
| 164 |
+
# Get appropriate chunker
|
| 165 |
+
if ((strategy == ChunkingStrategy.SEMANTIC) and llamaindex_used):
|
| 166 |
+
chunker = self.llamaindex_chunker
|
| 167 |
+
chunker_name = "LlamaIndex Semantic"
|
| 168 |
+
|
| 169 |
+
else:
|
| 170 |
+
chunker = self._get_chunker_for_strategy(strategy = strategy)
|
| 171 |
+
chunker_name = strategy.value
|
| 172 |
+
|
| 173 |
+
# Update metadata with strategy information
|
| 174 |
+
if metadata:
|
| 175 |
+
metadata.chunking_strategy = strategy
|
| 176 |
+
metadata.extra["chunking_analysis"] = {"strategy" : strategy.value,
|
| 177 |
+
"chunker_used" : chunker_name,
|
| 178 |
+
"reason" : reason,
|
| 179 |
+
"total_tokens" : analysis["total_tokens"],
|
| 180 |
+
"estimated_chunks" : analysis[f"estimated_{strategy.value.lower()}_chunks"],
|
| 181 |
+
"llamaindex_used" : llamaindex_used,
|
| 182 |
+
"has_structure" : analysis.get("has_structure", False),
|
| 183 |
+
"structure_score" : analysis.get("structure_score", 0),
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
self.logger.info(f"Using {chunker_name} chunker for document")
|
| 187 |
+
|
| 188 |
+
# Perform chunking
|
| 189 |
+
try:
|
| 190 |
+
chunks = chunker.chunk_text(text = text,
|
| 191 |
+
metadata = metadata,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Add strategy metadata to chunks
|
| 195 |
+
for chunk in chunks:
|
| 196 |
+
chunk.metadata["chunking_strategy"] = strategy.value
|
| 197 |
+
chunk.metadata["chunker_used"] = chunker_name
|
| 198 |
+
|
| 199 |
+
if llamaindex_used:
|
| 200 |
+
chunk.metadata["llamaindex_splitter"] = self.llamaindex_chunker.splitter_type
|
| 201 |
+
|
| 202 |
+
self.logger.info(f"Successfully created {len(chunks)} chunks using {chunker_name}")
|
| 203 |
+
|
| 204 |
+
# Log section coverage statistics
|
| 205 |
+
chunks_with_sections = sum(1 for c in chunks if c.section_title)
|
| 206 |
+
if (chunks_with_sections > 0):
|
| 207 |
+
self.logger.info(f"Section coverage: {chunks_with_sections}/{len(chunks)} chunks ({chunks_with_sections/len(chunks)*100:.1f}%) have section titles")
|
| 208 |
+
|
| 209 |
+
return chunks
|
| 210 |
+
|
| 211 |
+
except Exception as e:
|
| 212 |
+
self.logger.error(f"{chunker_name} chunking failed: {repr(e)}, falling back to fixed chunking")
|
| 213 |
+
|
| 214 |
+
# Fallback to fixed chunking
|
| 215 |
+
return self.fixed_chunker.chunk_text(text = text,
|
| 216 |
+
metadata = metadata,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _analyze_document(self, text: str, metadata: Optional[DocumentMetadata] = None) -> dict:
|
| 221 |
+
"""
|
| 222 |
+
Analyze document characteristics for strategy selection: Includes structure detection
|
| 223 |
+
|
| 224 |
+
Arguments:
|
| 225 |
+
----------
|
| 226 |
+
text { str } : Document text
|
| 227 |
+
|
| 228 |
+
metadata { DocumentMetadata } : Document metadata
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
--------
|
| 232 |
+
{ dict } : Analysis results
|
| 233 |
+
"""
|
| 234 |
+
# Basic token analysis
|
| 235 |
+
total_tokens = self.token_counter.count_tokens(text = text)
|
| 236 |
+
total_chars = len(text)
|
| 237 |
+
total_words = len(text.split())
|
| 238 |
+
|
| 239 |
+
# Estimate chunks for each strategy
|
| 240 |
+
estimated_fixed_chunks = max(1, total_tokens // settings.FIXED_CHUNK_SIZE)
|
| 241 |
+
estimated_semantic_chunks = max(1, total_tokens // (settings.FIXED_CHUNK_SIZE * 2))
|
| 242 |
+
estimated_hierarchical_chunks = max(1, total_tokens // settings.CHILD_CHUNK_SIZE)
|
| 243 |
+
estimated_llamaindex_chunks = max(1, total_tokens // (settings.FIXED_CHUNK_SIZE * 1.5))
|
| 244 |
+
|
| 245 |
+
# Structure analysis (simple heuristics)
|
| 246 |
+
sentence_count = len(self.token_counter._split_into_sentences(text = text))
|
| 247 |
+
avg_sentence_length = total_words / sentence_count if (sentence_count > 0) else 0
|
| 248 |
+
|
| 249 |
+
# Paragraph detection (rough)
|
| 250 |
+
paragraphs = [p for p in text.split('\n\n') if p.strip()]
|
| 251 |
+
paragraph_count = len(paragraphs)
|
| 252 |
+
|
| 253 |
+
# NEW: Detect document structure
|
| 254 |
+
has_structure, structure_score = self._detect_document_structure(text)
|
| 255 |
+
|
| 256 |
+
analysis = {"total_tokens" : total_tokens,
|
| 257 |
+
"total_chars" : total_chars,
|
| 258 |
+
"total_words" : total_words,
|
| 259 |
+
"sentence_count" : sentence_count,
|
| 260 |
+
"paragraph_count" : paragraph_count,
|
| 261 |
+
"avg_sentence_length" : avg_sentence_length,
|
| 262 |
+
"estimated_fixed_chunks" : estimated_fixed_chunks,
|
| 263 |
+
"estimated_semantic_chunks" : estimated_semantic_chunks,
|
| 264 |
+
"estimated_llamaindex_chunks" : estimated_llamaindex_chunks,
|
| 265 |
+
"estimated_hierarchical_chunks" : estimated_hierarchical_chunks,
|
| 266 |
+
"document_size_category" : self._get_size_category(total_tokens),
|
| 267 |
+
"llamaindex_available" : self.llamaindex_available,
|
| 268 |
+
"has_structure" : has_structure,
|
| 269 |
+
"structure_score" : structure_score,
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
# Add metadata-based insights if available
|
| 273 |
+
if metadata:
|
| 274 |
+
analysis.update({"document_type" : metadata.document_type.value,
|
| 275 |
+
"file_size_mb" : metadata.file_size_mb,
|
| 276 |
+
"num_pages" : metadata.num_pages,
|
| 277 |
+
"has_clear_structure" : self._has_clear_structure(metadata),
|
| 278 |
+
})
|
| 279 |
+
|
| 280 |
+
return analysis
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _detect_document_structure(self, text: str) -> tuple[bool, float]:
|
| 284 |
+
"""
|
| 285 |
+
Analyzes text for structural patterns and detect if document has clear structural elements (projects, sections, etc.)
|
| 286 |
+
& returns: (has_structure, structure_score)
|
| 287 |
+
"""
|
| 288 |
+
structure_indicators = 0
|
| 289 |
+
max_indicators = 5
|
| 290 |
+
|
| 291 |
+
# Check for project-style headers: "a) Project Name", "b) Project Name"
|
| 292 |
+
project_headers = len(re.findall(r'^[a-z]\)\s+[A-Z]', text, re.MULTILINE))
|
| 293 |
+
|
| 294 |
+
if (project_headers > 2):
|
| 295 |
+
structure_indicators += 1
|
| 296 |
+
|
| 297 |
+
# Check for bullet point lists: "●" or "❖"
|
| 298 |
+
bullet_points = text.count('●') + text.count('❖')
|
| 299 |
+
|
| 300 |
+
if (bullet_points > 5):
|
| 301 |
+
structure_indicators += 1
|
| 302 |
+
|
| 303 |
+
# Check for numbered sections: "1.", "2.", etc.
|
| 304 |
+
numbered_sections = len(re.findall(r'^\d+\.\s+[A-Z]', text, re.MULTILINE))
|
| 305 |
+
|
| 306 |
+
if (numbered_sections > 2):
|
| 307 |
+
structure_indicators += 1
|
| 308 |
+
|
| 309 |
+
# Check for subsection markers ending with ":"
|
| 310 |
+
subsection_markers = len(re.findall(r'^●\s+\w+.*:', text, re.MULTILINE))
|
| 311 |
+
|
| 312 |
+
if (subsection_markers > 3):
|
| 313 |
+
structure_indicators += 1
|
| 314 |
+
|
| 315 |
+
# Check for consistent indentation patterns
|
| 316 |
+
lines = text.split('\n')
|
| 317 |
+
indented_lines = sum(1 for line in lines if line.startswith(' ') or line.startswith('\t'))
|
| 318 |
+
|
| 319 |
+
# >20% indented
|
| 320 |
+
if (indented_lines > len(lines) * 0.2):
|
| 321 |
+
structure_indicators += 1
|
| 322 |
+
|
| 323 |
+
has_structure = (structure_indicators >= 2)
|
| 324 |
+
structure_score = structure_indicators / max_indicators
|
| 325 |
+
|
| 326 |
+
if has_structure:
|
| 327 |
+
self.logger.info(f"Document structure detected: score={structure_score:.2f} (project_headers={project_headers}, bullets={bullet_points}, "
|
| 328 |
+
f"numbered_sections={numbered_sections}, subsections={subsection_markers})")
|
| 329 |
+
|
| 330 |
+
return has_structure, structure_score
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _get_chunker_for_strategy(self, strategy: ChunkingStrategy) -> BaseChunker:
|
| 334 |
+
"""
|
| 335 |
+
Get chunker instance for specified strategy
|
| 336 |
+
|
| 337 |
+
Arguments:
|
| 338 |
+
----------
|
| 339 |
+
strategy { ChunkingStrategy } : Chunking strategy
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
--------
|
| 343 |
+
{ BaseChunker } : Chunker instance
|
| 344 |
+
"""
|
| 345 |
+
chunkers = {ChunkingStrategy.FIXED : self.fixed_chunker,
|
| 346 |
+
ChunkingStrategy.SEMANTIC : self.semantic_chunker,
|
| 347 |
+
ChunkingStrategy.HIERARCHICAL : self.hierarchical_chunker,
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
return chunkers.get(strategy, self.fixed_chunker)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def _get_size_category(self, total_tokens: int) -> str:
|
| 354 |
+
"""
|
| 355 |
+
Categorize document by size
|
| 356 |
+
"""
|
| 357 |
+
if (total_tokens <= self.small_doc_threshold):
|
| 358 |
+
return "small"
|
| 359 |
+
|
| 360 |
+
elif (total_tokens <= self.large_doc_threshold):
|
| 361 |
+
return "medium"
|
| 362 |
+
|
| 363 |
+
else:
|
| 364 |
+
return "large"
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _has_clear_structure(self, metadata: DocumentMetadata) -> bool:
|
| 368 |
+
"""
|
| 369 |
+
Check if document has clear structural elements
|
| 370 |
+
"""
|
| 371 |
+
if metadata.extra:
|
| 372 |
+
# DOCX with multiple sections/headings
|
| 373 |
+
if (metadata.document_type.value == "docx"):
|
| 374 |
+
if (metadata.extra.get("num_sections", 0) > 1):
|
| 375 |
+
return True
|
| 376 |
+
|
| 377 |
+
if (metadata.extra.get("num_paragraphs", 0) > 50):
|
| 378 |
+
return True
|
| 379 |
+
|
| 380 |
+
# PDF with multiple pages and likely structure
|
| 381 |
+
if (metadata.document_type.value == "pdf"):
|
| 382 |
+
if metadata.num_pages and metadata.num_pages > 10:
|
| 383 |
+
return True
|
| 384 |
+
|
| 385 |
+
return False
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def get_strategy_recommendations(self, text: str, metadata: Optional[DocumentMetadata] = None) -> dict:
|
| 389 |
+
"""
|
| 390 |
+
Get detailed strategy recommendations with pros/cons
|
| 391 |
+
"""
|
| 392 |
+
analysis = self._analyze_document(text, metadata)
|
| 393 |
+
|
| 394 |
+
# LlamaIndex recommendation
|
| 395 |
+
llamaindex_recommendation = {"recommended_for" : ["Medium documents", "Structured content", "Superior semantic analysis"],
|
| 396 |
+
"pros" : ["Best semantic boundary detection", "LlamaIndex ecosystem integration", "Advanced embedding-based splitting"],
|
| 397 |
+
"cons" : ["Additional dependency", "Slower initialization", "More complex setup"],
|
| 398 |
+
"estimated_chunks" : analysis["estimated_llamaindex_chunks"],
|
| 399 |
+
"available" : self.llamaindex_available,
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
recommendations = {"fixed" : {"recommended_for" : ["Small documents", "Homogeneous content", "Simple processing"],
|
| 403 |
+
"pros" : ["Fast", "Reliable", "Predictable chunk sizes"],
|
| 404 |
+
"cons" : ["May break semantic boundaries", "Ignores document structure"],
|
| 405 |
+
"estimated_chunks" : analysis["estimated_fixed_chunks"],
|
| 406 |
+
},
|
| 407 |
+
"semantic" : {"recommended_for" : ["Medium documents", "Structured content", "When coherence matters"],
|
| 408 |
+
"pros" : ["Preserves topic boundaries", "Respects section structure", "Better context coherence"],
|
| 409 |
+
"cons" : ["Slower (requires embeddings)", "Less predictable chunk sizes"],
|
| 410 |
+
"estimated_chunks" : analysis["estimated_semantic_chunks"],
|
| 411 |
+
"section_aware" : True,
|
| 412 |
+
},
|
| 413 |
+
"llamaindex" : llamaindex_recommendation,
|
| 414 |
+
"hierarchical" : {"recommended_for" : ["Large documents", "Complex structure", "Granular search needs"],
|
| 415 |
+
"pros" : ["Best for large docs", "Granular + context search", "Scalable"],
|
| 416 |
+
"cons" : ["Complex implementation", "More chunks to manage", "Higher storage"],
|
| 417 |
+
"estimated_chunks" : analysis["estimated_hierarchical_chunks"],
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# Add selected strategy
|
| 422 |
+
selected_strategy, analysis_result = self.select_chunking_strategy(text = text,
|
| 423 |
+
metadata = metadata,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
recommendations["selected_strategy"] = selected_strategy.value
|
| 427 |
+
recommendations["selection_reason"] = analysis_result["selection_reason"]
|
| 428 |
+
recommendations["llamaindex_used"] = analysis_result["llamaindex_used"]
|
| 429 |
+
recommendations["structure_detected"] = analysis_result.get("has_structure", False)
|
| 430 |
+
|
| 431 |
+
return recommendations
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# Global adaptive selector instance
|
| 435 |
+
_adaptive_selector = None
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def get_adaptive_selector() -> AdaptiveChunkingSelector:
|
| 439 |
+
"""
|
| 440 |
+
Get global adaptive selector instance (singleton)
|
| 441 |
+
"""
|
| 442 |
+
global _adaptive_selector
|
| 443 |
+
|
| 444 |
+
if _adaptive_selector is None:
|
| 445 |
+
_adaptive_selector = AdaptiveChunkingSelector()
|
| 446 |
+
|
| 447 |
+
return _adaptive_selector
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def adaptive_chunk_text(text: str, metadata: Optional[DocumentMetadata] = None, force_strategy: Optional[ChunkingStrategy] = None) -> List[DocumentChunk]:
|
| 451 |
+
"""
|
| 452 |
+
Convenience function for adaptive chunking
|
| 453 |
+
"""
|
| 454 |
+
selector = get_adaptive_selector()
|
| 455 |
+
|
| 456 |
+
return selector.chunk_text(text, metadata, force_strategy)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def analyze_document(text: str, metadata: Optional[DocumentMetadata] = None) -> dict:
|
| 460 |
+
"""
|
| 461 |
+
Analyze document without chunking
|
| 462 |
+
"""
|
| 463 |
+
selector = get_adaptive_selector()
|
| 464 |
+
|
| 465 |
+
return selector._analyze_document(text, metadata)
|
chunking/base_chunker.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
from abc import ABC
|
| 4 |
+
from typing import List
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from abc import abstractmethod
|
| 8 |
+
from config.models import DocumentChunk
|
| 9 |
+
from config.models import DocumentMetadata
|
| 10 |
+
from config.models import ChunkingStrategy
|
| 11 |
+
from config.logging_config import get_logger
|
| 12 |
+
from chunking.token_counter import count_tokens
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Setup Logging
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BaseChunker(ABC):
|
| 20 |
+
"""
|
| 21 |
+
Abstract base class for all chunking strategies: Implements Template Method pattern for consistent chunking pipeline
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, strategy_name: ChunkingStrategy):
|
| 24 |
+
"""
|
| 25 |
+
Initialize base chunker
|
| 26 |
+
|
| 27 |
+
Arguments:
|
| 28 |
+
----------
|
| 29 |
+
strategy_name { ChunkingStrategy } : Chunking strategy enum
|
| 30 |
+
"""
|
| 31 |
+
self.strategy_name = strategy_name
|
| 32 |
+
self.logger = logger
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def chunk_text(self, text: str, metadata: Optional[DocumentMetadata] = None) -> List[DocumentChunk]:
|
| 37 |
+
"""
|
| 38 |
+
Chunk text into smaller pieces - must be implemented by subclasses
|
| 39 |
+
|
| 40 |
+
Arguments:
|
| 41 |
+
----------
|
| 42 |
+
text { str } : Input text to chunk
|
| 43 |
+
|
| 44 |
+
metadata { DocumentMetadata } : Document metadata
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
--------
|
| 48 |
+
{ list } : List of DocumentChunk objects
|
| 49 |
+
"""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def chunk_document(self, text: str, metadata: DocumentMetadata) -> List[DocumentChunk]:
|
| 54 |
+
"""
|
| 55 |
+
Chunk document with full metadata: Template method that calls chunk_text and adds metadata
|
| 56 |
+
|
| 57 |
+
Arguments:
|
| 58 |
+
----------
|
| 59 |
+
text { str } : Document text
|
| 60 |
+
|
| 61 |
+
metadata { DocumentMetadata } : Document metadata
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
--------
|
| 65 |
+
{ list } : List of DocumentChunk objects with metadata
|
| 66 |
+
"""
|
| 67 |
+
try:
|
| 68 |
+
self.logger.info(f"Chunking document {metadata.document_id} using {self.strategy_name.value}")
|
| 69 |
+
|
| 70 |
+
# Validate input
|
| 71 |
+
if not text or not text.strip():
|
| 72 |
+
self.logger.warning(f"Empty text for document {metadata.document_id}")
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
# Perform chunking
|
| 76 |
+
chunks = self.chunk_text(text = text,
|
| 77 |
+
metadata = metadata,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Update metadata
|
| 81 |
+
metadata.num_chunks = len(chunks)
|
| 82 |
+
metadata.chunking_strategy = self.strategy_name
|
| 83 |
+
|
| 84 |
+
# Validate chunks
|
| 85 |
+
if not self.validate_chunks(chunks):
|
| 86 |
+
self.logger.warning(f"Chunk validation failed for {metadata.document_id}")
|
| 87 |
+
|
| 88 |
+
self.logger.info(f"Created {len(chunks)} chunks for {metadata.document_id}")
|
| 89 |
+
|
| 90 |
+
return chunks
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
self.logger.error(f"Chunking failed for {metadata.document_id}: {repr(e)}")
|
| 94 |
+
raise
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _create_chunk(self, text: str, chunk_index: int, document_id: str, start_char: int, end_char: int, page_number: Optional[int] = None,
|
| 98 |
+
section_title: Optional[str] = None, metadata: Optional[dict] = None) -> DocumentChunk:
|
| 99 |
+
"""
|
| 100 |
+
Create a DocumentChunk object with proper formatting
|
| 101 |
+
|
| 102 |
+
Arguments:
|
| 103 |
+
----------
|
| 104 |
+
text { str } : Chunk text
|
| 105 |
+
|
| 106 |
+
chunk_index { int } : Index of chunk in document
|
| 107 |
+
|
| 108 |
+
document_id { str } : Parent document ID
|
| 109 |
+
|
| 110 |
+
start_char { int } : Start character position
|
| 111 |
+
|
| 112 |
+
end_char { int } : End character position
|
| 113 |
+
|
| 114 |
+
page_number { int } : Page number (if applicable)
|
| 115 |
+
|
| 116 |
+
section_title { str } : Section heading (CRITICAL for retrieval)
|
| 117 |
+
|
| 118 |
+
metadata { dict } : Additional metadata
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
--------
|
| 122 |
+
{ DocumentChunk } : DocumentChunk object
|
| 123 |
+
"""
|
| 124 |
+
# Generate unique chunk ID
|
| 125 |
+
chunk_id = f"chunk_{document_id}_{chunk_index}"
|
| 126 |
+
|
| 127 |
+
# Count tokens
|
| 128 |
+
token_count = count_tokens(text)
|
| 129 |
+
|
| 130 |
+
# Create chunk with section context
|
| 131 |
+
chunk = DocumentChunk(chunk_id = chunk_id,
|
| 132 |
+
document_id = document_id,
|
| 133 |
+
text = text,
|
| 134 |
+
chunk_index = chunk_index,
|
| 135 |
+
start_char = start_char,
|
| 136 |
+
end_char = end_char,
|
| 137 |
+
page_number = page_number,
|
| 138 |
+
section_title = section_title,
|
| 139 |
+
token_count = token_count,
|
| 140 |
+
metadata = metadata or {},
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return chunk
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _extract_page_number(self, text: str, full_text: str) -> Optional[int]:
|
| 147 |
+
"""
|
| 148 |
+
Try to extract page number from text: Looks for [PAGE N] markers inserted during parsing
|
| 149 |
+
"""
|
| 150 |
+
# Look for page markers in current chunk
|
| 151 |
+
page_match = re.search(r'\[PAGE (\d+)\]', text)
|
| 152 |
+
|
| 153 |
+
if page_match:
|
| 154 |
+
return int(page_match.group(1))
|
| 155 |
+
|
| 156 |
+
# Alternative: try to determine from position in full text
|
| 157 |
+
if full_text:
|
| 158 |
+
chunk_start = full_text.find(text[:min(200, len(text))])
|
| 159 |
+
|
| 160 |
+
if (chunk_start >= 0):
|
| 161 |
+
text_before = full_text[:chunk_start]
|
| 162 |
+
page_matches = re.findall(r'\[PAGE (\d+)\]', text_before)
|
| 163 |
+
|
| 164 |
+
if page_matches:
|
| 165 |
+
return int(page_matches[-1])
|
| 166 |
+
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _clean_chunk_text(self, text: str) -> str:
|
| 171 |
+
"""
|
| 172 |
+
Clean chunk text by removing markers and extra whitespace
|
| 173 |
+
|
| 174 |
+
Arguments:
|
| 175 |
+
----------
|
| 176 |
+
text { str } : Raw chunk text
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
--------
|
| 180 |
+
{ str } : Cleaned text
|
| 181 |
+
"""
|
| 182 |
+
# Remove page markers
|
| 183 |
+
text = re.sub(r'\[PAGE \d+\]', '', text)
|
| 184 |
+
|
| 185 |
+
# Remove other common markers
|
| 186 |
+
text = re.sub(r'\[HEADER\]|\[FOOTER\]|\[TABLE \d+\]', '', text)
|
| 187 |
+
|
| 188 |
+
# Normalize whitespace
|
| 189 |
+
text = re.sub(r'\s+', ' ', text)
|
| 190 |
+
text = text.strip()
|
| 191 |
+
|
| 192 |
+
return text
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def validate_chunks(self, chunks: List[DocumentChunk]) -> bool:
|
| 196 |
+
"""
|
| 197 |
+
Validate chunk list for consistency
|
| 198 |
+
|
| 199 |
+
Arguments:
|
| 200 |
+
----------
|
| 201 |
+
chunks { list } : List of chunks to validate
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
--------
|
| 205 |
+
{ bool } : True if valid
|
| 206 |
+
"""
|
| 207 |
+
if not chunks:
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
# Check all chunks have the same document_id
|
| 211 |
+
doc_ids = {chunk.document_id for chunk in chunks}
|
| 212 |
+
|
| 213 |
+
if (len(doc_ids) > 1):
|
| 214 |
+
self.logger.error(f"Chunks have multiple document IDs: {doc_ids}")
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
# Check chunk indices are sequential
|
| 218 |
+
indices = [chunk.chunk_index for chunk in chunks]
|
| 219 |
+
expected_indices = list(range(len(chunks)))
|
| 220 |
+
|
| 221 |
+
if (indices != expected_indices):
|
| 222 |
+
self.logger.warning(f"Non-sequential chunk indices: {indices}")
|
| 223 |
+
|
| 224 |
+
# Check for empty chunks
|
| 225 |
+
empty_chunks = [c.chunk_index for c in chunks if not c.text.strip()]
|
| 226 |
+
|
| 227 |
+
if empty_chunks:
|
| 228 |
+
self.logger.warning(f"Empty chunks at indices: {empty_chunks}")
|
| 229 |
+
|
| 230 |
+
# Check token counts
|
| 231 |
+
zero_token_chunks = [c.chunk_index for c in chunks if (c.token_count == 0)]
|
| 232 |
+
|
| 233 |
+
if zero_token_chunks:
|
| 234 |
+
self.logger.warning(f"Zero-token chunks at indices: {zero_token_chunks}")
|
| 235 |
+
|
| 236 |
+
# NEW: Check section_title preservation (important for structured documents)
|
| 237 |
+
chunks_with_sections = [c for c in chunks if c.section_title]
|
| 238 |
+
|
| 239 |
+
if chunks_with_sections:
|
| 240 |
+
self.logger.info(f"{len(chunks_with_sections)}/{len(chunks)} chunks have section titles preserved")
|
| 241 |
+
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def get_chunk_statistics(self, chunks: List[DocumentChunk]) -> dict:
|
| 246 |
+
"""
|
| 247 |
+
Calculate statistics for chunk list
|
| 248 |
+
|
| 249 |
+
Arguments:
|
| 250 |
+
----------
|
| 251 |
+
chunks { list } : List of chunks
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
--------
|
| 255 |
+
{ dict } : Dictionary with statistics
|
| 256 |
+
"""
|
| 257 |
+
if not chunks:
|
| 258 |
+
return {"num_chunks" : 0,
|
| 259 |
+
"total_tokens" : 0,
|
| 260 |
+
"avg_tokens_per_chunk" : 0,
|
| 261 |
+
"min_tokens" : 0,
|
| 262 |
+
"max_tokens" : 0,
|
| 263 |
+
"total_chars" : 0,
|
| 264 |
+
"avg_chars_per_chunk" : 0,
|
| 265 |
+
"chunks_with_sections" : 0,
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
token_counts = [c.token_count for c in chunks]
|
| 269 |
+
char_counts = [len(c.text) for c in chunks]
|
| 270 |
+
chunks_with_sections = sum(1 for c in chunks if c.section_title)
|
| 271 |
+
|
| 272 |
+
stats = {"num_chunks" : len(chunks),
|
| 273 |
+
"total_tokens" : sum(token_counts),
|
| 274 |
+
"avg_tokens_per_chunk" : sum(token_counts) / len(chunks),
|
| 275 |
+
"min_tokens" : min(token_counts),
|
| 276 |
+
"max_tokens" : max(token_counts),
|
| 277 |
+
"total_chars" : sum(char_counts),
|
| 278 |
+
"avg_chars_per_chunk" : sum(char_counts) / len(chunks),
|
| 279 |
+
"strategy" : self.strategy_name.value,
|
| 280 |
+
"chunks_with_sections" : chunks_with_sections,
|
| 281 |
+
"section_coverage_pct" : (chunks_with_sections / len(chunks)) * 100,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
return stats
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def merge_chunks(self, chunks: List[DocumentChunk], max_tokens: int) -> List[DocumentChunk]:
|
| 288 |
+
"""
|
| 289 |
+
Merge small chunks up to max_tokens: Useful for optimizing chunk sizes
|
| 290 |
+
|
| 291 |
+
Arguments:
|
| 292 |
+
----------
|
| 293 |
+
chunks { list } : List of chunks to merge
|
| 294 |
+
|
| 295 |
+
max_tokens { int } : Maximum tokens per merged chunk
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
--------
|
| 299 |
+
{ list } : List of merged chunks
|
| 300 |
+
"""
|
| 301 |
+
if not chunks:
|
| 302 |
+
return []
|
| 303 |
+
|
| 304 |
+
merged = list()
|
| 305 |
+
current_chunks = list()
|
| 306 |
+
current_tokens = 0
|
| 307 |
+
document_id = chunks[0].document_id
|
| 308 |
+
|
| 309 |
+
for chunk in chunks:
|
| 310 |
+
if ((current_tokens + chunk.token_count) <= max_tokens):
|
| 311 |
+
current_chunks.append(chunk)
|
| 312 |
+
current_tokens += chunk.token_count
|
| 313 |
+
|
| 314 |
+
else:
|
| 315 |
+
# Save current merged chunk
|
| 316 |
+
if current_chunks:
|
| 317 |
+
merged_text = " ".join(c.text for c in current_chunks)
|
| 318 |
+
merged_chunk = self._create_chunk(text = merged_text,
|
| 319 |
+
chunk_index = len(merged),
|
| 320 |
+
document_id = document_id,
|
| 321 |
+
start_char = current_chunks[0].start_char,
|
| 322 |
+
end_char = current_chunks[-1].end_char,
|
| 323 |
+
page_number = current_chunks[0].page_number,
|
| 324 |
+
section_title = current_chunks[0].section_title,
|
| 325 |
+
)
|
| 326 |
+
merged.append(merged_chunk)
|
| 327 |
+
|
| 328 |
+
# Start new chunk
|
| 329 |
+
current_chunks = [chunk]
|
| 330 |
+
current_tokens = chunk.token_count
|
| 331 |
+
|
| 332 |
+
# Add final merged chunk
|
| 333 |
+
if current_chunks:
|
| 334 |
+
merged_text = " ".join(c.text for c in current_chunks)
|
| 335 |
+
merged_chunk = self._create_chunk(text = merged_text,
|
| 336 |
+
chunk_index = len(merged),
|
| 337 |
+
document_id = document_id,
|
| 338 |
+
start_char = current_chunks[0].start_char,
|
| 339 |
+
end_char = current_chunks[-1].end_char,
|
| 340 |
+
page_number = current_chunks[0].page_number,
|
| 341 |
+
section_title = current_chunks[0].section_title,
|
| 342 |
+
)
|
| 343 |
+
merged.append(merged_chunk)
|
| 344 |
+
|
| 345 |
+
self.logger.info(f"Merged {len(chunks)} chunks into {len(merged)}")
|
| 346 |
+
|
| 347 |
+
return merged
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def __str__(self) -> str:
|
| 351 |
+
"""
|
| 352 |
+
String representation
|
| 353 |
+
"""
|
| 354 |
+
return f"{self.__class__.__name__}(strategy={self.strategy_name.value})"
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def __repr__(self) -> str:
|
| 358 |
+
"""
|
| 359 |
+
Detailed representation
|
| 360 |
+
"""
|
| 361 |
+
return self.__str__()
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class ChunkerConfig:
|
| 366 |
+
"""
|
| 367 |
+
Configuration for chunking strategies: Provides a way to pass parameters to chunkers
|
| 368 |
+
"""
|
| 369 |
+
def __init__(self, chunk_size: int = 512, overlap: int = 50, respect_boundaries: bool = True, min_chunk_size: int = 100, **kwargs):
|
| 370 |
+
"""
|
| 371 |
+
Initialize chunker configuration
|
| 372 |
+
|
| 373 |
+
Arguments:
|
| 374 |
+
----------
|
| 375 |
+
chunk_size { int } : Target chunk size in tokens
|
| 376 |
+
|
| 377 |
+
overlap { int } : Overlap between chunks in tokens
|
| 378 |
+
|
| 379 |
+
respect_boundaries { bool } : Respect sentence/paragraph/section boundaries
|
| 380 |
+
|
| 381 |
+
min_chunk_size { int } : Minimum chunk size in tokens
|
| 382 |
+
|
| 383 |
+
**kwargs : Additional strategy-specific parameters
|
| 384 |
+
"""
|
| 385 |
+
self.chunk_size = chunk_size
|
| 386 |
+
self.overlap = overlap
|
| 387 |
+
self.respect_boundaries = respect_boundaries
|
| 388 |
+
self.min_chunk_size = min_chunk_size
|
| 389 |
+
self.extra = kwargs
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def to_dict(self) -> dict:
|
| 393 |
+
"""
|
| 394 |
+
Convert to dictionary
|
| 395 |
+
"""
|
| 396 |
+
return {"chunk_size" : self.chunk_size,
|
| 397 |
+
"overlap" : self.overlap,
|
| 398 |
+
"respect_boundaries" : self.respect_boundaries,
|
| 399 |
+
"min_chunk_size" : self.min_chunk_size,
|
| 400 |
+
**self.extra
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def __repr__(self) -> str:
|
| 405 |
+
return f"ChunkerConfig({self.to_dict()})"
|
chunking/fixed_chunker.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from config.models import DocumentChunk
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.models import DocumentMetadata
|
| 8 |
+
from config.models import ChunkingStrategy
|
| 9 |
+
from config.logging_config import get_logger
|
| 10 |
+
from chunking.base_chunker import BaseChunker
|
| 11 |
+
from chunking.base_chunker import ChunkerConfig
|
| 12 |
+
from chunking.token_counter import TokenCounter
|
| 13 |
+
from chunking.overlap_manager import OverlapManager
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Setup Settings and Logging
|
| 17 |
+
logger = get_logger(__name__)
|
| 18 |
+
settings = get_settings()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FixedChunker(BaseChunker):
|
| 22 |
+
"""
|
| 23 |
+
Fixed-size chunking strategy : Splits text into chunks of approximately equal token count with overlap
|
| 24 |
+
|
| 25 |
+
Best for:
|
| 26 |
+
- Small to medium documents (<50K tokens)
|
| 27 |
+
- Homogeneous content
|
| 28 |
+
- When simplicity is preferred
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, chunk_size: int = None, overlap: int = None, respect_sentence_boundaries: bool = True, min_chunk_size: int = 100):
|
| 31 |
+
"""
|
| 32 |
+
Initialize fixed chunker
|
| 33 |
+
|
| 34 |
+
Arguments:
|
| 35 |
+
----------
|
| 36 |
+
chunk_size { int } : Target tokens per chunk (default from settings)
|
| 37 |
+
|
| 38 |
+
overlap { int } : Overlap tokens between chunks (default from settings)
|
| 39 |
+
|
| 40 |
+
respect_sentence_boundaries { bool } : Try to break at sentence boundaries
|
| 41 |
+
|
| 42 |
+
min_chunk_size { int } : Minimum chunk size in tokens
|
| 43 |
+
"""
|
| 44 |
+
super().__init__(ChunkingStrategy.FIXED)
|
| 45 |
+
|
| 46 |
+
self.chunk_size = chunk_size or settings.FIXED_CHUNK_SIZE
|
| 47 |
+
self.overlap = overlap or settings.FIXED_CHUNK_OVERLAP
|
| 48 |
+
self.respect_sentence_boundaries = respect_sentence_boundaries
|
| 49 |
+
self.min_chunk_size = min_chunk_size
|
| 50 |
+
|
| 51 |
+
# Initialize token counter and overlap manager
|
| 52 |
+
self.token_counter = TokenCounter()
|
| 53 |
+
self.overlap_manager = OverlapManager(overlap_tokens = self.overlap)
|
| 54 |
+
|
| 55 |
+
# Validate parameters
|
| 56 |
+
if (self.overlap >= self.chunk_size):
|
| 57 |
+
raise ValueError(f"Overlap ({self.overlap}) must be less than chunk_size ({self.chunk_size})")
|
| 58 |
+
|
| 59 |
+
self.logger.info(f"Initialized FixedChunker: chunk_size={self.chunk_size}, overlap={self.overlap}, respect_boundaries={self.respect_sentence_boundaries}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def chunk_text(self, text: str, metadata: Optional[DocumentMetadata] = None) -> List[DocumentChunk]:
|
| 63 |
+
"""
|
| 64 |
+
Chunk text into fixed-size pieces
|
| 65 |
+
|
| 66 |
+
Arguments:
|
| 67 |
+
----------
|
| 68 |
+
text { str } : Input text
|
| 69 |
+
|
| 70 |
+
metadata { DocumentMetaData } : Document metadata
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
--------
|
| 74 |
+
{ list } : List of DocumentChunk objects
|
| 75 |
+
"""
|
| 76 |
+
if not text or not text.strip():
|
| 77 |
+
return []
|
| 78 |
+
|
| 79 |
+
document_id = metadata.document_id if metadata else "unknown"
|
| 80 |
+
|
| 81 |
+
# Split into sentences if respecting boundaries
|
| 82 |
+
if self.respect_sentence_boundaries:
|
| 83 |
+
chunks = self._chunk_with_sentence_boundaries(text = text,
|
| 84 |
+
document_id = document_id,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
chunks = self._chunk_without_boundaries(text = text,
|
| 89 |
+
document_id = document_id,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Clean and validate
|
| 93 |
+
chunks = [c for c in chunks if (c.token_count >= self.min_chunk_size)]
|
| 94 |
+
|
| 95 |
+
# Use OverlapManager to add proper overlap
|
| 96 |
+
if ((len(chunks) > 1) and (self.overlap > 0)):
|
| 97 |
+
chunks = self.overlap_manager.add_overlap(chunks = chunks,
|
| 98 |
+
overlap_tokens = self.overlap,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.logger.debug(f"Created {len(chunks)} fixed-size chunks")
|
| 102 |
+
|
| 103 |
+
return chunks
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _chunk_with_sentence_boundaries(self, text: str, document_id: str) -> List[DocumentChunk]:
|
| 107 |
+
"""
|
| 108 |
+
Chunk text respecting sentence boundaries
|
| 109 |
+
|
| 110 |
+
Arguments:
|
| 111 |
+
----------
|
| 112 |
+
text { str } : Input text
|
| 113 |
+
|
| 114 |
+
document_id { str } : Document ID
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
--------
|
| 118 |
+
{ list } : List of chunks without overlap (overlap added later)
|
| 119 |
+
"""
|
| 120 |
+
# Split into sentences
|
| 121 |
+
sentences = self._split_sentences(text = text)
|
| 122 |
+
|
| 123 |
+
chunks = list()
|
| 124 |
+
current_sentences = list()
|
| 125 |
+
current_tokens = 0
|
| 126 |
+
start_char = 0
|
| 127 |
+
|
| 128 |
+
for sentence in sentences:
|
| 129 |
+
sentence_tokens = self.token_counter.count_tokens(text = sentence)
|
| 130 |
+
|
| 131 |
+
# If single sentence exceeds chunk_size, split it
|
| 132 |
+
if (sentence_tokens > self.chunk_size):
|
| 133 |
+
# Save current chunk if any
|
| 134 |
+
if current_sentences:
|
| 135 |
+
chunk_text = " ".join(current_sentences)
|
| 136 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 137 |
+
chunk_index = len(chunks),
|
| 138 |
+
document_id = document_id,
|
| 139 |
+
start_char = start_char,
|
| 140 |
+
end_char = start_char + len(chunk_text),
|
| 141 |
+
)
|
| 142 |
+
chunks.append(chunk)
|
| 143 |
+
|
| 144 |
+
current_sentences = list()
|
| 145 |
+
current_tokens = 0
|
| 146 |
+
start_char += len(chunk_text)
|
| 147 |
+
|
| 148 |
+
# Split long sentence and add as separate chunks
|
| 149 |
+
long_sentence_chunks = self._split_long_sentence(sentence = sentence,
|
| 150 |
+
document_id = document_id,
|
| 151 |
+
start_index = len(chunks),
|
| 152 |
+
start_char = start_char,
|
| 153 |
+
)
|
| 154 |
+
chunks.extend(long_sentence_chunks)
|
| 155 |
+
start_char += len(sentence)
|
| 156 |
+
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
# Check if adding this sentence exceeds chunk_size
|
| 160 |
+
if (((current_tokens + sentence_tokens) > self.chunk_size) and current_sentences):
|
| 161 |
+
# Save current chunk WITHOUT overlap (overlap added later)
|
| 162 |
+
chunk_text = " ".join(current_sentences)
|
| 163 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 164 |
+
chunk_index = len(chunks),
|
| 165 |
+
document_id = document_id,
|
| 166 |
+
start_char = start_char,
|
| 167 |
+
end_char = start_char + len(chunk_text),
|
| 168 |
+
)
|
| 169 |
+
chunks.append(chunk)
|
| 170 |
+
|
| 171 |
+
# OverlapManager will handle the overlap here
|
| 172 |
+
current_sentences = [sentence]
|
| 173 |
+
current_tokens = sentence_tokens
|
| 174 |
+
start_char += len(chunk_text)
|
| 175 |
+
|
| 176 |
+
else:
|
| 177 |
+
# Add sentence to current chunk
|
| 178 |
+
current_sentences.append(sentence)
|
| 179 |
+
current_tokens += sentence_tokens
|
| 180 |
+
|
| 181 |
+
# Add final chunk
|
| 182 |
+
if current_sentences:
|
| 183 |
+
chunk_text = " ".join(current_sentences)
|
| 184 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 185 |
+
chunk_index = len(chunks),
|
| 186 |
+
document_id = document_id,
|
| 187 |
+
start_char = start_char,
|
| 188 |
+
end_char = start_char + len(chunk_text),
|
| 189 |
+
)
|
| 190 |
+
chunks.append(chunk)
|
| 191 |
+
|
| 192 |
+
return chunks
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _chunk_without_boundaries(self, text: str, document_id: str) -> List[DocumentChunk]:
|
| 196 |
+
"""
|
| 197 |
+
Chunk text without respecting boundaries (pure token-based)
|
| 198 |
+
|
| 199 |
+
Arguments:
|
| 200 |
+
----------
|
| 201 |
+
text { str } : Input text
|
| 202 |
+
|
| 203 |
+
document_id { str } : Document ID
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
--------
|
| 207 |
+
{ list } : List of chunks WITHOUT overlap
|
| 208 |
+
"""
|
| 209 |
+
# Use token counter's split method
|
| 210 |
+
chunk_texts = self.token_counter.split_into_token_chunks(text,
|
| 211 |
+
chunk_size = self.chunk_size,
|
| 212 |
+
overlap = 0,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
chunks = list()
|
| 216 |
+
current_pos = 0
|
| 217 |
+
|
| 218 |
+
for i, chunk_text in enumerate(chunk_texts):
|
| 219 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 220 |
+
chunk_index = i,
|
| 221 |
+
document_id = document_id,
|
| 222 |
+
start_char = current_pos,
|
| 223 |
+
end_char = current_pos + len(chunk_text),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
chunks.append(chunk)
|
| 227 |
+
current_pos += len(chunk_text)
|
| 228 |
+
|
| 229 |
+
return chunks
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _split_sentences(self, text: str) -> List[str]:
|
| 233 |
+
"""
|
| 234 |
+
Split text into sentences
|
| 235 |
+
|
| 236 |
+
Arguments:
|
| 237 |
+
----------
|
| 238 |
+
text { str } : Input text
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
--------
|
| 242 |
+
{ list } : List of sentences
|
| 243 |
+
"""
|
| 244 |
+
# Handle common abbreviations: Protect them temporarily
|
| 245 |
+
protected = text
|
| 246 |
+
abbreviations = ['Dr.', 'Mr.', 'Mrs.', 'Ms.', 'Jr.', 'Sr.', 'Prof.', 'Inc.', 'Ltd.', 'Corp.', 'Co.', 'vs.', 'etc.', 'e.g.', 'i.e.', 'Ph.D.', 'M.D.', 'B.A.', 'M.A.', 'U.S.', 'U.K.']
|
| 247 |
+
|
| 248 |
+
for abbr in abbreviations:
|
| 249 |
+
protected = protected.replace(abbr, abbr.replace('.', '<DOT>'))
|
| 250 |
+
|
| 251 |
+
# Split on sentence boundaries
|
| 252 |
+
# - Pattern: period/question/exclamation followed by space and capital letter
|
| 253 |
+
sentence_pattern = r'(?<=[.!?])\s+(?=[A-Z])'
|
| 254 |
+
sentences = re.split(sentence_pattern, protected)
|
| 255 |
+
|
| 256 |
+
# Restore abbreviations
|
| 257 |
+
sentences = [s.replace('<DOT>', '.').strip() for s in sentences]
|
| 258 |
+
|
| 259 |
+
# Filter empty
|
| 260 |
+
sentences = [s for s in sentences if s]
|
| 261 |
+
|
| 262 |
+
return sentences
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _split_long_sentence(self, sentence: str, document_id: str, start_index: int, start_char: int) -> List[DocumentChunk]:
|
| 266 |
+
"""
|
| 267 |
+
Split a sentence that's longer than chunk_size
|
| 268 |
+
|
| 269 |
+
Arguments:
|
| 270 |
+
----------
|
| 271 |
+
sentence { str } : Long sentence
|
| 272 |
+
|
| 273 |
+
document_id { str } : Document ID
|
| 274 |
+
|
| 275 |
+
start_index { str } : Starting chunk index
|
| 276 |
+
|
| 277 |
+
start_char { int } : Starting character position
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
--------
|
| 281 |
+
{ list } : List of chunks
|
| 282 |
+
"""
|
| 283 |
+
# Split by commas, semicolons, or just by tokens
|
| 284 |
+
parts = re.split(r'[,;]', sentence)
|
| 285 |
+
|
| 286 |
+
chunks = list()
|
| 287 |
+
current_text = list()
|
| 288 |
+
current_tokens = 0
|
| 289 |
+
|
| 290 |
+
for part in parts:
|
| 291 |
+
part = part.strip()
|
| 292 |
+
if not part:
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
part_tokens = self.token_counter.count_tokens(part)
|
| 296 |
+
|
| 297 |
+
if (((current_tokens + part_tokens) > self.chunk_size) and current_text):
|
| 298 |
+
# Save current chunk
|
| 299 |
+
chunk_text = " ".join(current_text)
|
| 300 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 301 |
+
chunk_index = start_index + len(chunks),
|
| 302 |
+
document_id = document_id,
|
| 303 |
+
start_char = start_char,
|
| 304 |
+
end_char = start_char + len(chunk_text),
|
| 305 |
+
)
|
| 306 |
+
chunks.append(chunk)
|
| 307 |
+
start_char += len(chunk_text)
|
| 308 |
+
current_text = []
|
| 309 |
+
current_tokens = 0
|
| 310 |
+
|
| 311 |
+
current_text.append(part)
|
| 312 |
+
current_tokens += part_tokens
|
| 313 |
+
|
| 314 |
+
# Add final part
|
| 315 |
+
if current_text:
|
| 316 |
+
chunk_text = " ".join(current_text)
|
| 317 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 318 |
+
chunk_index = start_index + len(chunks),
|
| 319 |
+
document_id = document_id,
|
| 320 |
+
start_char = start_char,
|
| 321 |
+
end_char = start_char + len(chunk_text),
|
| 322 |
+
)
|
| 323 |
+
chunks.append(chunk)
|
| 324 |
+
|
| 325 |
+
return chunks
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def _get_overlap_sentences(self, sentences: List[str], overlap_tokens: int) -> List[str]:
|
| 329 |
+
"""
|
| 330 |
+
Get last few sentences that fit in overlap window
|
| 331 |
+
|
| 332 |
+
Arguments:
|
| 333 |
+
----------
|
| 334 |
+
sentences { list } : List of sentences
|
| 335 |
+
|
| 336 |
+
overlap_tokens { int } : Target overlap tokens
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
--------
|
| 340 |
+
{ list } : List of overlap sentences
|
| 341 |
+
"""
|
| 342 |
+
overlap = list()
|
| 343 |
+
tokens = 0
|
| 344 |
+
|
| 345 |
+
# Add sentences from the end until we reach overlap size
|
| 346 |
+
for sentence in reversed(sentences):
|
| 347 |
+
sentence_tokens = self.token_counter.count_tokens(sentence)
|
| 348 |
+
|
| 349 |
+
if ((tokens + sentence_tokens) <= overlap_tokens):
|
| 350 |
+
overlap.insert(0, sentence)
|
| 351 |
+
tokens += sentence_tokens
|
| 352 |
+
|
| 353 |
+
else:
|
| 354 |
+
break
|
| 355 |
+
|
| 356 |
+
return overlap
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@classmethod
|
| 360 |
+
def from_config(cls, config: ChunkerConfig) -> 'FixedChunker':
|
| 361 |
+
"""
|
| 362 |
+
Create FixedChunker from configuration
|
| 363 |
+
|
| 364 |
+
Arguments:
|
| 365 |
+
----------
|
| 366 |
+
config { ChunkerConfig } : ChunkerConfig object
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
--------
|
| 370 |
+
FixedChunker instance
|
| 371 |
+
"""
|
| 372 |
+
return cls(chunk_size = config.chunk_size,
|
| 373 |
+
overlap = config.overlap,
|
| 374 |
+
respect_sentence_boundaries = config.respect_boundaries,
|
| 375 |
+
min_chunk_size = config.min_chunk_size,
|
| 376 |
+
)
|
chunking/hierarchical_chunker.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from config.models import DocumentChunk
|
| 5 |
+
from config.settings import get_settings
|
| 6 |
+
from config.models import DocumentMetadata
|
| 7 |
+
from config.models import ChunkingStrategy
|
| 8 |
+
from config.logging_config import get_logger
|
| 9 |
+
from chunking.base_chunker import BaseChunker
|
| 10 |
+
from chunking.base_chunker import ChunkerConfig
|
| 11 |
+
from chunking.token_counter import TokenCounter
|
| 12 |
+
from chunking.overlap_manager import OverlapManager
|
| 13 |
+
from chunking.fixed_chunker import FixedChunker
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Setup Settings and Logging
|
| 17 |
+
logger = get_logger(__name__)
|
| 18 |
+
settings = get_settings()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HierarchicalChunker(BaseChunker):
|
| 22 |
+
"""
|
| 23 |
+
Hierarchical chunking strategy:
|
| 24 |
+
- Creates parent chunks (large) and child chunks (small)
|
| 25 |
+
- Child chunks for granular search, parent chunks for context
|
| 26 |
+
- Maintains parent-child relationships for context expansion
|
| 27 |
+
|
| 28 |
+
Best for:
|
| 29 |
+
- Large documents (>500K tokens)
|
| 30 |
+
- Complex documents with nested structure
|
| 31 |
+
- When both granular search and context preservation are needed
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, parent_chunk_size: int = None, child_chunk_size: int = None, overlap: int = None, min_chunk_size: int = 100):
|
| 34 |
+
"""
|
| 35 |
+
Initialize hierarchical chunker
|
| 36 |
+
|
| 37 |
+
Arguments:
|
| 38 |
+
----------
|
| 39 |
+
parent_chunk_size { int } : Size of parent chunks in tokens
|
| 40 |
+
|
| 41 |
+
child_chunk_size { int } : Size of child chunks in tokens
|
| 42 |
+
|
| 43 |
+
overlap { int } : Overlap between child chunks
|
| 44 |
+
|
| 45 |
+
min_chunk_size { int } : Minimum chunk size in tokens
|
| 46 |
+
"""
|
| 47 |
+
super().__init__(ChunkingStrategy.HIERARCHICAL)
|
| 48 |
+
|
| 49 |
+
self.parent_chunk_size = parent_chunk_size or settings.PARENT_CHUNK_SIZE
|
| 50 |
+
self.child_chunk_size = child_chunk_size or settings.CHILD_CHUNK_SIZE
|
| 51 |
+
self.overlap = overlap or settings.FIXED_CHUNK_OVERLAP
|
| 52 |
+
self.min_chunk_size = min_chunk_size
|
| 53 |
+
|
| 54 |
+
# Validate parameters
|
| 55 |
+
if (self.child_chunk_size >= self.parent_chunk_size):
|
| 56 |
+
raise ValueError(f"Child chunk size ({self.child_chunk_size}) must be smaller than parent chunk size ({self.parent_chunk_size})")
|
| 57 |
+
|
| 58 |
+
# Initialize dependencies
|
| 59 |
+
self.token_counter = TokenCounter()
|
| 60 |
+
self.overlap_manager = OverlapManager(overlap_tokens = self.overlap)
|
| 61 |
+
self.child_chunker = FixedChunker(chunk_size = self.child_chunk_size,
|
| 62 |
+
overlap = self.overlap,
|
| 63 |
+
respect_sentence_boundaries = True,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.logger.info(f"Initialized HierarchicalChunker: parent_size={self.parent_chunk_size}, child_size={self.child_chunk_size}, overlap={self.overlap}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def chunk_text(self, text: str, metadata: Optional[DocumentMetadata] = None) -> List[DocumentChunk]:
|
| 70 |
+
"""
|
| 71 |
+
Create hierarchical chunks with parent-child relationships
|
| 72 |
+
|
| 73 |
+
Arguments:
|
| 74 |
+
----------
|
| 75 |
+
text { str } : Input text
|
| 76 |
+
|
| 77 |
+
metadata { DocumentMetaData } : Document metadata
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
--------
|
| 81 |
+
{ list } : List of DocumentChunk objects (children with parent references)
|
| 82 |
+
"""
|
| 83 |
+
if not text or not text.strip():
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
document_id = metadata.document_id if metadata else "unknown"
|
| 87 |
+
|
| 88 |
+
# Create parent chunks (large context windows)
|
| 89 |
+
parent_chunks = self._create_parent_chunks(text, document_id)
|
| 90 |
+
|
| 91 |
+
# For each parent chunk, create child chunks (granular search)
|
| 92 |
+
all_child_chunks = list()
|
| 93 |
+
|
| 94 |
+
for parent_chunk in parent_chunks:
|
| 95 |
+
child_chunks = self._create_child_chunks(parent_chunk = parent_chunk,
|
| 96 |
+
parent_text = text,
|
| 97 |
+
document_id = document_id,
|
| 98 |
+
)
|
| 99 |
+
all_child_chunks.extend(child_chunks)
|
| 100 |
+
|
| 101 |
+
# Step 3: Filter small chunks
|
| 102 |
+
all_child_chunks = [c for c in all_child_chunks if (c.token_count >= self.min_chunk_size)]
|
| 103 |
+
|
| 104 |
+
self.logger.info(f"Created {len(all_child_chunks)} child chunks from {len(parent_chunks)} parent chunks")
|
| 105 |
+
|
| 106 |
+
return all_child_chunks
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _create_parent_chunks(self, text: str, document_id: str) -> List[DocumentChunk]:
|
| 110 |
+
"""
|
| 111 |
+
Create large parent chunks for context preservation
|
| 112 |
+
|
| 113 |
+
Arguments:
|
| 114 |
+
----------
|
| 115 |
+
text { str } : Input text
|
| 116 |
+
|
| 117 |
+
document_id { str } : Document ID
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
--------
|
| 121 |
+
{ list } : List of parent chunks WITHOUT overlap
|
| 122 |
+
"""
|
| 123 |
+
# Use fixed chunking for parents (no overlap between parents)
|
| 124 |
+
parent_chunker = FixedChunker(chunk_size = self.parent_chunk_size,
|
| 125 |
+
overlap = 0, # No overlap between parents
|
| 126 |
+
respect_sentence_boundaries = True,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Create parent chunks
|
| 130 |
+
parent_chunks = parent_chunker._chunk_with_sentence_boundaries(text = text,
|
| 131 |
+
document_id = document_id,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Add parent metadata
|
| 135 |
+
for i, chunk in enumerate(parent_chunks):
|
| 136 |
+
chunk.metadata["chunk_type"] = "parent"
|
| 137 |
+
chunk.metadata["parent_chunk_id"] = chunk.chunk_id
|
| 138 |
+
|
| 139 |
+
return parent_chunks
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _create_child_chunks(self, parent_chunk: DocumentChunk, parent_text: str, document_id: str) -> List[DocumentChunk]:
|
| 143 |
+
"""
|
| 144 |
+
Create child chunks within a parent chunk
|
| 145 |
+
|
| 146 |
+
Arguments:
|
| 147 |
+
----------
|
| 148 |
+
parent_chunk { DocumentChunk } : Parent chunk object
|
| 149 |
+
|
| 150 |
+
parent_text { str } : Full parent text (for position reference)
|
| 151 |
+
|
| 152 |
+
document_id { str } : Document ID
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
--------
|
| 156 |
+
{ list } : List of child chunks with parent references
|
| 157 |
+
"""
|
| 158 |
+
# Extract the actual text segment from parent_text using parent chunk positions
|
| 159 |
+
parent_segment = parent_text[parent_chunk.start_char:parent_chunk.end_char]
|
| 160 |
+
|
| 161 |
+
# Create child chunks within this parent segment
|
| 162 |
+
child_chunker = FixedChunker(chunk_size = self.child_chunk_size,
|
| 163 |
+
overlap = self.overlap,
|
| 164 |
+
respect_sentence_boundaries = True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Create child chunks with proper positioning
|
| 168 |
+
child_chunks = child_chunker._chunk_with_sentence_boundaries(text = parent_segment,
|
| 169 |
+
document_id = document_id,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Update child chunks with parent relationship and correct positions
|
| 173 |
+
for i, child_chunk in enumerate(child_chunks):
|
| 174 |
+
# Adjust positions to be relative to full document
|
| 175 |
+
child_chunk.start_char += parent_chunk.start_char
|
| 176 |
+
child_chunk.end_char += parent_chunk.start_char
|
| 177 |
+
|
| 178 |
+
# Add parent relationship metadata
|
| 179 |
+
child_chunk.metadata["chunk_type"] = "child"
|
| 180 |
+
child_chunk.metadata["parent_chunk_id"] = parent_chunk.chunk_id
|
| 181 |
+
child_chunk.metadata["parent_index"] = i
|
| 182 |
+
|
| 183 |
+
# Update chunk ID to reflect hierarchy
|
| 184 |
+
child_chunk.chunk_id = f"{parent_chunk.chunk_id}_child_{i}"
|
| 185 |
+
|
| 186 |
+
return child_chunks
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def expand_to_parent_context(self, child_chunk: DocumentChunk, all_chunks: List[DocumentChunk]) -> DocumentChunk:
|
| 190 |
+
"""
|
| 191 |
+
Expand a child chunk to include full parent context for generation
|
| 192 |
+
|
| 193 |
+
Arguments:
|
| 194 |
+
----------
|
| 195 |
+
child_chunk { DocumentChunk } : Child chunk to expand
|
| 196 |
+
|
| 197 |
+
all_chunks { list } : All chunks from the document
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
--------
|
| 201 |
+
{ DocumentChunk } : Expanded chunk with parent context
|
| 202 |
+
"""
|
| 203 |
+
# Find the parent chunk
|
| 204 |
+
parent_chunk_id = child_chunk.metadata.get("parent_chunk_id")
|
| 205 |
+
|
| 206 |
+
if not parent_chunk_id:
|
| 207 |
+
return child_chunk
|
| 208 |
+
|
| 209 |
+
parent_chunk = next((c for c in all_chunks if c.chunk_id == parent_chunk_id), None)
|
| 210 |
+
|
| 211 |
+
if not parent_chunk:
|
| 212 |
+
return child_chunk
|
| 213 |
+
|
| 214 |
+
# Create expanded chunk with parent context
|
| 215 |
+
expanded_text = f"[PARENT_CONTEXT]\n{parent_chunk.text}\n\n[CHILD_CONTEXT]\n{child_chunk.text}"
|
| 216 |
+
|
| 217 |
+
expanded_chunk = DocumentChunk(chunk_id = f"{child_chunk.chunk_id}_expanded",
|
| 218 |
+
document_id = child_chunk.document_id,
|
| 219 |
+
text = expanded_text,
|
| 220 |
+
chunk_index = child_chunk.chunk_index,
|
| 221 |
+
start_char = child_chunk.start_char,
|
| 222 |
+
end_char = child_chunk.end_char,
|
| 223 |
+
page_number = child_chunk.page_number,
|
| 224 |
+
section_title = child_chunk.section_title,
|
| 225 |
+
token_count = self.token_counter.count_tokens(expanded_text),
|
| 226 |
+
parent_chunk_id = parent_chunk_id,
|
| 227 |
+
child_chunk_ids = [child_chunk.chunk_id],
|
| 228 |
+
metadata = {**child_chunk.metadata, "expanded": True},
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return expanded_chunk
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_parent_child_relationships(self, chunks: List[DocumentChunk]) -> dict:
|
| 235 |
+
"""
|
| 236 |
+
Extract parent-child relationships from chunks
|
| 237 |
+
|
| 238 |
+
Arguments:
|
| 239 |
+
----------
|
| 240 |
+
chunks { list } : List of chunks
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
--------
|
| 244 |
+
{ dict } : Dictionary mapping parent IDs to child chunks
|
| 245 |
+
"""
|
| 246 |
+
relationships = dict()
|
| 247 |
+
|
| 248 |
+
for chunk in chunks:
|
| 249 |
+
if (chunk.metadata.get("chunk_type") == "parent"):
|
| 250 |
+
relationships[chunk.chunk_id] = {"parent" : chunk,
|
| 251 |
+
"children" : [],
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
for chunk in chunks:
|
| 255 |
+
parent_id = chunk.metadata.get("parent_chunk_id")
|
| 256 |
+
|
| 257 |
+
if parent_id and parent_id in relationships:
|
| 258 |
+
relationships[parent_id]["children"].append(chunk)
|
| 259 |
+
|
| 260 |
+
return relationships
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@classmethod
|
| 264 |
+
def from_config(cls, config: ChunkerConfig) -> 'HierarchicalChunker':
|
| 265 |
+
"""
|
| 266 |
+
Create HierarchicalChunker from configuration
|
| 267 |
+
|
| 268 |
+
Arguments:
|
| 269 |
+
----------
|
| 270 |
+
config { ChunkerConfig } : ChunkerConfig object
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
--------
|
| 274 |
+
{ HierarchicalChunker } : HierarchicalChunker instance
|
| 275 |
+
"""
|
| 276 |
+
return cls(parent_chunk_size = config.extra.get('parent_size', settings.PARENT_CHUNK_SIZE),
|
| 277 |
+
child_chunk_size = config.extra.get('child_size', settings.CHILD_CHUNK_SIZE),
|
| 278 |
+
overlap = config.overlap,
|
| 279 |
+
min_chunk_size = config.min_chunk_size,
|
| 280 |
+
)
|
chunking/llamaindex_chunker.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from config.models import DocumentChunk
|
| 5 |
+
from config.settings import get_settings
|
| 6 |
+
from config.models import DocumentMetadata
|
| 7 |
+
from config.models import ChunkingStrategy
|
| 8 |
+
from config.logging_config import get_logger
|
| 9 |
+
from chunking.base_chunker import BaseChunker
|
| 10 |
+
from chunking.base_chunker import ChunkerConfig
|
| 11 |
+
from chunking.token_counter import TokenCounter
|
| 12 |
+
from chunking.semantic_chunker import SemanticChunker
|
| 13 |
+
from llama_index.core.node_parser import SentenceSplitter
|
| 14 |
+
from llama_index.core.node_parser import TokenTextSplitter
|
| 15 |
+
from llama_index.core.schema import Document as LlamaDocument
|
| 16 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 17 |
+
from llama_index.core.node_parser import SemanticSplitterNodeParser
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Setup Settings and Logging
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
settings = get_settings()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LlamaIndexChunker(BaseChunker):
|
| 26 |
+
"""
|
| 27 |
+
LlamaIndex-based semantic chunking strategy:
|
| 28 |
+
- Uses LlamaIndex's advanced semantic splitting algorithms
|
| 29 |
+
- Provides superior boundary detection using embeddings
|
| 30 |
+
- Supports multiple LlamaIndex splitter types
|
| 31 |
+
|
| 32 |
+
Best for:
|
| 33 |
+
- Documents requiring sophisticated semantic analysis
|
| 34 |
+
- When LlamaIndex ecosystem integration is needed
|
| 35 |
+
- Advanced chunking with embedding-based boundaries
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, chunk_size: int = None, overlap: int = None, splitter_type: str = "semantic", min_chunk_size: int = 100):
|
| 38 |
+
"""
|
| 39 |
+
Initialize LlamaIndex chunker
|
| 40 |
+
|
| 41 |
+
Arguments:
|
| 42 |
+
----------
|
| 43 |
+
chunk_size { int } : Target tokens per chunk
|
| 44 |
+
|
| 45 |
+
overlap { int } : Overlap tokens between chunks
|
| 46 |
+
|
| 47 |
+
splitter_type { str } : Type of LlamaIndex splitter ("semantic", "sentence", "token")
|
| 48 |
+
|
| 49 |
+
min_chunk_size { int } : Minimum chunk size in tokens
|
| 50 |
+
"""
|
| 51 |
+
# Use SEMANTIC since it's semantic-based
|
| 52 |
+
super().__init__(ChunkingStrategy.SEMANTIC)
|
| 53 |
+
|
| 54 |
+
self.chunk_size = chunk_size or settings.FIXED_CHUNK_SIZE
|
| 55 |
+
self.overlap = overlap or settings.FIXED_CHUNK_OVERLAP
|
| 56 |
+
self.splitter_type = splitter_type
|
| 57 |
+
self.min_chunk_size = min_chunk_size
|
| 58 |
+
|
| 59 |
+
# Initialize token counter
|
| 60 |
+
self.token_counter = TokenCounter()
|
| 61 |
+
|
| 62 |
+
# Initialize LlamaIndex components
|
| 63 |
+
self._splitter = None
|
| 64 |
+
self._initialized = False
|
| 65 |
+
|
| 66 |
+
self._initialize_llamaindex()
|
| 67 |
+
|
| 68 |
+
self.logger.info(f"Initialized LlamaIndexChunker: chunk_size={self.chunk_size}, overlap={self.overlap}, splitter_type={self.splitter_type}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _initialize_llamaindex(self):
|
| 72 |
+
"""
|
| 73 |
+
Initialize LlamaIndex splitter with proper error handling
|
| 74 |
+
"""
|
| 75 |
+
try:
|
| 76 |
+
# Initialize embedding model
|
| 77 |
+
embed_model = HuggingFaceEmbedding(model_name = settings.EMBEDDING_MODEL)
|
| 78 |
+
|
| 79 |
+
# Initialize appropriate splitter based on type
|
| 80 |
+
if (self.splitter_type == "semantic"):
|
| 81 |
+
self._splitter = SemanticSplitterNodeParser(buffer_size = 1,
|
| 82 |
+
breakpoint_percentile_threshold = 95,
|
| 83 |
+
embed_model = embed_model,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
elif (self.splitter_type == "sentence"):
|
| 87 |
+
self._splitter = SentenceSplitter(chunk_size = self.chunk_size,
|
| 88 |
+
chunk_overlap = self.overlap,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
elif (self.splitter_type == "token"):
|
| 92 |
+
self._splitter = TokenTextSplitter(chunk_size = self.chunk_size,
|
| 93 |
+
chunk_overlap = self.overlap,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
self.logger.warning(f"Unknown splitter type: {self.splitter_type}, using semantic")
|
| 98 |
+
self._splitter = SemanticSplitterNodeParser(buffer_size = 1,
|
| 99 |
+
breakpoint_percentile_threshold = 95,
|
| 100 |
+
embed_model = embed_model,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self._initialized = True
|
| 104 |
+
self.logger.info(f"Successfully initialized LlamaIndex {self.splitter_type} splitter")
|
| 105 |
+
|
| 106 |
+
except ImportError as e:
|
| 107 |
+
self.logger.error(f"LlamaIndex not available: {repr(e)}")
|
| 108 |
+
self._initialized = False
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
self.logger.error(f"Failed to initialize LlamaIndex: {repr(e)}")
|
| 112 |
+
self._initialized = False
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def chunk_text(self, text: str, metadata: Optional[DocumentMetadata] = None) -> List[DocumentChunk]:
|
| 116 |
+
"""
|
| 117 |
+
Chunk text using LlamaIndex semantic splitting
|
| 118 |
+
|
| 119 |
+
Arguments:
|
| 120 |
+
----------
|
| 121 |
+
text { str } : Input text
|
| 122 |
+
|
| 123 |
+
metadata { DocumentMetaData } : Document metadata
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
--------
|
| 127 |
+
{ list } : List of DocumentChunk objects
|
| 128 |
+
"""
|
| 129 |
+
if not text or not text.strip():
|
| 130 |
+
return []
|
| 131 |
+
|
| 132 |
+
# Fallback if LlamaIndex not available
|
| 133 |
+
if not self._initialized:
|
| 134 |
+
self.logger.warning("LlamaIndex not available, falling back to simple semantic chunking")
|
| 135 |
+
return self._fallback_chunking(text = text,
|
| 136 |
+
metadata = metadata,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
document_id = metadata.document_id if metadata else "unknown"
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Create LlamaIndex document
|
| 143 |
+
llama_doc = LlamaDocument(text = text)
|
| 144 |
+
|
| 145 |
+
# Get nodes from splitter
|
| 146 |
+
nodes = self._splitter.get_nodes_from_documents([llama_doc])
|
| 147 |
+
|
| 148 |
+
# Convert nodes to our DocumentChunk format
|
| 149 |
+
chunks = list()
|
| 150 |
+
start_pos = 0
|
| 151 |
+
|
| 152 |
+
for i, node in enumerate(nodes):
|
| 153 |
+
chunk_text = node.text
|
| 154 |
+
|
| 155 |
+
# Create chunk
|
| 156 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 157 |
+
chunk_index = i,
|
| 158 |
+
document_id = document_id,
|
| 159 |
+
start_char = start_pos,
|
| 160 |
+
end_char = start_pos + len(chunk_text),
|
| 161 |
+
metadata = {"llamaindex_splitter" : self.splitter_type,
|
| 162 |
+
"node_id" : node.node_id,
|
| 163 |
+
"chunk_type" : "llamaindex_semantic",
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
chunks.append(chunk)
|
| 168 |
+
start_pos += len(chunk_text)
|
| 169 |
+
|
| 170 |
+
# Filter out chunks that are too small
|
| 171 |
+
chunks = [c for c in chunks if (c.token_count >= self.min_chunk_size)]
|
| 172 |
+
|
| 173 |
+
self.logger.debug(f"Created {len(chunks)} chunks using LlamaIndex {self.splitter_type} splitter")
|
| 174 |
+
|
| 175 |
+
return chunks
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
self.logger.error(f"LlamaIndex chunking failed: {repr(e)}")
|
| 179 |
+
return self._fallback_chunking(text = text,
|
| 180 |
+
metadata = metadata,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _fallback_chunking(self, text: str, metadata: Optional[DocumentMetadata] = None) -> List[DocumentChunk]:
|
| 185 |
+
"""
|
| 186 |
+
Fallback to basic semantic chunking when LlamaIndex fails
|
| 187 |
+
|
| 188 |
+
Arguments:
|
| 189 |
+
----------
|
| 190 |
+
text { str } : Input text
|
| 191 |
+
|
| 192 |
+
metadata { DocumentMetaData } : Document metadata
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
--------
|
| 196 |
+
{ list } : List of chunks
|
| 197 |
+
"""
|
| 198 |
+
fallback_chunker = SemanticChunker(chunk_size = self.chunk_size,
|
| 199 |
+
overlap = self.overlap,
|
| 200 |
+
similarity_threshold = 0.95,
|
| 201 |
+
min_chunk_size = self.min_chunk_size,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return fallback_chunker.chunk_text(text, metadata)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_splitter_info(self) -> dict:
|
| 208 |
+
"""
|
| 209 |
+
Get information about the LlamaIndex splitter configuration
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
--------
|
| 213 |
+
{ dict } : Splitter information
|
| 214 |
+
"""
|
| 215 |
+
return {"splitter_type" : self.splitter_type,
|
| 216 |
+
"chunk_size" : self.chunk_size,
|
| 217 |
+
"overlap" : self.overlap,
|
| 218 |
+
"initialized" : self._initialized,
|
| 219 |
+
"min_chunk_size" : self.min_chunk_size,
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def from_config(cls, config: ChunkerConfig) -> 'LlamaIndexChunker':
|
| 225 |
+
"""
|
| 226 |
+
Create LlamaIndexChunker from configuration
|
| 227 |
+
|
| 228 |
+
Arguments:
|
| 229 |
+
----------
|
| 230 |
+
config { ChunkerConfig } : ChunkerConfig object
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
--------
|
| 234 |
+
{ LlamaIndexChunker } : LlamaIndexChunker instance
|
| 235 |
+
"""
|
| 236 |
+
return cls(chunk_size = config.chunk_size,
|
| 237 |
+
overlap = config.overlap,
|
| 238 |
+
splitter_type = config.extra.get('llamaindex_splitter', 'semantic'),
|
| 239 |
+
min_chunk_size = config.min_chunk_size,
|
| 240 |
+
)
|
chunking/overlap_manager.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from config.models import DocumentChunk
|
| 6 |
+
from config.logging_config import get_logger
|
| 7 |
+
from chunking.token_counter import TokenCounter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Setup Logging
|
| 11 |
+
logger = get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OverlapManager:
|
| 15 |
+
"""
|
| 16 |
+
Manages overlapping regions between chunks : ensures smooth context transitions and optimal retrieval
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, overlap_tokens: int = 50):
|
| 19 |
+
"""
|
| 20 |
+
Initialize overlap manager
|
| 21 |
+
|
| 22 |
+
Arguments:
|
| 23 |
+
----------
|
| 24 |
+
overlap_tokens { int } : Target overlap in tokens
|
| 25 |
+
"""
|
| 26 |
+
self.overlap_tokens = overlap_tokens
|
| 27 |
+
self.token_counter = TokenCounter()
|
| 28 |
+
self.logger = logger
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def add_overlap(self, chunks: List[DocumentChunk], overlap_tokens: Optional[int] = None) -> List[DocumentChunk]:
|
| 32 |
+
"""
|
| 33 |
+
Add overlap to existing chunks
|
| 34 |
+
|
| 35 |
+
Arguments:
|
| 36 |
+
----------
|
| 37 |
+
chunks { list } : List of chunks without overlap
|
| 38 |
+
|
| 39 |
+
overlap_tokens { int } : Override default overlap
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
--------
|
| 43 |
+
{ list } : List of chunks with overlap
|
| 44 |
+
"""
|
| 45 |
+
if (not chunks or (len(chunks) < 2)):
|
| 46 |
+
return chunks
|
| 47 |
+
|
| 48 |
+
overlap = overlap_tokens or self.overlap_tokens
|
| 49 |
+
overlapped_chunks = list()
|
| 50 |
+
|
| 51 |
+
for i, chunk in enumerate(chunks):
|
| 52 |
+
if (i == 0):
|
| 53 |
+
# First chunk: no prefix, add suffix from next
|
| 54 |
+
new_text = chunk.text
|
| 55 |
+
if (i + 1 < len(chunks)):
|
| 56 |
+
suffix = self._get_overlap_text(text = chunks[i + 1].text,
|
| 57 |
+
overlap_tokens = overlap,
|
| 58 |
+
from_start = True,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
new_text = new_text + " " + suffix
|
| 62 |
+
|
| 63 |
+
elif (i == len(chunks) - 1):
|
| 64 |
+
# Last chunk: add prefix from previous, no suffix
|
| 65 |
+
prefix = self._get_overlap_text(text = chunks[i - 1].text,
|
| 66 |
+
overlap_tokens = overlap,
|
| 67 |
+
from_start = False,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
new_text = prefix + " " + chunk.text
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
# Middle chunk: add both prefix and suffix
|
| 74 |
+
prefix = self._get_overlap_text(text = chunks[i - 1].text,
|
| 75 |
+
overlap_tokens = overlap,
|
| 76 |
+
from_start = False,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
suffix = self._get_overlap_text(text = chunks[i + 1].text,
|
| 80 |
+
overlap_tokens = overlap,
|
| 81 |
+
from_start = True,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
new_text = prefix + " " + chunk.text + " " + suffix
|
| 85 |
+
|
| 86 |
+
# Create new chunk with overlapped text
|
| 87 |
+
overlapped_chunk = DocumentChunk(chunk_id = chunk.chunk_id,
|
| 88 |
+
document_id = chunk.document_id,
|
| 89 |
+
text = new_text,
|
| 90 |
+
chunk_index = chunk.chunk_index,
|
| 91 |
+
start_char = chunk.start_char,
|
| 92 |
+
end_char = chunk.end_char,
|
| 93 |
+
page_number = chunk.page_number,
|
| 94 |
+
section_title = chunk.section_title,
|
| 95 |
+
token_count = self.token_counter.count_tokens(new_text),
|
| 96 |
+
metadata = chunk.metadata,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
overlapped_chunks.append(overlapped_chunk)
|
| 100 |
+
|
| 101 |
+
self.logger.debug(f"Added overlap to {len(chunks)} chunks")
|
| 102 |
+
return overlapped_chunks
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _get_overlap_text(self, text: str, overlap_tokens: int, from_start: bool) -> str:
|
| 106 |
+
"""
|
| 107 |
+
Extract overlap text from beginning or end
|
| 108 |
+
|
| 109 |
+
Arguments:
|
| 110 |
+
----------
|
| 111 |
+
text { str } : Source text
|
| 112 |
+
|
| 113 |
+
overlap_tokens { int } : Number of tokens to extract
|
| 114 |
+
|
| 115 |
+
from_start { bool } : True for start, False for end
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
--------
|
| 119 |
+
{ str } : Overlap text
|
| 120 |
+
"""
|
| 121 |
+
total_tokens = self.token_counter.count_tokens(text)
|
| 122 |
+
|
| 123 |
+
if (total_tokens <= overlap_tokens):
|
| 124 |
+
return text
|
| 125 |
+
|
| 126 |
+
if from_start:
|
| 127 |
+
# Get first N tokens
|
| 128 |
+
return self.token_counter.truncate_to_tokens(text = text,
|
| 129 |
+
max_tokens = overlap_tokens,
|
| 130 |
+
suffix = "",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
# Get last N tokens using token counter's boundary finding
|
| 135 |
+
char_pos, overlap_text = self.token_counter.find_token_boundaries(text = text,
|
| 136 |
+
target_tokens = overlap_tokens,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Take from the end instead of beginning
|
| 140 |
+
if (char_pos < len(text)):
|
| 141 |
+
return text[-char_pos:] if (char_pos > 0) else text
|
| 142 |
+
|
| 143 |
+
return overlap_text
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def remove_overlap(self, chunks: List[DocumentChunk]) -> List[DocumentChunk]:
|
| 147 |
+
"""
|
| 148 |
+
Remove overlap from chunks (get core content only)
|
| 149 |
+
|
| 150 |
+
Arguments:
|
| 151 |
+
----------
|
| 152 |
+
chunks { list } : List of chunks with overlap
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
--------
|
| 156 |
+
{ list } : List of chunks without overlap
|
| 157 |
+
"""
|
| 158 |
+
if (not chunks or (len(chunks) < 2)):
|
| 159 |
+
return chunks
|
| 160 |
+
|
| 161 |
+
core_chunks = list()
|
| 162 |
+
|
| 163 |
+
for i, chunk in enumerate(chunks):
|
| 164 |
+
if (i == 0):
|
| 165 |
+
# First chunk: remove suffix
|
| 166 |
+
core_text = self._remove_suffix_overlap(text = chunk.text,
|
| 167 |
+
next_text = chunks[i + 1].text if i + 1 < len(chunks) else "",
|
| 168 |
+
)
|
| 169 |
+
elif (i == len(chunks) - 1):
|
| 170 |
+
# Last chunk: remove prefix
|
| 171 |
+
core_text = self._remove_prefix_overlap(text = chunk.text,
|
| 172 |
+
previous_text = chunks[i - 1].text,
|
| 173 |
+
)
|
| 174 |
+
else:
|
| 175 |
+
# Middle chunk: remove both
|
| 176 |
+
temp_text = self._remove_prefix_overlap(text = chunk.text,
|
| 177 |
+
previous_text = chunks[i - 1].text,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
core_text = self._remove_suffix_overlap(text = temp_text,
|
| 181 |
+
next_text = chunks[i + 1].text,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
core_chunk = DocumentChunk(chunk_id = chunk.chunk_id,
|
| 185 |
+
document_id = chunk.document_id,
|
| 186 |
+
text = core_text,
|
| 187 |
+
chunk_index = chunk.chunk_index,
|
| 188 |
+
start_char = chunk.start_char,
|
| 189 |
+
end_char = chunk.end_char,
|
| 190 |
+
page_number = chunk.page_number,
|
| 191 |
+
section_title = chunk.section_title,
|
| 192 |
+
token_count = self.token_counter.count_tokens(core_text),
|
| 193 |
+
metadata = chunk.metadata,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
core_chunks.append(core_chunk)
|
| 197 |
+
|
| 198 |
+
return core_chunks
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _remove_prefix_overlap(self, text: str, previous_text: str) -> str:
|
| 202 |
+
"""
|
| 203 |
+
Remove overlap with previous chunk
|
| 204 |
+
"""
|
| 205 |
+
if not text or not previous_text:
|
| 206 |
+
return text
|
| 207 |
+
|
| 208 |
+
words = text.split()
|
| 209 |
+
prev_words = previous_text.split()
|
| 210 |
+
|
| 211 |
+
# Find longest common suffix-prefix match
|
| 212 |
+
max_overlap = 0
|
| 213 |
+
|
| 214 |
+
for overlap_size in range(1, min(len(words), len(prev_words)) + 1):
|
| 215 |
+
if (words[:overlap_size] == prev_words[-overlap_size:]):
|
| 216 |
+
max_overlap = overlap_size
|
| 217 |
+
|
| 218 |
+
if (max_overlap > 0):
|
| 219 |
+
return " ".join(words[max_overlap:])
|
| 220 |
+
|
| 221 |
+
return text
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _remove_suffix_overlap(self, text: str, next_text: str) -> str:
|
| 225 |
+
"""
|
| 226 |
+
Remove overlap with next chunk
|
| 227 |
+
"""
|
| 228 |
+
# Find common suffix
|
| 229 |
+
words = text.split()
|
| 230 |
+
next_words = next_text.split()
|
| 231 |
+
|
| 232 |
+
common_length = 0
|
| 233 |
+
|
| 234 |
+
for i in range(1, min(len(words), len(next_words)) + 1):
|
| 235 |
+
if (words[-i] == next_words[i - 1]):
|
| 236 |
+
common_length += 1
|
| 237 |
+
|
| 238 |
+
else:
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
if (common_length > 0):
|
| 242 |
+
return " ".join(words[:-common_length])
|
| 243 |
+
|
| 244 |
+
return text
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def calculate_overlap_percentage(self, chunks: List[DocumentChunk]) -> float:
|
| 248 |
+
"""
|
| 249 |
+
Calculate average overlap percentage
|
| 250 |
+
|
| 251 |
+
Arguments:
|
| 252 |
+
----------
|
| 253 |
+
chunks { list } : List of chunks
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
--------
|
| 257 |
+
{ float } : Average overlap percentage
|
| 258 |
+
"""
|
| 259 |
+
if (len(chunks) < 2):
|
| 260 |
+
return 0.0
|
| 261 |
+
|
| 262 |
+
overlaps = list()
|
| 263 |
+
|
| 264 |
+
for i in range(len(chunks) - 1):
|
| 265 |
+
overlap = self._measure_overlap(chunks[i].text, chunks[i + 1].text)
|
| 266 |
+
|
| 267 |
+
overlaps.append(overlap)
|
| 268 |
+
|
| 269 |
+
return sum(overlaps) / len(overlaps) if overlaps else 0.0
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _measure_overlap(self, text1: str, text2: str) -> float:
|
| 273 |
+
"""
|
| 274 |
+
Measure overlap between two texts
|
| 275 |
+
|
| 276 |
+
Arguments:
|
| 277 |
+
----------
|
| 278 |
+
text1 { str } : First text
|
| 279 |
+
|
| 280 |
+
text2 { str } : Second text
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
--------
|
| 284 |
+
{ float } : Overlap percentage (0-100)
|
| 285 |
+
"""
|
| 286 |
+
words1 = set(text1.lower().split())
|
| 287 |
+
words2 = set(text2.lower().split())
|
| 288 |
+
|
| 289 |
+
if (not words1 or not words2):
|
| 290 |
+
return 0.0
|
| 291 |
+
|
| 292 |
+
common = words1 & words2
|
| 293 |
+
overlap_pct = (len(common) / min(len(words1), len(words2))) * 100
|
| 294 |
+
|
| 295 |
+
return overlap_pct
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def optimize_overlaps(self, chunks: List[DocumentChunk], target_overlap: int, tolerance: int = 10) -> List[DocumentChunk]:
|
| 299 |
+
"""
|
| 300 |
+
Optimize overlap sizes to target
|
| 301 |
+
|
| 302 |
+
Arguments:
|
| 303 |
+
----------
|
| 304 |
+
chunks { list } : List of chunks
|
| 305 |
+
|
| 306 |
+
target_overlap { int } : Target overlap in tokens
|
| 307 |
+
|
| 308 |
+
tolerance { int } : Acceptable deviation in tokens
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
--------
|
| 312 |
+
{ list } : Optimized chunks
|
| 313 |
+
"""
|
| 314 |
+
if (len(chunks) < 2):
|
| 315 |
+
return chunks
|
| 316 |
+
|
| 317 |
+
# Validate target_overlap is reasonable
|
| 318 |
+
if (target_overlap <= 0):
|
| 319 |
+
self.logger.warning("Target overlap must be positive, using default")
|
| 320 |
+
target_overlap = self.overlap_tokens
|
| 321 |
+
|
| 322 |
+
optimized = list()
|
| 323 |
+
|
| 324 |
+
for i in range(len(chunks)):
|
| 325 |
+
chunk = chunks[i]
|
| 326 |
+
|
| 327 |
+
# Check current overlap with next chunk
|
| 328 |
+
if (i < len(chunks) - 1):
|
| 329 |
+
current_overlap = self._count_overlap_tokens(text1 = chunk.text,
|
| 330 |
+
text2 = chunks[i + 1].text,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Adjust if outside tolerance
|
| 334 |
+
if (abs(current_overlap - target_overlap) > tolerance):
|
| 335 |
+
# Add or remove text to reach target
|
| 336 |
+
if (current_overlap < target_overlap):
|
| 337 |
+
# Need more overlap
|
| 338 |
+
additional = self._get_overlap_text(text = chunks[i + 1].text,
|
| 339 |
+
overlap_tokens = target_overlap - current_overlap,
|
| 340 |
+
from_start = True,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
new_text = chunk.text + " " + additional
|
| 344 |
+
|
| 345 |
+
else:
|
| 346 |
+
# Need less overlap
|
| 347 |
+
new_text = self.token_counter.truncate_to_tokens(text = chunk.text,
|
| 348 |
+
max_tokens = self.token_counter.count_tokens(chunk.text) - (current_overlap - target_overlap),
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
chunk = DocumentChunk(chunk_id = chunk.chunk_id,
|
| 352 |
+
document_id = chunk.document_id,
|
| 353 |
+
text = new_text,
|
| 354 |
+
chunk_index = chunk.chunk_index,
|
| 355 |
+
start_char = chunk.start_char,
|
| 356 |
+
end_char = chunk.end_char,
|
| 357 |
+
page_number = chunk.page_number,
|
| 358 |
+
section_title = chunk.section_title,
|
| 359 |
+
token_count = self.token_counter.count_tokens(new_text),
|
| 360 |
+
metadata = chunk.metadata,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
optimized.append(chunk)
|
| 364 |
+
|
| 365 |
+
return optimized
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _count_overlap_tokens(self, text1: str, text2: str) -> int:
|
| 369 |
+
"""
|
| 370 |
+
Count overlapping tokens between two texts
|
| 371 |
+
"""
|
| 372 |
+
# Find longest common substring at the boundary
|
| 373 |
+
words1 = text1.split()
|
| 374 |
+
words2 = text2.split()
|
| 375 |
+
|
| 376 |
+
max_overlap = 0
|
| 377 |
+
|
| 378 |
+
for i in range(1, min(len(words1), len(words2)) + 1):
|
| 379 |
+
if (words1[-i:] == words2[:i]):
|
| 380 |
+
overlap_text = " ".join(words1[-i:])
|
| 381 |
+
max_overlap = self.token_counter.count_tokens(overlap_text)
|
| 382 |
+
|
| 383 |
+
return max_overlap
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def get_overlap_statistics(self, chunks: List[DocumentChunk]) -> dict:
|
| 387 |
+
"""
|
| 388 |
+
Get statistics about overlaps
|
| 389 |
+
|
| 390 |
+
Arguments:
|
| 391 |
+
----------
|
| 392 |
+
chunks { list } : List of chunks
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
--------
|
| 396 |
+
{ dict } : Statistics dictionary
|
| 397 |
+
"""
|
| 398 |
+
if (len(chunks) < 2):
|
| 399 |
+
return {"num_chunks" : len(chunks),
|
| 400 |
+
"num_overlaps" : 0,
|
| 401 |
+
"avg_overlap_tokens" : 0,
|
| 402 |
+
"avg_overlap_percentage" : 0,
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
overlap_tokens = list()
|
| 406 |
+
overlap_percentages = list()
|
| 407 |
+
|
| 408 |
+
for i in range(len(chunks) - 1):
|
| 409 |
+
tokens = self._count_overlap_tokens(chunks[i].text, chunks[i + 1].text)
|
| 410 |
+
pct = self._measure_overlap(chunks[i].text, chunks[i + 1].text)
|
| 411 |
+
|
| 412 |
+
overlap_tokens.append(tokens)
|
| 413 |
+
overlap_percentages.append(pct)
|
| 414 |
+
|
| 415 |
+
return {"num_chunks" : len(chunks),
|
| 416 |
+
"num_overlaps" : len(overlap_tokens),
|
| 417 |
+
"avg_overlap_tokens" : sum(overlap_tokens) / len(overlap_tokens) if overlap_tokens else 0,
|
| 418 |
+
"min_overlap_tokens" : min(overlap_tokens) if overlap_tokens else 0,
|
| 419 |
+
"max_overlap_tokens" : max(overlap_tokens) if overlap_tokens else 0,
|
| 420 |
+
"avg_overlap_percentage" : sum(overlap_percentages) / len(overlap_percentages) if overlap_percentages else 0,
|
| 421 |
+
}
|
chunking/semantic_chunker.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from config.models import DocumentChunk
|
| 8 |
+
from config.settings import get_settings
|
| 9 |
+
from config.models import DocumentMetadata
|
| 10 |
+
from config.models import ChunkingStrategy
|
| 11 |
+
from config.logging_config import get_logger
|
| 12 |
+
from chunking.base_chunker import BaseChunker
|
| 13 |
+
from chunking.base_chunker import ChunkerConfig
|
| 14 |
+
from chunking.token_counter import TokenCounter
|
| 15 |
+
from chunking.fixed_chunker import FixedChunker
|
| 16 |
+
from chunking.overlap_manager import OverlapManager
|
| 17 |
+
from sentence_transformers import SentenceTransformer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Setup Settings and Logging
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
settings = get_settings()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SemanticChunker(BaseChunker):
|
| 26 |
+
"""
|
| 27 |
+
Semantic chunking strategy with section-aware splitting:
|
| 28 |
+
- Detects section boundaries and NEVER crosses them
|
| 29 |
+
- Creates chunks based on semantic similarity within sections
|
| 30 |
+
- Preserves hierarchical structure (sections → subsections → content)
|
| 31 |
+
|
| 32 |
+
Best for:
|
| 33 |
+
- Medium documents (50K-500K tokens)
|
| 34 |
+
- Documents with clear topics/sections
|
| 35 |
+
- When context coherence is critical
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, chunk_size: int = None, overlap: int = None, similarity_threshold: float = None, min_chunk_size: int = 100,
|
| 38 |
+
embedding_model: Optional[SentenceTransformer] = None, respect_section_boundaries: bool = True):
|
| 39 |
+
"""
|
| 40 |
+
Initialize semantic chunker
|
| 41 |
+
|
| 42 |
+
Arguments:
|
| 43 |
+
----------
|
| 44 |
+
chunk_size { int } : Target tokens per chunk (soft limit)
|
| 45 |
+
|
| 46 |
+
overlap { int } : Overlap tokens between chunks
|
| 47 |
+
|
| 48 |
+
similarity_threshold { float } : Threshold for semantic breakpoints (0-1)
|
| 49 |
+
|
| 50 |
+
min_chunk_size { int } : Minimum chunk size in tokens
|
| 51 |
+
|
| 52 |
+
embedding_model { SentenceTransformer } : Pre-loaded embedding model (optional)
|
| 53 |
+
|
| 54 |
+
respect_section_boundaries { bool } : Detect and respect section headers
|
| 55 |
+
"""
|
| 56 |
+
super().__init__(ChunkingStrategy.SEMANTIC)
|
| 57 |
+
|
| 58 |
+
self.chunk_size = chunk_size or settings.FIXED_CHUNK_SIZE
|
| 59 |
+
self.overlap = overlap or settings.FIXED_CHUNK_OVERLAP
|
| 60 |
+
self.similarity_threshold = similarity_threshold or settings.SEMANTIC_BREAKPOINT_THRESHOLD
|
| 61 |
+
self.min_chunk_size = min_chunk_size
|
| 62 |
+
self.respect_section_boundaries = respect_section_boundaries
|
| 63 |
+
|
| 64 |
+
# Initialize token counter and overlap manager
|
| 65 |
+
self.token_counter = TokenCounter()
|
| 66 |
+
self.overlap_manager = OverlapManager(overlap_tokens = self.overlap)
|
| 67 |
+
|
| 68 |
+
# Initialize or use provided embedding model
|
| 69 |
+
if embedding_model is not None:
|
| 70 |
+
self.embedding_model = embedding_model
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
try:
|
| 74 |
+
self.logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
|
| 75 |
+
self.embedding_model = SentenceTransformer(settings.EMBEDDING_MODEL)
|
| 76 |
+
self.logger.info("Embedding model loaded successfully")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
self.logger.error(f"Failed to load embedding model: {repr(e)}")
|
| 80 |
+
self.embedding_model = None
|
| 81 |
+
|
| 82 |
+
self.logger.info(f"Initialized SemanticChunker: chunk_size={self.chunk_size}, threshold={self.similarity_threshold}, "
|
| 83 |
+
f"model_loaded={self.embedding_model is not None}, section_aware={self.respect_section_boundaries}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def chunk_text(self, text: str, metadata: Optional[DocumentMetadata] = None) -> List[DocumentChunk]:
|
| 87 |
+
"""
|
| 88 |
+
Chunk text based on semantic similarity AND section structure
|
| 89 |
+
|
| 90 |
+
Arguments:
|
| 91 |
+
----------
|
| 92 |
+
text { str } : Input text
|
| 93 |
+
|
| 94 |
+
metadata { DocumentMetadata } : Document metadata
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
--------
|
| 98 |
+
{ list } : List of DocumentChunk objects
|
| 99 |
+
"""
|
| 100 |
+
if not text or not text.strip():
|
| 101 |
+
return []
|
| 102 |
+
|
| 103 |
+
document_id = metadata.document_id if metadata else "unknown"
|
| 104 |
+
|
| 105 |
+
# If embedding model not available, fall back to fixed chunking
|
| 106 |
+
if self.embedding_model is None:
|
| 107 |
+
self.logger.warning("Embedding model not available, using sentence-based chunking")
|
| 108 |
+
return self._fallback_chunking(text=text, document_id=document_id)
|
| 109 |
+
|
| 110 |
+
# Detect section headers if enabled
|
| 111 |
+
if self.respect_section_boundaries:
|
| 112 |
+
headers = self._detect_section_headers(text)
|
| 113 |
+
|
| 114 |
+
if headers:
|
| 115 |
+
self.logger.info(f"Detected {len(headers)} section headers - using section-aware chunking")
|
| 116 |
+
chunks = self._chunk_by_sections(text = text,
|
| 117 |
+
headers = headers,
|
| 118 |
+
document_id = document_id,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
self.logger.info("No section headers detected - using standard semantic chunking")
|
| 123 |
+
chunks = self._chunk_semantic(text = text,
|
| 124 |
+
document_id = document_id,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
else:
|
| 128 |
+
chunks = self._chunk_semantic(text = text,
|
| 129 |
+
document_id = document_id,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Filter out chunks that are too small
|
| 133 |
+
chunks = [c for c in chunks if (c.token_count >= self.min_chunk_size)]
|
| 134 |
+
|
| 135 |
+
# Use OverlapManager to add proper overlap between semantic chunks
|
| 136 |
+
if ((len(chunks) > 1) and (self.overlap > 0)):
|
| 137 |
+
chunks = self.overlap_manager.add_overlap(chunks = chunks,
|
| 138 |
+
overlap_tokens = self.overlap,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.logger.debug(f"Created {len(chunks)} semantic chunks")
|
| 142 |
+
|
| 143 |
+
return chunks
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _detect_section_headers(self, text: str) -> List[Tuple[int, str, str, int]]:
|
| 147 |
+
"""
|
| 148 |
+
Detect section headers in text to preserve document structure and returns a list of (line_index, header_type, header_text, char_position)
|
| 149 |
+
|
| 150 |
+
Detects:
|
| 151 |
+
- Project headers
|
| 152 |
+
- Subsection headers
|
| 153 |
+
- Major section headers
|
| 154 |
+
"""
|
| 155 |
+
headers = list()
|
| 156 |
+
lines = text.split('\n')
|
| 157 |
+
char_position = 0
|
| 158 |
+
|
| 159 |
+
for i, line in enumerate(lines):
|
| 160 |
+
line_stripped = line.strip()
|
| 161 |
+
|
| 162 |
+
# Pattern 1: Headers - "a) Name" or "b) Name"
|
| 163 |
+
if (re.match(r'^[a-z]\)\s+[A-Z]', line_stripped)):
|
| 164 |
+
headers.append((i, 'section', line_stripped, char_position))
|
| 165 |
+
self.logger.debug(f"Detected section header at line {i}: {line_stripped[:60]}")
|
| 166 |
+
|
| 167 |
+
# Pattern 2: Subsection headers - "● Subsection:" (bullet with colon)
|
| 168 |
+
elif ((line_stripped.startswith('●')) and (':' in line_stripped)):
|
| 169 |
+
headers.append((i, 'subsection', line_stripped, char_position))
|
| 170 |
+
self.logger.debug(f"Detected subsection header at line {i}: {line_stripped[:60]}")
|
| 171 |
+
|
| 172 |
+
# Pattern 3: Major section headers - "1. SECTION NAME" or all caps with numbers
|
| 173 |
+
elif (re.match(r'^\d+\.\s+[A-Z\s&]+:', line_stripped)):
|
| 174 |
+
headers.append((i, 'section', line_stripped, char_position))
|
| 175 |
+
self.logger.debug(f"Detected major section at line {i}: {line_stripped[:60]}")
|
| 176 |
+
|
| 177 |
+
# Pattern 4: All caps headers (must be substantial)
|
| 178 |
+
elif (line_stripped.isupper() and (len(line_stripped) > 15) and (not line_stripped.startswith('●'))):
|
| 179 |
+
headers.append((i, 'category', line_stripped, char_position))
|
| 180 |
+
self.logger.debug(f"Detected category header at line {i}: {line_stripped[:60]}")
|
| 181 |
+
|
| 182 |
+
# +1 for newline
|
| 183 |
+
char_position += len(line) + 1
|
| 184 |
+
|
| 185 |
+
return headers
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _chunk_by_sections(self, text: str, headers: List[Tuple], document_id: str) -> List[DocumentChunk]:
|
| 189 |
+
"""
|
| 190 |
+
Create chunks that never cross section boundaries: Each chunk preserves its parent section in metadata
|
| 191 |
+
"""
|
| 192 |
+
lines = text.split('\n')
|
| 193 |
+
chunks = list()
|
| 194 |
+
|
| 195 |
+
# Group lines by their parent section
|
| 196 |
+
current_section_lines = list()
|
| 197 |
+
current_section_header = None
|
| 198 |
+
current_subsection_header = None
|
| 199 |
+
start_char = 0
|
| 200 |
+
|
| 201 |
+
for line_idx, line in enumerate(lines):
|
| 202 |
+
# Check if this line is a header
|
| 203 |
+
matching_headers = [h for h in headers if (h[0] == line_idx)]
|
| 204 |
+
|
| 205 |
+
if matching_headers:
|
| 206 |
+
header_info = matching_headers[0]
|
| 207 |
+
header_type = header_info[1]
|
| 208 |
+
header_text = header_info[2]
|
| 209 |
+
|
| 210 |
+
# If we hit a Header, save previous section
|
| 211 |
+
if (header_type == 'section'):
|
| 212 |
+
if current_section_lines:
|
| 213 |
+
# Create chunks from previous section
|
| 214 |
+
section_text = '\n'.join(current_section_lines)
|
| 215 |
+
section_chunks = self._split_section_if_large(text = section_text,
|
| 216 |
+
document_id = document_id,
|
| 217 |
+
start_index = len(chunks),
|
| 218 |
+
start_char = start_char,
|
| 219 |
+
section_header = current_section_header,
|
| 220 |
+
subsection_header = current_subsection_header,
|
| 221 |
+
)
|
| 222 |
+
chunks.extend(section_chunks)
|
| 223 |
+
start_char += len(section_text) + 1
|
| 224 |
+
|
| 225 |
+
# Start new section
|
| 226 |
+
current_section_header = header_text
|
| 227 |
+
current_subsection_header = None
|
| 228 |
+
current_section_lines = [line]
|
| 229 |
+
|
| 230 |
+
# If we hit a SUBSECTION header within a section
|
| 231 |
+
elif (header_type == 'subsection'):
|
| 232 |
+
if (current_section_lines and current_subsection_header):
|
| 233 |
+
# Save previous subsection
|
| 234 |
+
section_text = '\n'.join(current_section_lines)
|
| 235 |
+
section_chunks = self._split_section_if_large(text = section_text,
|
| 236 |
+
document_id = document_id,
|
| 237 |
+
start_index = len(chunks),
|
| 238 |
+
start_char = start_char,
|
| 239 |
+
section_header = current_section_header,
|
| 240 |
+
subsection_header = current_subsection_header,
|
| 241 |
+
)
|
| 242 |
+
chunks.extend(section_chunks)
|
| 243 |
+
start_char += len(section_text) + 1
|
| 244 |
+
current_section_lines = list()
|
| 245 |
+
|
| 246 |
+
# Update subsection
|
| 247 |
+
current_subsection_header = header_text
|
| 248 |
+
current_section_lines.append(line)
|
| 249 |
+
|
| 250 |
+
else:
|
| 251 |
+
current_section_lines.append(line)
|
| 252 |
+
|
| 253 |
+
else:
|
| 254 |
+
current_section_lines.append(line)
|
| 255 |
+
|
| 256 |
+
# Process final section
|
| 257 |
+
if current_section_lines:
|
| 258 |
+
section_text = '\n'.join(current_section_lines)
|
| 259 |
+
|
| 260 |
+
section_chunks = self._split_section_if_large(text = section_text,
|
| 261 |
+
document_id = document_id,
|
| 262 |
+
start_index = len(chunks),
|
| 263 |
+
start_char = start_char,
|
| 264 |
+
section_header = current_section_header,
|
| 265 |
+
subsection_header = current_subsection_header,
|
| 266 |
+
)
|
| 267 |
+
chunks.extend(section_chunks)
|
| 268 |
+
|
| 269 |
+
return chunks
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _split_section_if_large(self, text: str, document_id: str, start_index: int, start_char: int, section_header: Optional[str],
|
| 273 |
+
subsection_header: Optional[str]) -> List[DocumentChunk]:
|
| 274 |
+
"""
|
| 275 |
+
Split a section if it's too large, while preserving section context: Always stores section info in metadata
|
| 276 |
+
"""
|
| 277 |
+
token_count = self.token_counter.count_tokens(text)
|
| 278 |
+
|
| 279 |
+
# Build section title for metadata
|
| 280 |
+
section_parts = list()
|
| 281 |
+
|
| 282 |
+
if section_header:
|
| 283 |
+
section_parts.append(section_header)
|
| 284 |
+
|
| 285 |
+
if subsection_header:
|
| 286 |
+
section_parts.append(subsection_header)
|
| 287 |
+
|
| 288 |
+
section_title = " | ".join(section_parts) if section_parts else None
|
| 289 |
+
|
| 290 |
+
# If section fits in one chunk, keep it whole
|
| 291 |
+
if (token_count <= self.chunk_size * 1.5):
|
| 292 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(text),
|
| 293 |
+
chunk_index = start_index,
|
| 294 |
+
document_id = document_id,
|
| 295 |
+
start_char = start_char,
|
| 296 |
+
end_char = start_char + len(text),
|
| 297 |
+
section_title = section_title,
|
| 298 |
+
metadata = {"section_header" : section_header,
|
| 299 |
+
"subsection_header" : subsection_header,
|
| 300 |
+
"semantic_chunk" : True,
|
| 301 |
+
"section_aware" : True,
|
| 302 |
+
}
|
| 303 |
+
)
|
| 304 |
+
return [chunk]
|
| 305 |
+
|
| 306 |
+
# Section too large - split by bullet points or sentences: But always keep section context in metadata
|
| 307 |
+
if '❖' in text or '●' in text:
|
| 308 |
+
# Split by bullet points (Interactive Demo Features style)
|
| 309 |
+
parts = re.split(r'(❖[^\n]+)', text)
|
| 310 |
+
parts = [p for p in parts if p.strip()]
|
| 311 |
+
|
| 312 |
+
else:
|
| 313 |
+
# Split by sentences within this section
|
| 314 |
+
parts = self._split_sentences(text)
|
| 315 |
+
|
| 316 |
+
sub_chunks = []
|
| 317 |
+
current_pos = start_char
|
| 318 |
+
|
| 319 |
+
for part in parts:
|
| 320 |
+
if not part.strip():
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
part_tokens = self.token_counter.count_tokens(part)
|
| 324 |
+
|
| 325 |
+
# Create chunk with preserved section context
|
| 326 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(part),
|
| 327 |
+
chunk_index = start_index + len(sub_chunks),
|
| 328 |
+
document_id = document_id,
|
| 329 |
+
start_char = current_pos,
|
| 330 |
+
end_char = current_pos + len(part),
|
| 331 |
+
section_title = section_title,
|
| 332 |
+
metadata = {"section_header" : section_header,
|
| 333 |
+
"subsection_header" : subsection_header,
|
| 334 |
+
"parent_section" : section_title,
|
| 335 |
+
"semantic_chunk" : True,
|
| 336 |
+
"section_aware" : True,
|
| 337 |
+
"is_subsection_part" : True,
|
| 338 |
+
}
|
| 339 |
+
)
|
| 340 |
+
sub_chunks.append(chunk)
|
| 341 |
+
current_pos += len(part)
|
| 342 |
+
|
| 343 |
+
if sub_chunks:
|
| 344 |
+
return sub_chunks
|
| 345 |
+
|
| 346 |
+
else:
|
| 347 |
+
chunks_list = [self._create_chunk(text = self._clean_chunk_text(text),
|
| 348 |
+
chunk_index = start_index,
|
| 349 |
+
document_id = document_id,
|
| 350 |
+
start_char = start_char,
|
| 351 |
+
end_char = start_char + len(text),
|
| 352 |
+
section_title = section_title,
|
| 353 |
+
metadata = {"section_header" : section_header,
|
| 354 |
+
"subsection_header" : subsection_header,
|
| 355 |
+
"semantic_chunk" : True,
|
| 356 |
+
}
|
| 357 |
+
)
|
| 358 |
+
]
|
| 359 |
+
|
| 360 |
+
return chunks_list
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def _chunk_semantic(self, text: str, document_id: str) -> List[DocumentChunk]:
|
| 364 |
+
"""
|
| 365 |
+
Standard semantic chunking (when no headers detected)
|
| 366 |
+
"""
|
| 367 |
+
# Split into sentences
|
| 368 |
+
sentences = self._split_sentences(text = text)
|
| 369 |
+
|
| 370 |
+
if (len(sentences) < 2):
|
| 371 |
+
return self._create_single_chunk(text=text, document_id=document_id)
|
| 372 |
+
|
| 373 |
+
# Calculate semantic similarities
|
| 374 |
+
similarities = self._calculate_similarities(sentences=sentences)
|
| 375 |
+
|
| 376 |
+
# Find breakpoints
|
| 377 |
+
breakpoints = self._find_breakpoints(similarities=similarities)
|
| 378 |
+
|
| 379 |
+
# Create chunks WITHOUT overlap
|
| 380 |
+
chunks = self._create_chunks_from_breakpoints(sentences = sentences,
|
| 381 |
+
breakpoints = breakpoints,
|
| 382 |
+
document_id = document_id,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
return chunks
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _split_sentences(self, text: str) -> List[str]:
|
| 389 |
+
"""
|
| 390 |
+
Split text into sentences
|
| 391 |
+
"""
|
| 392 |
+
# Protect abbreviations
|
| 393 |
+
protected = text
|
| 394 |
+
abbreviations = ['Dr.', 'Mr.', 'Mrs.', 'Ms.', 'Jr.', 'Sr.', 'Prof.', 'Inc.', 'Ltd.', 'Corp.', 'Co.', 'vs.', 'etc.', 'e.g.', 'i.e.', 'Ph.D.', 'M.D.', 'B.A.', 'M.A.', 'U.S.', 'U.K.']
|
| 395 |
+
|
| 396 |
+
for abbr in abbreviations:
|
| 397 |
+
protected = protected.replace(abbr, abbr.replace('.', '<DOT>'))
|
| 398 |
+
|
| 399 |
+
# Split on sentence boundaries
|
| 400 |
+
sentence_pattern = r'(?<=[.!?])\s+(?=[A-Z])'
|
| 401 |
+
sentences = re.split(sentence_pattern, protected)
|
| 402 |
+
|
| 403 |
+
# Restore abbreviations
|
| 404 |
+
sentences = [s.replace('<DOT>', '.').strip() for s in sentences]
|
| 405 |
+
|
| 406 |
+
# Filter empty
|
| 407 |
+
sentences = [s for s in sentences if s]
|
| 408 |
+
|
| 409 |
+
return sentences
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def _calculate_similarities(self, sentences: List[str]) -> List[float]:
|
| 413 |
+
"""
|
| 414 |
+
Calculate cosine similarity between adjacent sentences
|
| 415 |
+
"""
|
| 416 |
+
if (len(sentences) < 2):
|
| 417 |
+
return []
|
| 418 |
+
|
| 419 |
+
self.logger.debug(f"Generating embeddings for {len(sentences)} sentences")
|
| 420 |
+
|
| 421 |
+
embeddings = self.embedding_model.encode(sentences,
|
| 422 |
+
show_progress_bar = False,
|
| 423 |
+
convert_to_numpy = True,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
similarities = list()
|
| 427 |
+
|
| 428 |
+
for i in range(len(embeddings) - 1):
|
| 429 |
+
similarity = self._cosine_similarity(vec1 = embeddings[i],
|
| 430 |
+
vec2 = embeddings[i + 1],
|
| 431 |
+
)
|
| 432 |
+
similarities.append(similarity)
|
| 433 |
+
|
| 434 |
+
return similarities
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@staticmethod
|
| 438 |
+
def _cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
|
| 439 |
+
"""
|
| 440 |
+
Calculate cosine similarity between two vectors
|
| 441 |
+
"""
|
| 442 |
+
dot_product = np.dot(vec1, vec2)
|
| 443 |
+
norm1 = np.linalg.norm(vec1)
|
| 444 |
+
norm2 = np.linalg.norm(vec2)
|
| 445 |
+
|
| 446 |
+
if ((norm1 == 0) or (norm2 == 0)):
|
| 447 |
+
return 0.0
|
| 448 |
+
|
| 449 |
+
return dot_product / (norm1 * norm2)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def _find_breakpoints(self, similarities: List[float]) -> List[int]:
|
| 453 |
+
"""
|
| 454 |
+
Find breakpoints where semantic similarity drops significantly
|
| 455 |
+
"""
|
| 456 |
+
if not similarities:
|
| 457 |
+
return []
|
| 458 |
+
|
| 459 |
+
similarities_array = np.array(similarities)
|
| 460 |
+
threshold = np.percentile(similarities_array, (1 - self.similarity_threshold) * 100)
|
| 461 |
+
|
| 462 |
+
breakpoints = [0]
|
| 463 |
+
|
| 464 |
+
for i, sim in enumerate(similarities):
|
| 465 |
+
if (sim < threshold):
|
| 466 |
+
breakpoints.append(i + 1)
|
| 467 |
+
|
| 468 |
+
self.logger.debug(f"Found {len(breakpoints)} breakpoints with threshold {threshold:.3f}")
|
| 469 |
+
|
| 470 |
+
return breakpoints
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _create_chunks_from_breakpoints(self, sentences: List[str], breakpoints: List[int], document_id: str) -> List[DocumentChunk]:
|
| 474 |
+
"""
|
| 475 |
+
Create chunks from sentences and breakpoints WITHOUT overlap
|
| 476 |
+
"""
|
| 477 |
+
chunks = list()
|
| 478 |
+
breakpoints = sorted(set(breakpoints))
|
| 479 |
+
|
| 480 |
+
if (breakpoints[-1] != len(sentences)):
|
| 481 |
+
breakpoints.append(len(sentences))
|
| 482 |
+
|
| 483 |
+
current_pos = 0
|
| 484 |
+
|
| 485 |
+
for i in range(len(breakpoints) - 1):
|
| 486 |
+
start_idx = breakpoints[i]
|
| 487 |
+
end_idx = breakpoints[i + 1]
|
| 488 |
+
|
| 489 |
+
chunk_sentences = sentences[start_idx:end_idx]
|
| 490 |
+
|
| 491 |
+
if not chunk_sentences:
|
| 492 |
+
continue
|
| 493 |
+
|
| 494 |
+
chunk_text = " ".join(chunk_sentences)
|
| 495 |
+
token_count = self.token_counter.count_tokens(chunk_text)
|
| 496 |
+
|
| 497 |
+
if (token_count > self.chunk_size * 1.5):
|
| 498 |
+
sub_chunks = self._split_large_chunk_simple(chunk_sentences = chunk_sentences,
|
| 499 |
+
document_id = document_id,
|
| 500 |
+
start_index = len(chunks),
|
| 501 |
+
start_char = current_pos,
|
| 502 |
+
)
|
| 503 |
+
chunks.extend(sub_chunks)
|
| 504 |
+
|
| 505 |
+
else:
|
| 506 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 507 |
+
chunk_index = len(chunks),
|
| 508 |
+
document_id = document_id,
|
| 509 |
+
start_char = current_pos,
|
| 510 |
+
end_char = current_pos + len(chunk_text),
|
| 511 |
+
metadata = {"sentences" : len(chunk_sentences),
|
| 512 |
+
"semantic_chunk" : True,
|
| 513 |
+
}
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
chunks.append(chunk)
|
| 517 |
+
|
| 518 |
+
current_pos += len(chunk_text)
|
| 519 |
+
|
| 520 |
+
return chunks
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def _split_large_chunk_simple(self, chunk_sentences: List[str], document_id: str, start_index: int, start_char: int) -> List[DocumentChunk]:
|
| 524 |
+
"""
|
| 525 |
+
Split a large chunk into smaller pieces without overlap
|
| 526 |
+
"""
|
| 527 |
+
sub_chunks = list()
|
| 528 |
+
current_sentences = list()
|
| 529 |
+
current_tokens = 0
|
| 530 |
+
current_pos = start_char
|
| 531 |
+
|
| 532 |
+
for sentence in chunk_sentences:
|
| 533 |
+
sentence_tokens = self.token_counter.count_tokens(sentence)
|
| 534 |
+
|
| 535 |
+
if (((current_tokens + sentence_tokens) > self.chunk_size) and current_sentences):
|
| 536 |
+
chunk_text = " ".join(current_sentences)
|
| 537 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 538 |
+
chunk_index = start_index + len(sub_chunks),
|
| 539 |
+
document_id = document_id,
|
| 540 |
+
start_char = current_pos,
|
| 541 |
+
end_char = current_pos + len(chunk_text),
|
| 542 |
+
)
|
| 543 |
+
sub_chunks.append(chunk)
|
| 544 |
+
|
| 545 |
+
current_sentences = [sentence]
|
| 546 |
+
current_tokens = sentence_tokens
|
| 547 |
+
current_pos += len(chunk_text)
|
| 548 |
+
|
| 549 |
+
else:
|
| 550 |
+
current_sentences.append(sentence)
|
| 551 |
+
current_tokens += sentence_tokens
|
| 552 |
+
|
| 553 |
+
if current_sentences:
|
| 554 |
+
chunk_text = " ".join(current_sentences)
|
| 555 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(chunk_text),
|
| 556 |
+
chunk_index = start_index + len(sub_chunks),
|
| 557 |
+
document_id = document_id,
|
| 558 |
+
start_char = current_pos,
|
| 559 |
+
end_char = current_pos + len(chunk_text),
|
| 560 |
+
)
|
| 561 |
+
sub_chunks.append(chunk)
|
| 562 |
+
|
| 563 |
+
return sub_chunks
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def _create_single_chunk(self, text: str, document_id: str) -> List[DocumentChunk]:
|
| 567 |
+
"""
|
| 568 |
+
Create a single chunk for short text
|
| 569 |
+
"""
|
| 570 |
+
chunk = self._create_chunk(text = self._clean_chunk_text(text),
|
| 571 |
+
chunk_index = 0,
|
| 572 |
+
document_id = document_id,
|
| 573 |
+
start_char = 0,
|
| 574 |
+
end_char = len(text),
|
| 575 |
+
)
|
| 576 |
+
return [chunk]
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def _fallback_chunking(self, text: str, document_id: str) -> List[DocumentChunk]:
|
| 580 |
+
"""
|
| 581 |
+
Fallback to sentence-based chunking when embeddings unavailable
|
| 582 |
+
"""
|
| 583 |
+
fallback_chunker = FixedChunker(chunk_size = self.chunk_size,
|
| 584 |
+
overlap = self.overlap,
|
| 585 |
+
respect_sentence_boundaries = True,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
metadata = DocumentMetadata(document_id = document_id,
|
| 589 |
+
filename = "fallback",
|
| 590 |
+
document_type = "txt",
|
| 591 |
+
file_size_bytes = len(text),
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
return fallback_chunker.chunk_text(text, metadata)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
@classmethod
|
| 598 |
+
def from_config(cls, config: ChunkerConfig) -> 'SemanticChunker':
|
| 599 |
+
"""
|
| 600 |
+
Create SemanticChunker from configuration
|
| 601 |
+
"""
|
| 602 |
+
return cls(chunk_size = config.chunk_size,
|
| 603 |
+
overlap = config.overlap,
|
| 604 |
+
similarity_threshold = config.extra.get('semantic_threshold', settings.SEMANTIC_BREAKPOINT_THRESHOLD),
|
| 605 |
+
min_chunk_size = config.min_chunk_size,
|
| 606 |
+
respect_section_boundaries = config.extra.get('respect_section_boundaries', True),
|
| 607 |
+
)
|
chunking/token_counter.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
import tiktoken
|
| 4 |
+
from typing import List
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from config.models import TokenizerType
|
| 7 |
+
from config.settings import get_settings
|
| 8 |
+
from config.logging_config import get_logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Setup Logger and settings
|
| 12 |
+
logger = get_logger(__name__)
|
| 13 |
+
settings = get_settings()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TokenCounter:
|
| 17 |
+
"""
|
| 18 |
+
Token counting utility with support for multiple tokenizers: Provides accurate token counts for chunking and context management
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, tokenizer_type: str = "cl100k_base"):
|
| 21 |
+
"""
|
| 22 |
+
Initialize token counter
|
| 23 |
+
|
| 24 |
+
Arguments:
|
| 25 |
+
----------
|
| 26 |
+
tokenizer_type { str } : Type of tokenizer to use
|
| 27 |
+
"""
|
| 28 |
+
self.tokenizer_type = tokenizer_type
|
| 29 |
+
self.logger = logger
|
| 30 |
+
|
| 31 |
+
# Validate tokenizer type
|
| 32 |
+
valid_tokenizers = [t.value for t in TokenizerType]
|
| 33 |
+
|
| 34 |
+
if tokenizer_type not in valid_tokenizers:
|
| 35 |
+
self.logger.warning(f"Invalid tokenizer type: {tokenizer_type}, using approximate")
|
| 36 |
+
self.tokenizer_type = TokenizerType.APPROXIMATE
|
| 37 |
+
self.tokenizer = None
|
| 38 |
+
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
# Initialize tokenizer
|
| 42 |
+
if (tokenizer_type != TokenizerType.APPROXIMATE):
|
| 43 |
+
try:
|
| 44 |
+
self.tokenizer = tiktoken.get_encoding(tokenizer_type)
|
| 45 |
+
self.logger.debug(f"Initialized tiktoken tokenizer: {tokenizer_type}")
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
self.logger.warning(f"Failed to load tiktoken: {repr(e)}, using approximation")
|
| 49 |
+
|
| 50 |
+
self.tokenizer = None
|
| 51 |
+
self.tokenizer_type = TokenizerType.APPROXIMATE
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
self.tokenizer = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def count_tokens(self, text: str) -> int:
|
| 58 |
+
"""
|
| 59 |
+
Count tokens in text
|
| 60 |
+
|
| 61 |
+
Arguments:
|
| 62 |
+
----------
|
| 63 |
+
text { str } : Input text
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
--------
|
| 67 |
+
{ int } : Number of tokens
|
| 68 |
+
"""
|
| 69 |
+
if not text:
|
| 70 |
+
return 0
|
| 71 |
+
|
| 72 |
+
if self.tokenizer is not None:
|
| 73 |
+
# Use tiktoken for accurate counting
|
| 74 |
+
try:
|
| 75 |
+
tokens = self.tokenizer.encode(text)
|
| 76 |
+
return len(tokens)
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
self.logger.warning(f"Tokenizer error: {e}, falling back to approximation")
|
| 80 |
+
return self._approximate_token_count(text)
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
# Use approximation
|
| 84 |
+
return self._approximate_token_count(text = text)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _approximate_token_count(self, text: str) -> int:
|
| 88 |
+
"""
|
| 89 |
+
Approximate token count using multiple heuristics
|
| 90 |
+
"""
|
| 91 |
+
if not text:
|
| 92 |
+
return 0
|
| 93 |
+
|
| 94 |
+
# Method 1: Word-based estimation (accounts for subword tokenization)
|
| 95 |
+
words = text.split()
|
| 96 |
+
word_count = len(words)
|
| 97 |
+
|
| 98 |
+
# Method 2: Character-based estimation
|
| 99 |
+
char_count = len(text)
|
| 100 |
+
|
| 101 |
+
# Method 3: Hybrid approach with weighting
|
| 102 |
+
# - Short texts: more word-based (better for code/short docs)
|
| 103 |
+
# - Long texts: more character-based (better for prose)
|
| 104 |
+
if (char_count < 1000):
|
| 105 |
+
# Prefer word-based for short texts : Slightly higher for short texts
|
| 106 |
+
estimate = word_count * 1.33
|
| 107 |
+
|
| 108 |
+
else:
|
| 109 |
+
# Balanced approach for longer texts
|
| 110 |
+
word_estimate = word_count * 1.3
|
| 111 |
+
char_estimate = char_count / 4.0
|
| 112 |
+
estimate = (word_estimate + char_estimate) / 2
|
| 113 |
+
|
| 114 |
+
# Ensure reasonable bounds
|
| 115 |
+
min_tokens = max(1, word_count) # At least 1 token per word
|
| 116 |
+
max_tokens = char_count // 2 # At most 1 token per 2 chars
|
| 117 |
+
|
| 118 |
+
return max(min_tokens, min(int(estimate), max_tokens))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def encode(self, text: str) -> List[int]:
|
| 122 |
+
"""
|
| 123 |
+
Encode text to token IDs
|
| 124 |
+
|
| 125 |
+
Arguments:
|
| 126 |
+
----------
|
| 127 |
+
text { str } : Input text
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
--------
|
| 131 |
+
{ list } : List of token IDs
|
| 132 |
+
"""
|
| 133 |
+
if self.tokenizer is None:
|
| 134 |
+
raise ValueError("Cannot encode with approximate tokenizer")
|
| 135 |
+
|
| 136 |
+
return self.tokenizer.encode(text)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def decode(self, tokens: List[int]) -> str:
|
| 140 |
+
"""
|
| 141 |
+
Decode token IDs to text
|
| 142 |
+
|
| 143 |
+
Arguments:
|
| 144 |
+
----------
|
| 145 |
+
tokens { list } : List of token IDs
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
--------
|
| 149 |
+
{ str } : Decoded text
|
| 150 |
+
"""
|
| 151 |
+
if self.tokenizer is None:
|
| 152 |
+
raise ValueError("Cannot decode with approximate tokenizer")
|
| 153 |
+
|
| 154 |
+
return self.tokenizer.decode(tokens)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def truncate_to_tokens(self, text: str, max_tokens: int, suffix: str = "") -> str:
|
| 158 |
+
"""
|
| 159 |
+
Truncate text to maximum token count
|
| 160 |
+
|
| 161 |
+
Arguments:
|
| 162 |
+
----------
|
| 163 |
+
text { str } : Input text
|
| 164 |
+
|
| 165 |
+
max_tokens { int } : Maximum number of tokens
|
| 166 |
+
|
| 167 |
+
suffix { str } : Suffix to add (e.g., "...")
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
--------
|
| 171 |
+
{ str } : Truncated text
|
| 172 |
+
"""
|
| 173 |
+
if self.tokenizer is not None:
|
| 174 |
+
# Use precise token-based truncation
|
| 175 |
+
tokens = self.encode(text)
|
| 176 |
+
|
| 177 |
+
if (len(tokens) <= max_tokens):
|
| 178 |
+
return text
|
| 179 |
+
|
| 180 |
+
# Account for suffix tokens
|
| 181 |
+
suffix_tokens = len(self.encode(suffix)) if suffix else 0
|
| 182 |
+
truncate_at = max_tokens - suffix_tokens
|
| 183 |
+
|
| 184 |
+
truncated_tokens = tokens[:truncate_at]
|
| 185 |
+
truncated_text = self.decode(truncated_tokens)
|
| 186 |
+
|
| 187 |
+
return truncated_text + suffix
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
# Use character-based approximation
|
| 191 |
+
current_tokens = self.count_tokens(text = text)
|
| 192 |
+
|
| 193 |
+
if (current_tokens <= max_tokens):
|
| 194 |
+
return text
|
| 195 |
+
|
| 196 |
+
# Estimate character position
|
| 197 |
+
ratio = max_tokens / current_tokens
|
| 198 |
+
char_position = int(len(text) * ratio)
|
| 199 |
+
|
| 200 |
+
# Find nearest word boundary
|
| 201 |
+
truncated = text[:char_position]
|
| 202 |
+
last_space = truncated.rfind(' ')
|
| 203 |
+
|
| 204 |
+
if (last_space > 0):
|
| 205 |
+
truncated = truncated[:last_space]
|
| 206 |
+
|
| 207 |
+
return truncated + suffix
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def split_into_token_chunks(self, text: str, chunk_size: int, overlap: int = 0) -> List[str]:
|
| 211 |
+
"""
|
| 212 |
+
Split text into chunks of approximately equal token count
|
| 213 |
+
|
| 214 |
+
Arguments:
|
| 215 |
+
----------
|
| 216 |
+
text { str } : Input text
|
| 217 |
+
|
| 218 |
+
chunk_size { int } : Target tokens per chunk
|
| 219 |
+
|
| 220 |
+
overlap { int } : Number of overlapping tokens between chunks
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
--------
|
| 224 |
+
{ list } : List of text chunks
|
| 225 |
+
"""
|
| 226 |
+
if (overlap >= chunk_size):
|
| 227 |
+
raise ValueError("Overlap must be less than chunk_size")
|
| 228 |
+
|
| 229 |
+
if self.tokenizer is not None:
|
| 230 |
+
precise_chunks = self._split_precise(text = text,
|
| 231 |
+
chunk_size = chunk_size,
|
| 232 |
+
overlap = overlap,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return precise_chunks
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
approximate_chunks = self._split_approximate(text = text,
|
| 239 |
+
chunk_size = chunk_size,
|
| 240 |
+
overlap = overlap,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return approximate_chunks
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _split_precise(self, text: str, chunk_size: int, overlap: int) -> List[str]:
|
| 247 |
+
"""
|
| 248 |
+
Split using precise token counts
|
| 249 |
+
"""
|
| 250 |
+
tokens = self.encode(text)
|
| 251 |
+
chunks = list()
|
| 252 |
+
|
| 253 |
+
start = 0
|
| 254 |
+
|
| 255 |
+
while (start < len(tokens)):
|
| 256 |
+
# Get chunk tokens
|
| 257 |
+
end = min(start + chunk_size, len(tokens))
|
| 258 |
+
chunk_tokens = tokens[start:end]
|
| 259 |
+
|
| 260 |
+
# Decode to text
|
| 261 |
+
chunk_text = self.decode(chunk_tokens)
|
| 262 |
+
|
| 263 |
+
chunks.append(chunk_text)
|
| 264 |
+
|
| 265 |
+
# Move to next chunk with overlap
|
| 266 |
+
start = end - overlap
|
| 267 |
+
|
| 268 |
+
# Avoid infinite loop
|
| 269 |
+
if ((start >= len(tokens)) or ((end == len(tokens)))):
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
return chunks
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _split_approximate(self, text: str, chunk_size: int, overlap: int) -> List[str]:
|
| 276 |
+
"""
|
| 277 |
+
Split using approximate token counts
|
| 278 |
+
"""
|
| 279 |
+
# Estimate characters per chunk : Rule = ~4 chars per token
|
| 280 |
+
chars_per_chunk = chunk_size * 4
|
| 281 |
+
overlap_chars = overlap * 4
|
| 282 |
+
|
| 283 |
+
chunks = list()
|
| 284 |
+
sentences = self._split_into_sentences(text = text)
|
| 285 |
+
|
| 286 |
+
current_chunk = list()
|
| 287 |
+
current_tokens = 0
|
| 288 |
+
|
| 289 |
+
for sentence in sentences:
|
| 290 |
+
sentence_tokens = self.count_tokens(text = sentence)
|
| 291 |
+
|
| 292 |
+
if (((current_tokens + sentence_tokens) > chunk_size) and current_chunk):
|
| 293 |
+
# Save current chunk
|
| 294 |
+
chunk_text = " ".join(current_chunk)
|
| 295 |
+
|
| 296 |
+
chunks.append(chunk_text)
|
| 297 |
+
|
| 298 |
+
# Start new chunk with overlap
|
| 299 |
+
if (overlap > 0):
|
| 300 |
+
# Keep last few sentences for overlap
|
| 301 |
+
overlap_text = chunk_text[-overlap_chars:] if len(chunk_text) > overlap_chars else chunk_text
|
| 302 |
+
current_chunk = [overlap_text, sentence]
|
| 303 |
+
current_tokens = self.count_tokens(text = " ".join(current_chunk))
|
| 304 |
+
|
| 305 |
+
else:
|
| 306 |
+
current_chunk = [sentence]
|
| 307 |
+
current_tokens = sentence_tokens
|
| 308 |
+
|
| 309 |
+
else:
|
| 310 |
+
current_chunk.append(sentence)
|
| 311 |
+
|
| 312 |
+
current_tokens += sentence_tokens
|
| 313 |
+
|
| 314 |
+
# Add final chunk
|
| 315 |
+
if current_chunk:
|
| 316 |
+
chunks.append(" ".join(current_chunk))
|
| 317 |
+
|
| 318 |
+
return chunks
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
def _split_into_sentences(text: str) -> List[str]:
|
| 323 |
+
"""
|
| 324 |
+
Simple sentence splitter with better edge case handling
|
| 325 |
+
"""
|
| 326 |
+
if not text.strip():
|
| 327 |
+
return []
|
| 328 |
+
|
| 329 |
+
# Split on sentence boundaries
|
| 330 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
| 331 |
+
|
| 332 |
+
# Filter and clean
|
| 333 |
+
final_sentences = list()
|
| 334 |
+
|
| 335 |
+
for sentence in sentences:
|
| 336 |
+
sentence = sentence.strip()
|
| 337 |
+
|
| 338 |
+
if sentence:
|
| 339 |
+
# Handle abbreviations (basic)
|
| 340 |
+
if not any(sentence.endswith(abbr) for abbr in ['Dr.', 'Mr.', 'Mrs.', 'Ms.', 'etc.']):
|
| 341 |
+
final_sentences.append(sentence)
|
| 342 |
+
|
| 343 |
+
else:
|
| 344 |
+
# For abbreviations, keep with next sentence if possible
|
| 345 |
+
if final_sentences:
|
| 346 |
+
final_sentences[-1] += " " + sentence
|
| 347 |
+
|
| 348 |
+
else:
|
| 349 |
+
final_sentences.append(sentence)
|
| 350 |
+
|
| 351 |
+
return final_sentences
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_token_stats(self, text: str) -> dict:
|
| 355 |
+
"""
|
| 356 |
+
Get comprehensive token statistics
|
| 357 |
+
|
| 358 |
+
Arguments:
|
| 359 |
+
----------
|
| 360 |
+
text { str } : Input text
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
--------
|
| 364 |
+
{ dict } : Dictionary with statistics
|
| 365 |
+
"""
|
| 366 |
+
token_count = self.count_tokens(text = text)
|
| 367 |
+
char_count = len(text)
|
| 368 |
+
word_count = len(text.split())
|
| 369 |
+
|
| 370 |
+
stats = {"tokens" : token_count,
|
| 371 |
+
"characters" : char_count,
|
| 372 |
+
"words" : word_count,
|
| 373 |
+
"chars_per_token" : char_count / token_count if (token_count > 0) else 0,
|
| 374 |
+
"tokens_per_word" : token_count / word_count if (word_count > 0) else 0,
|
| 375 |
+
"tokenizer" : self.tokenizer_type,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
return stats
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def estimate_cost(self, text: str, cost_per_1k_tokens: float = 0.002) -> float:
|
| 382 |
+
"""
|
| 383 |
+
Estimate API cost for text.
|
| 384 |
+
|
| 385 |
+
Arguments:
|
| 386 |
+
----------
|
| 387 |
+
text { str } : Input text
|
| 388 |
+
|
| 389 |
+
cost_per_1k_tokens { float } : Cost per 1000 tokens (default: GPT-4 input)
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
--------
|
| 393 |
+
{ float } : Estimated cost in dollars
|
| 394 |
+
"""
|
| 395 |
+
tokens = self.count_tokens(text = text)
|
| 396 |
+
cost = (tokens / 1000) * cost_per_1k_tokens
|
| 397 |
+
|
| 398 |
+
return round(cost, 6)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def batch_count_tokens(self, texts: List[str]) -> List[int]:
|
| 402 |
+
"""
|
| 403 |
+
Count tokens for multiple texts efficiently
|
| 404 |
+
|
| 405 |
+
Arguments:
|
| 406 |
+
----------
|
| 407 |
+
texts { list } : List of texts
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
--------
|
| 411 |
+
{ list } : List of token counts
|
| 412 |
+
"""
|
| 413 |
+
token_counts = [self.count_tokens(text = text) for text in texts]
|
| 414 |
+
|
| 415 |
+
return token_counts
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def find_token_boundaries(self, text: str, target_tokens: int) -> tuple[int, str]:
|
| 419 |
+
"""
|
| 420 |
+
Find character position that gives approximately target tokens
|
| 421 |
+
|
| 422 |
+
Arguments:
|
| 423 |
+
----------
|
| 424 |
+
text { str } : Input text
|
| 425 |
+
|
| 426 |
+
target_tokens { int } : Target number of tokens
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
--------
|
| 430 |
+
{ tuple } : Tuple of (character_position, text_up_to_position)
|
| 431 |
+
"""
|
| 432 |
+
if self.tokenizer is not None:
|
| 433 |
+
tokens = self.encode(text)
|
| 434 |
+
|
| 435 |
+
if (len(tokens) <= target_tokens):
|
| 436 |
+
return len(text), text
|
| 437 |
+
|
| 438 |
+
target_tokens_subset = tokens[:target_tokens]
|
| 439 |
+
result_text = self.decode(target_tokens_subset)
|
| 440 |
+
|
| 441 |
+
return len(result_text), result_text
|
| 442 |
+
|
| 443 |
+
else:
|
| 444 |
+
# Approximate
|
| 445 |
+
total_tokens = self.count_tokens(text = text)
|
| 446 |
+
|
| 447 |
+
if (total_tokens <= target_tokens):
|
| 448 |
+
return len(text), text
|
| 449 |
+
|
| 450 |
+
ratio = target_tokens / total_tokens
|
| 451 |
+
char_pos = int(len(text) * ratio)
|
| 452 |
+
|
| 453 |
+
return char_pos, text[:char_pos]
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
# Global counter instance
|
| 457 |
+
_counter = None
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def get_token_counter(tokenizer_type: str = "cl100k_base") -> TokenCounter:
|
| 461 |
+
"""
|
| 462 |
+
Get global token counter instance
|
| 463 |
+
|
| 464 |
+
Arguments:
|
| 465 |
+
----------
|
| 466 |
+
tokenizer_type { str } : Tokenizer type
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
--------
|
| 470 |
+
{ TokenCounter } : TokenCounter instance
|
| 471 |
+
"""
|
| 472 |
+
global _counter
|
| 473 |
+
|
| 474 |
+
if _counter is None or _counter.tokenizer_type != tokenizer_type:
|
| 475 |
+
_counter = TokenCounter(tokenizer_type)
|
| 476 |
+
|
| 477 |
+
return _counter
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# Convenience functions
|
| 481 |
+
def count_tokens(text: str, tokenizer_type: str = "cl100k_base") -> int:
|
| 482 |
+
"""
|
| 483 |
+
Quick token count
|
| 484 |
+
|
| 485 |
+
Arguments:
|
| 486 |
+
----------
|
| 487 |
+
text { str } : Input text
|
| 488 |
+
|
| 489 |
+
tokenizer_type { str } : Tokenizer type
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
--------
|
| 493 |
+
{ int } : Token count
|
| 494 |
+
"""
|
| 495 |
+
counter = get_token_counter(tokenizer_type)
|
| 496 |
+
|
| 497 |
+
return counter.count_tokens(text)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def truncate_to_tokens(text: str, max_tokens: int, suffix: str = "...", tokenizer_type: str = "cl100k_base") -> str:
|
| 501 |
+
"""
|
| 502 |
+
Truncate text to max tokens
|
| 503 |
+
|
| 504 |
+
Arguments:
|
| 505 |
+
----------
|
| 506 |
+
text { str } : Input text
|
| 507 |
+
|
| 508 |
+
max_tokens { int } : Maximum tokens
|
| 509 |
+
|
| 510 |
+
suffix { str } : Suffix to add
|
| 511 |
+
|
| 512 |
+
tokenizer_type { str } : Tokenizer type
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
---------
|
| 516 |
+
{ str } : Truncated text
|
| 517 |
+
"""
|
| 518 |
+
counter = get_token_counter(tokenizer_type)
|
| 519 |
+
|
| 520 |
+
return counter.truncate_to_tokens(text, max_tokens, suffix)
|
config/__init__.py
ADDED
|
File without changes
|
config/logging_config.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from logging.handlers import RotatingFileHandler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ColoredFormatter(logging.Formatter):
|
| 11 |
+
"""
|
| 12 |
+
Custom formatter with color support for console output
|
| 13 |
+
"""
|
| 14 |
+
# ANSI color codes
|
| 15 |
+
COLORS = {'DEBUG' : '\033[36m', # Cyan
|
| 16 |
+
'INFO' : '\033[32m', # Green
|
| 17 |
+
'WARNING' : '\033[33m', # Yellow
|
| 18 |
+
'ERROR' : '\033[31m', # Red
|
| 19 |
+
'CRITICAL' : '\033[35m', # Magenta
|
| 20 |
+
'RESET' : '\033[0m', # Reset
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 25 |
+
"""
|
| 26 |
+
Format log record with color
|
| 27 |
+
"""
|
| 28 |
+
# Add color to level name
|
| 29 |
+
levelname = record.levelname
|
| 30 |
+
|
| 31 |
+
if levelname in self.COLORS:
|
| 32 |
+
record.levelname = (f"{self.COLORS[levelname]}{levelname}{self.COLORS['RESET']}")
|
| 33 |
+
|
| 34 |
+
# Add color to logger name
|
| 35 |
+
record.name = f"\033[34m{record.name}\033[0m" # Blue
|
| 36 |
+
|
| 37 |
+
return super().format(record)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class StructuredFormatter(logging.Formatter):
|
| 41 |
+
"""
|
| 42 |
+
Structured JSON-like formatter for file logging
|
| 43 |
+
"""
|
| 44 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 45 |
+
"""
|
| 46 |
+
Format log record as structured data
|
| 47 |
+
"""
|
| 48 |
+
log_data = {"timestamp" : datetime.fromtimestamp(record.created).isoformat(),
|
| 49 |
+
"level" : record.levelname,
|
| 50 |
+
"logger" : record.name,
|
| 51 |
+
"message" : record.getMessage(),
|
| 52 |
+
"module" : record.module,
|
| 53 |
+
"function" : record.funcName,
|
| 54 |
+
"line" : record.lineno,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Add exception info if present
|
| 58 |
+
if record.exc_info:
|
| 59 |
+
log_data["exception"] = self.formatException(record.exc_info)
|
| 60 |
+
|
| 61 |
+
# Add extra fields
|
| 62 |
+
if hasattr(record, "extra"):
|
| 63 |
+
log_data.update(record.extra)
|
| 64 |
+
|
| 65 |
+
# Format as key=value pairs (easier to read than JSON)
|
| 66 |
+
parts = [f"{k}={v}" for k, v in log_data.items()]
|
| 67 |
+
|
| 68 |
+
return " | ".join(parts)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def setup_logging(log_level: str = "INFO", log_dir: Optional[Path] = None, log_format: Optional[str] = None, enable_console: bool = True,
|
| 72 |
+
enable_file: bool = True, max_bytes: int = 10 * 1024 * 1024, backup_count: int = 5) -> logging.Logger:
|
| 73 |
+
"""
|
| 74 |
+
Setup comprehensive logging configuration
|
| 75 |
+
|
| 76 |
+
Arguments:
|
| 77 |
+
----------
|
| 78 |
+
log_level { str } : Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 79 |
+
|
| 80 |
+
log_dir { Path } : Directory for log files
|
| 81 |
+
|
| 82 |
+
log_format { str } : Custom log format string
|
| 83 |
+
|
| 84 |
+
enable_console { bool } : Enable console output
|
| 85 |
+
|
| 86 |
+
enable_file { bool } : Enable file output
|
| 87 |
+
|
| 88 |
+
max_bytes { int } : Max file size before rotation
|
| 89 |
+
|
| 90 |
+
backup_count { int } : Number of backup files to keep
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
--------
|
| 94 |
+
{ logging.Logger } : Configured root logger
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
# Get root logger
|
| 98 |
+
logger = logging.getLogger()
|
| 99 |
+
logger.setLevel(getattr(logging, log_level.upper()))
|
| 100 |
+
|
| 101 |
+
# Clear existing handlers
|
| 102 |
+
logger.handlers.clear()
|
| 103 |
+
|
| 104 |
+
# Default format
|
| 105 |
+
if log_format is None:
|
| 106 |
+
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 107 |
+
|
| 108 |
+
# Console handler with colors
|
| 109 |
+
if enable_console:
|
| 110 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 111 |
+
console_handler.setLevel(logging.DEBUG)
|
| 112 |
+
console_formatter = ColoredFormatter(log_format, datefmt = "%Y-%m-%d %H:%M:%S")
|
| 113 |
+
console_handler.setFormatter(console_formatter)
|
| 114 |
+
logger.addHandler(console_handler)
|
| 115 |
+
|
| 116 |
+
# File handler with rotation
|
| 117 |
+
if enable_file and log_dir:
|
| 118 |
+
log_dir = Path(log_dir)
|
| 119 |
+
log_dir.mkdir(parents = True, exist_ok = True)
|
| 120 |
+
|
| 121 |
+
# Main log file
|
| 122 |
+
log_file = log_dir / f"app_{datetime.now().strftime('%Y%m%d')}.log"
|
| 123 |
+
file_handler = RotatingFileHandler(log_file, maxBytes = max_bytes, backupCount = backup_count, encoding = "utf-8")
|
| 124 |
+
file_handler.setLevel(logging.DEBUG)
|
| 125 |
+
|
| 126 |
+
# Use structured formatter for files
|
| 127 |
+
file_formatter = StructuredFormatter()
|
| 128 |
+
file_handler.setFormatter(file_formatter)
|
| 129 |
+
logger.addHandler(file_handler)
|
| 130 |
+
|
| 131 |
+
# Separate error log
|
| 132 |
+
error_file = log_dir / f"error_{datetime.now().strftime('%Y%m%d')}.log"
|
| 133 |
+
error_handler = RotatingFileHandler(error_file, maxBytes = max_bytes, backupCount = backup_count, encoding = "utf-8")
|
| 134 |
+
|
| 135 |
+
error_handler.setLevel(logging.ERROR)
|
| 136 |
+
error_handler.setFormatter(file_formatter)
|
| 137 |
+
logger.addHandler(error_handler)
|
| 138 |
+
|
| 139 |
+
# Suppress noisy third-party loggers
|
| 140 |
+
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
| 141 |
+
logging.getLogger("requests").setLevel(logging.WARNING)
|
| 142 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 143 |
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
| 144 |
+
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
| 145 |
+
logging.getLogger("transformers").setLevel(logging.WARNING)
|
| 146 |
+
logging.getLogger("torch").setLevel(logging.WARNING)
|
| 147 |
+
logging.getLogger("playwright").setLevel(logging.WARNING)
|
| 148 |
+
logging.getLogger("faiss").setLevel(logging.WARNING)
|
| 149 |
+
logging.getLogger("llama_index").setLevel(logging.WARNING)
|
| 150 |
+
logging.getLogger("langchain").setLevel(logging.WARNING)
|
| 151 |
+
|
| 152 |
+
logger.info(f"Logging configured: level={log_level}, console={enable_console}, file={enable_file}")
|
| 153 |
+
|
| 154 |
+
return logger
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_logger(name: str) -> logging.Logger:
|
| 158 |
+
"""
|
| 159 |
+
Get a logger instance for a specific module
|
| 160 |
+
|
| 161 |
+
Arguments:
|
| 162 |
+
----------
|
| 163 |
+
name { str } : Logger name (typically __name__)
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
--------
|
| 167 |
+
{ logging.Logger } : Logger instance
|
| 168 |
+
"""
|
| 169 |
+
return logging.getLogger(name)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class LoggerAdapter(logging.LoggerAdapter):
|
| 173 |
+
"""
|
| 174 |
+
Custom logger adapter that adds contextual information: Useful for tracking request IDs, user IDs, etc
|
| 175 |
+
"""
|
| 176 |
+
def process(self, msg: str, kwargs: dict) -> tuple[str, dict]:
|
| 177 |
+
"""
|
| 178 |
+
Add extra context to log messages
|
| 179 |
+
"""
|
| 180 |
+
extra = self.extra.copy()
|
| 181 |
+
|
| 182 |
+
# Add to structured logging
|
| 183 |
+
if 'extra' not in kwargs:
|
| 184 |
+
kwargs['extra'] = {}
|
| 185 |
+
|
| 186 |
+
kwargs['extra'].update(extra)
|
| 187 |
+
|
| 188 |
+
# Add to message
|
| 189 |
+
context_parts = [f"{k}={v}" for k, v in extra.items()]
|
| 190 |
+
|
| 191 |
+
if context_parts:
|
| 192 |
+
msg = f"[{', '.join(context_parts)}] {msg}"
|
| 193 |
+
|
| 194 |
+
return msg, kwargs
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_context_logger(name: str, **context) -> LoggerAdapter:
|
| 198 |
+
"""
|
| 199 |
+
Get a logger with contextual information
|
| 200 |
+
|
| 201 |
+
Arguments:
|
| 202 |
+
----------
|
| 203 |
+
name { str } : Logger name
|
| 204 |
+
|
| 205 |
+
**context : Context key-value pairs
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
--------
|
| 209 |
+
{ LoggerAdapter } : Logger adapter with context
|
| 210 |
+
"""
|
| 211 |
+
base_logger = get_logger(name)
|
| 212 |
+
|
| 213 |
+
return LoggerAdapter(base_logger, context)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Performance logging utilities
|
| 217 |
+
class TimedLogger:
|
| 218 |
+
"""
|
| 219 |
+
Context manager for timing operations and logging
|
| 220 |
+
"""
|
| 221 |
+
def __init__(self, logger: logging.Logger, operation: str, level: int = logging.INFO, log_start: bool = False):
|
| 222 |
+
self.logger = logger
|
| 223 |
+
self.operation = operation
|
| 224 |
+
self.level = level
|
| 225 |
+
self.log_start = log_start
|
| 226 |
+
self.start_time = None
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def __enter__(self):
|
| 230 |
+
if self.log_start:
|
| 231 |
+
self.logger.log(self.level, f"{self.operation} started")
|
| 232 |
+
|
| 233 |
+
self.start_time = datetime.now()
|
| 234 |
+
|
| 235 |
+
return self
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 239 |
+
duration = (datetime.now() - self.start_time).total_seconds()
|
| 240 |
+
|
| 241 |
+
if exc_type is None:
|
| 242 |
+
self.logger.log(self.level, f"{self.operation} completed in {duration:.2f}s")
|
| 243 |
+
|
| 244 |
+
else:
|
| 245 |
+
self.logger.error(f"{self.operation} failed after {duration:.2f}s: {exc_val}")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# Logging decorators
|
| 249 |
+
def log_execution(logger: Optional[logging.Logger] = None, level: int = logging.INFO):
|
| 250 |
+
"""
|
| 251 |
+
Decorator to log function execution time
|
| 252 |
+
"""
|
| 253 |
+
def decorator(func):
|
| 254 |
+
nonlocal logger
|
| 255 |
+
|
| 256 |
+
if logger is None:
|
| 257 |
+
logger = get_logger(func.__module__)
|
| 258 |
+
|
| 259 |
+
def wrapper(*args, **kwargs):
|
| 260 |
+
func_name = f"{func.__module__}.{func.__name__}"
|
| 261 |
+
|
| 262 |
+
with TimedLogger(logger, func_name, level=level):
|
| 263 |
+
return func(*args, **kwargs)
|
| 264 |
+
|
| 265 |
+
return wrapper
|
| 266 |
+
return decorator
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def log_exceptions(logger: Optional[logging.Logger] = None, reraise: bool = True):
|
| 270 |
+
"""
|
| 271 |
+
Decorator to log exceptions with full traceback
|
| 272 |
+
"""
|
| 273 |
+
def decorator(func):
|
| 274 |
+
nonlocal logger
|
| 275 |
+
|
| 276 |
+
if logger is None:
|
| 277 |
+
logger = get_logger(func.__module__)
|
| 278 |
+
|
| 279 |
+
def wrapper(*args, **kwargs):
|
| 280 |
+
try:
|
| 281 |
+
return func(*args, **kwargs)
|
| 282 |
+
|
| 283 |
+
except Exception as e:
|
| 284 |
+
func_name = f"{func.__module__}.{func.__name__}"
|
| 285 |
+
|
| 286 |
+
logger.exception(f"Exception in {func_name}: {str(e)}")
|
| 287 |
+
|
| 288 |
+
if reraise:
|
| 289 |
+
raise
|
| 290 |
+
|
| 291 |
+
return None
|
| 292 |
+
|
| 293 |
+
return wrapper
|
| 294 |
+
return decorator
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Initialize logging on module import (with defaults) : This will be overridden by app.py with actual settings
|
| 298 |
+
_default_logger = setup_logging(log_level = "DEBUG",
|
| 299 |
+
enable_console = True,
|
| 300 |
+
enable_file = True,
|
| 301 |
+
)
|
config/models.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Any
|
| 5 |
+
from typing import List
|
| 6 |
+
from typing import Dict
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Literal
|
| 9 |
+
from pydantic import Field
|
| 10 |
+
from typing import Optional
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
from pydantic import ConfigDict
|
| 14 |
+
from pydantic import field_validator
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Enums
|
| 18 |
+
class DocumentType(str, Enum):
|
| 19 |
+
"""
|
| 20 |
+
Supported document types
|
| 21 |
+
"""
|
| 22 |
+
PDF = "pdf"
|
| 23 |
+
DOCX = "docx"
|
| 24 |
+
TXT = "txt"
|
| 25 |
+
URL = "url"
|
| 26 |
+
IMAGE = "image"
|
| 27 |
+
ARCHIVE = "archive"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class IngestionInputType(str, Enum):
|
| 31 |
+
"""
|
| 32 |
+
Supported input types for ingestion
|
| 33 |
+
"""
|
| 34 |
+
FILE = "file"
|
| 35 |
+
URL = "url"
|
| 36 |
+
ARCHIVE = "archive"
|
| 37 |
+
TEXT = "text"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ProcessingStatus(str, Enum):
|
| 41 |
+
"""
|
| 42 |
+
Document processing status
|
| 43 |
+
"""
|
| 44 |
+
PENDING = "pending"
|
| 45 |
+
PROCESSING = "processing"
|
| 46 |
+
COMPLETED = "completed"
|
| 47 |
+
FAILED = "failed"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TokenizerType(str, Enum):
|
| 51 |
+
"""
|
| 52 |
+
Supported tokenizer types
|
| 53 |
+
"""
|
| 54 |
+
CL100K = "cl100k_base" # GPT-4, GPT-3.5-turbo
|
| 55 |
+
P50K = "p50k_base" # Codex, text-davinci-002/003
|
| 56 |
+
R50K = "r50k_base" # GPT-3, text-davinci-001
|
| 57 |
+
GPT2 = "gpt2" # GPT-2
|
| 58 |
+
APPROXIMATE = "approximate" # Fast approximation
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ChunkingStrategy(str, Enum):
|
| 62 |
+
"""
|
| 63 |
+
Available chunking strategies
|
| 64 |
+
"""
|
| 65 |
+
FIXED = "fixed"
|
| 66 |
+
SEMANTIC = "semantic"
|
| 67 |
+
HIERARCHICAL = "hierarchical"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class LLMProvider(str, Enum):
|
| 71 |
+
"""
|
| 72 |
+
Supported LLM providers
|
| 73 |
+
"""
|
| 74 |
+
OLLAMA = "ollama"
|
| 75 |
+
OPENAI = "openai"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class TemperatureStrategy(str, Enum):
|
| 79 |
+
"""
|
| 80 |
+
Temperature control strategies
|
| 81 |
+
"""
|
| 82 |
+
FIXED = "fixed"
|
| 83 |
+
ADAPTIVE = "adaptive"
|
| 84 |
+
CONFIDENCE = "confidence"
|
| 85 |
+
PROGRESSIVE = "progressive"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class CitationStyle(str, Enum):
|
| 89 |
+
"""
|
| 90 |
+
Supported citation styles
|
| 91 |
+
"""
|
| 92 |
+
NUMERIC = "numeric"
|
| 93 |
+
VERBOSE = "verbose"
|
| 94 |
+
MINIMAL = "minimal"
|
| 95 |
+
ACADEMIC = "academic"
|
| 96 |
+
LEGAL = "legal"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class PromptType(str, Enum):
|
| 100 |
+
"""
|
| 101 |
+
Supported prompt types
|
| 102 |
+
"""
|
| 103 |
+
QA = "qa"
|
| 104 |
+
SUMMARY = "summary"
|
| 105 |
+
ANALYTICAL = "analytical"
|
| 106 |
+
COMPARISON = "comparison"
|
| 107 |
+
EXTRACTION = "extraction"
|
| 108 |
+
CREATIVE = "creative"
|
| 109 |
+
CONVERSATIONAL = "conversational"
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Document Models
|
| 113 |
+
class DocumentMetadata(BaseModel):
|
| 114 |
+
"""
|
| 115 |
+
Metadata extracted from documents
|
| 116 |
+
"""
|
| 117 |
+
model_config = ConfigDict(arbitrary_types_allowed = True)
|
| 118 |
+
|
| 119 |
+
document_id : str = Field(..., description = "Unique document identifier")
|
| 120 |
+
filename : str = Field(..., description = "Original filename")
|
| 121 |
+
file_path : Optional[Path] = Field(None, description = "Path to uploaded file")
|
| 122 |
+
document_type : DocumentType = Field(..., description = "Type of document")
|
| 123 |
+
|
| 124 |
+
# Content metadata
|
| 125 |
+
title : Optional[str] = Field(None, description = "Document title")
|
| 126 |
+
author : Optional[str] = Field(None, description = "Document author")
|
| 127 |
+
created_date : Optional[datetime] = Field(None, description = "Document creation date")
|
| 128 |
+
modified_date : Optional[datetime] = Field(None, description = "Last modification date")
|
| 129 |
+
|
| 130 |
+
# Processing metadata
|
| 131 |
+
upload_date : datetime = Field(default_factory = datetime.now)
|
| 132 |
+
processed_date : Optional[datetime] = Field(None)
|
| 133 |
+
status : ProcessingStatus = Field(default = ProcessingStatus.PENDING)
|
| 134 |
+
|
| 135 |
+
# Size metrics
|
| 136 |
+
file_size_bytes : int = Field(..., gt = 0, description = "File size in bytes")
|
| 137 |
+
num_pages : Optional[int] = Field(None, ge = 1, description = "Number of pages (PDFs)")
|
| 138 |
+
num_tokens : Optional[int] = Field(None, ge = 0, description = "Total tokens")
|
| 139 |
+
num_chunks : Optional[int] = Field(None, ge = 0, description = "Number of chunks")
|
| 140 |
+
|
| 141 |
+
# Processing info
|
| 142 |
+
chunking_strategy : Optional[ChunkingStrategy] = Field(None)
|
| 143 |
+
processing_time_seconds : Optional[float] = Field(None, ge = 0.0)
|
| 144 |
+
error_message : Optional[str] = Field(None)
|
| 145 |
+
|
| 146 |
+
# Additional metadata
|
| 147 |
+
extra : Dict[str, Any] = Field(default_factory = dict)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@field_validator("file_size_bytes")
|
| 151 |
+
@classmethod
|
| 152 |
+
def validate_file_size(cls, v: int) -> int:
|
| 153 |
+
"""
|
| 154 |
+
Ensure file size is reasonable
|
| 155 |
+
"""
|
| 156 |
+
max_size = 2 * 1024 * 1024 * 1024 # 2GB
|
| 157 |
+
|
| 158 |
+
if (v > max_size):
|
| 159 |
+
raise ValueError(f"File size {v} exceeds maximum {max_size}")
|
| 160 |
+
|
| 161 |
+
return v
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def file_size_mb(self) -> float:
|
| 165 |
+
"""
|
| 166 |
+
File size in megabytes
|
| 167 |
+
"""
|
| 168 |
+
return self.file_size_bytes / (1024 * 1024)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class DocumentChunk(BaseModel):
|
| 173 |
+
"""
|
| 174 |
+
A single chunk of text from a document
|
| 175 |
+
"""
|
| 176 |
+
chunk_id : str = Field(..., description = "Unique chunk identifier")
|
| 177 |
+
document_id : str = Field(..., description = "Parent document ID")
|
| 178 |
+
|
| 179 |
+
# Content
|
| 180 |
+
text : str = Field(..., min_length = 1, description = "Chunk text content")
|
| 181 |
+
embedding : Optional[List[float]] = Field(None, description = "Vector embedding")
|
| 182 |
+
|
| 183 |
+
# Position metadata
|
| 184 |
+
chunk_index : int = Field(..., ge = 0, description = "Chunk position in document")
|
| 185 |
+
start_char : int = Field(..., ge = 0, description = "Start character position")
|
| 186 |
+
end_char : int = Field(..., ge = 0, description = "End character position")
|
| 187 |
+
|
| 188 |
+
# Page/section info
|
| 189 |
+
page_number : Optional[int] = Field(None, ge = 1, description = "Page number (if applicable)")
|
| 190 |
+
section_title : Optional[str] = Field(None, description = "Section heading")
|
| 191 |
+
|
| 192 |
+
# Hierarchical info (for hierarchical chunking)
|
| 193 |
+
parent_chunk_id : Optional[str] = Field(None)
|
| 194 |
+
child_chunk_ids : List[str] = Field(default_factory = list)
|
| 195 |
+
|
| 196 |
+
# Token info
|
| 197 |
+
token_count : int = Field(..., gt = 0, description = "Number of tokens")
|
| 198 |
+
|
| 199 |
+
# Metadata
|
| 200 |
+
metadata : Dict[str, Any] = Field(default_factory = dict)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def char_count(self) -> int:
|
| 205 |
+
"""
|
| 206 |
+
Number of characters in chunk
|
| 207 |
+
"""
|
| 208 |
+
return self.end_char - self.start_char
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class ChunkWithScore(BaseModel):
|
| 212 |
+
"""
|
| 213 |
+
Chunk with retrieval score
|
| 214 |
+
"""
|
| 215 |
+
chunk : DocumentChunk
|
| 216 |
+
score : float = Field(..., description = "Relevance score (can be any real number)")
|
| 217 |
+
rank : int = Field(..., ge = 1, description = "Rank in results")
|
| 218 |
+
retrieval_method : str = Field('vector', description = "Retrieval method used")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def citation(self) -> str:
|
| 223 |
+
parts = [self.chunk.document_id]
|
| 224 |
+
|
| 225 |
+
# Add source filename if available
|
| 226 |
+
if ((hasattr(self.chunk, 'metadata')) and ('filename' in self.chunk.metadata)):
|
| 227 |
+
parts.append(f"file: {self.chunk.metadata['filename']}")
|
| 228 |
+
|
| 229 |
+
if self.chunk.page_number:
|
| 230 |
+
parts.append(f"page {self.chunk.page_number}")
|
| 231 |
+
|
| 232 |
+
if self.chunk.section_title:
|
| 233 |
+
parts.append(f"section: {self.chunk.section_title}")
|
| 234 |
+
|
| 235 |
+
return ", ".join(parts)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Embedding Request
|
| 239 |
+
class EmbeddingRequest(BaseModel):
|
| 240 |
+
texts : List[str]
|
| 241 |
+
normalize : bool = True
|
| 242 |
+
device : Optional[str] = None
|
| 243 |
+
batch_size : Optional[int] = None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# Query Models
|
| 247 |
+
class QueryRequest(BaseModel):
|
| 248 |
+
"""
|
| 249 |
+
User query request
|
| 250 |
+
"""
|
| 251 |
+
model_config = ConfigDict(protected_namespaces = ())
|
| 252 |
+
|
| 253 |
+
query : str = Field(..., min_length = 1, max_length = 1000, description = "User question")
|
| 254 |
+
|
| 255 |
+
# Retrieval parameters
|
| 256 |
+
top_k : Optional[int] = Field(5, ge = 1, le = 20, description = "Number of chunks to retrieve")
|
| 257 |
+
enable_reranking : Optional[bool] = Field(False)
|
| 258 |
+
|
| 259 |
+
# Generation parameters
|
| 260 |
+
temperature : Optional[float] = Field(0.1, ge = 0.0, le = 1.0)
|
| 261 |
+
top_p : Optional[float] = Field(0.9, ge = 0.0, le = 1.0)
|
| 262 |
+
max_tokens : Optional[int] = Field(1000, ge = 50, le = 4000)
|
| 263 |
+
|
| 264 |
+
# Filters
|
| 265 |
+
document_ids : Optional[List[str]] = Field(None, description = "Filter by specific documents")
|
| 266 |
+
date_from : Optional[datetime] = Field(None)
|
| 267 |
+
date_to : Optional[datetime] = Field(None)
|
| 268 |
+
|
| 269 |
+
# Response preferences
|
| 270 |
+
include_sources : bool = Field(True, description = "Include source citations")
|
| 271 |
+
include_metrics : bool = Field(False, description = "Include quality metrics")
|
| 272 |
+
stream : bool = Field(False, description = "Stream response tokens")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class QueryResponse(BaseModel):
|
| 276 |
+
"""
|
| 277 |
+
Response to user query
|
| 278 |
+
"""
|
| 279 |
+
query : str = Field(..., description = "Original query")
|
| 280 |
+
answer : str = Field(..., description = "Generated answer")
|
| 281 |
+
|
| 282 |
+
# Retrieved context
|
| 283 |
+
sources : List[ChunkWithScore] = Field(default_factory = list)
|
| 284 |
+
|
| 285 |
+
# Metrics
|
| 286 |
+
retrieval_time_ms : float = Field(..., ge = 0.0)
|
| 287 |
+
generation_time_ms : float = Field(..., ge = 0.0)
|
| 288 |
+
total_time_ms : float = Field(..., ge = 0.0)
|
| 289 |
+
|
| 290 |
+
tokens_used : Optional[Dict[str, int]] = Field(None) # {input: X, output: Y}
|
| 291 |
+
|
| 292 |
+
# Quality metrics (if enabled)
|
| 293 |
+
metrics : Optional[Dict[str, float]] = Field(None)
|
| 294 |
+
|
| 295 |
+
# Metadata
|
| 296 |
+
timestamp : datetime = Field(default_factory = datetime.now)
|
| 297 |
+
model_used : str = Field(...)
|
| 298 |
+
|
| 299 |
+
model_config = ConfigDict(protected_namespaces = ())
|
| 300 |
+
|
| 301 |
+
@property
|
| 302 |
+
def citation_text(self) -> str:
|
| 303 |
+
"""
|
| 304 |
+
Format citations as text
|
| 305 |
+
"""
|
| 306 |
+
if not self.sources:
|
| 307 |
+
return ""
|
| 308 |
+
|
| 309 |
+
citations = list()
|
| 310 |
+
|
| 311 |
+
for i, source in enumerate(self.sources, 1):
|
| 312 |
+
citations.append(f"[{i}] {source.citation}")
|
| 313 |
+
|
| 314 |
+
return "\n".join(citations)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# Upload Models
|
| 318 |
+
class UploadRequest(BaseModel):
|
| 319 |
+
"""
|
| 320 |
+
File upload request metadata
|
| 321 |
+
"""
|
| 322 |
+
filename : str = Field(..., min_length = 1)
|
| 323 |
+
file_size_bytes : int = Field(..., gt = 0)
|
| 324 |
+
content_type : Optional[str] = Field(None)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@field_validator("filename")
|
| 328 |
+
@classmethod
|
| 329 |
+
def validate_filename(cls, v: str) -> str:
|
| 330 |
+
"""
|
| 331 |
+
Ensure filename is safe
|
| 332 |
+
"""
|
| 333 |
+
# Remove path traversal attempts
|
| 334 |
+
v = Path(v).name
|
| 335 |
+
|
| 336 |
+
if not v or v.startswith("."):
|
| 337 |
+
raise ValueError("Invalid filename")
|
| 338 |
+
|
| 339 |
+
return v
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class UploadResponse(BaseModel):
|
| 343 |
+
"""
|
| 344 |
+
File upload response
|
| 345 |
+
"""
|
| 346 |
+
document_id : str = Field(..., description = "Generated document ID")
|
| 347 |
+
filename : str = Field(...)
|
| 348 |
+
status : ProcessingStatus = Field(...)
|
| 349 |
+
message : str = Field(...)
|
| 350 |
+
upload_date : datetime = Field(default_factory = datetime.now)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class ProcessingProgress(BaseModel):
|
| 354 |
+
"""
|
| 355 |
+
Real-time processing progress
|
| 356 |
+
"""
|
| 357 |
+
document_id : str = Field(...)
|
| 358 |
+
status : ProcessingStatus = Field(...)
|
| 359 |
+
|
| 360 |
+
# Progress tracking
|
| 361 |
+
progress_percentage : float = Field(0.0, ge = 0.0, le = 100.0)
|
| 362 |
+
current_step : str = Field(..., description = "Current processing step")
|
| 363 |
+
|
| 364 |
+
# Stats
|
| 365 |
+
chunks_processed : int = Field(0, ge = 0)
|
| 366 |
+
total_chunks : Optional[int] = Field(None)
|
| 367 |
+
|
| 368 |
+
# Timing
|
| 369 |
+
start_time : datetime = Field(...)
|
| 370 |
+
elapsed_seconds : float = Field(0.0, ge = 0.0)
|
| 371 |
+
estimated_remaining_seconds : Optional[float] = Field(None)
|
| 372 |
+
|
| 373 |
+
# Messages
|
| 374 |
+
log_messages : List[str] = Field(default_factory = list)
|
| 375 |
+
error_message : Optional[str] = Field(None)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# Embedding Models
|
| 379 |
+
class EmbeddingRequest(BaseModel):
|
| 380 |
+
"""
|
| 381 |
+
Request to generate embeddings
|
| 382 |
+
"""
|
| 383 |
+
texts : List[str] = Field(..., min_length = 1, max_length = 1000)
|
| 384 |
+
batch_size : Optional[int] = Field(32, ge = 1, le = 128)
|
| 385 |
+
normalize : bool = Field(True, description = "Normalize embeddings to unit length")
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class EmbeddingResponse(BaseModel):
|
| 389 |
+
"""
|
| 390 |
+
Embedding generation response
|
| 391 |
+
"""
|
| 392 |
+
embeddings : List[List[float]] = Field(...)
|
| 393 |
+
dimension : int = Field(..., gt = 0)
|
| 394 |
+
num_embeddings : int = Field(..., gt = 0)
|
| 395 |
+
processing_time_ms : float = Field(..., ge = 0.0)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# Retrieval Models
|
| 399 |
+
class RetrievalRequest(BaseModel):
|
| 400 |
+
"""
|
| 401 |
+
Request for document retrieval
|
| 402 |
+
"""
|
| 403 |
+
query : str = Field(..., min_length = 1)
|
| 404 |
+
top_k : int = Field(10, ge = 1, le = 100)
|
| 405 |
+
|
| 406 |
+
# Retrieval method
|
| 407 |
+
use_vector : bool = Field(True)
|
| 408 |
+
use_bm25 : bool = Field(True)
|
| 409 |
+
vector_weight : Optional[float] = Field(0.6, ge = 0.0, le = 1.0)
|
| 410 |
+
|
| 411 |
+
# Filters
|
| 412 |
+
document_ids : Optional[List[str]] = Field(None)
|
| 413 |
+
min_score : Optional[float] = Field(None, ge = 0.0, le = 1.0)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class RetrievalResponse(BaseModel):
|
| 417 |
+
"""
|
| 418 |
+
Document retrieval response
|
| 419 |
+
"""
|
| 420 |
+
chunks : List[ChunkWithScore] = Field(...)
|
| 421 |
+
retrieval_time_ms : float = Field(..., ge = 0.0)
|
| 422 |
+
num_candidates : int = Field(..., ge = 0)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# System Models
|
| 426 |
+
class HealthCheck(BaseModel):
|
| 427 |
+
"""
|
| 428 |
+
System health check response
|
| 429 |
+
"""
|
| 430 |
+
status : Literal["healthy", "degraded", "unhealthy"] = Field(...)
|
| 431 |
+
timestamp : datetime = Field(default_factory = datetime.now)
|
| 432 |
+
|
| 433 |
+
# Component status
|
| 434 |
+
ollama_available : bool = Field(...)
|
| 435 |
+
vector_store_available : bool = Field(...)
|
| 436 |
+
embedding_model_available : bool = Field(...)
|
| 437 |
+
|
| 438 |
+
# Stats
|
| 439 |
+
total_documents : int = Field(0, ge = 0)
|
| 440 |
+
total_chunks : int = Field(0, ge = 0)
|
| 441 |
+
|
| 442 |
+
# Version info
|
| 443 |
+
version : str = Field(...)
|
| 444 |
+
|
| 445 |
+
# Issues
|
| 446 |
+
warnings : List[str] = Field(default_factory = list)
|
| 447 |
+
errors : List[str] = Field(default_factory = list)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class SystemStats(BaseModel):
|
| 451 |
+
"""
|
| 452 |
+
System statistics
|
| 453 |
+
"""
|
| 454 |
+
# Document stats
|
| 455 |
+
total_documents : int = Field(0, ge = 0)
|
| 456 |
+
documents_by_type : Dict[str, int] = Field(default_factory = dict)
|
| 457 |
+
total_file_size_mb : float = Field(0.0, ge = 0.0)
|
| 458 |
+
|
| 459 |
+
# Chunk stats
|
| 460 |
+
total_chunks : int = Field(0, ge = 0)
|
| 461 |
+
avg_chunk_size : float = Field(0.0, ge = 0.0)
|
| 462 |
+
|
| 463 |
+
# Query stats
|
| 464 |
+
total_queries : int = Field(0, ge = 0)
|
| 465 |
+
avg_query_time_ms : float = Field(0.0, ge = 0.0)
|
| 466 |
+
avg_retrieval_score : float = Field(0.0, ge = 0.0)
|
| 467 |
+
|
| 468 |
+
# Timestamp
|
| 469 |
+
generated_at : datetime = Field(default_factory = datetime.now)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class ErrorResponse(BaseModel):
|
| 473 |
+
"""
|
| 474 |
+
Standard error response
|
| 475 |
+
"""
|
| 476 |
+
error : str = Field(..., description = "Error type")
|
| 477 |
+
message : str = Field(..., description = "Human-readable error message")
|
| 478 |
+
detail : Optional[Dict[str, Any]] = Field(None, description = "Additional error details")
|
| 479 |
+
timestamp : datetime = Field(default_factory = datetime.now)
|
| 480 |
+
request_id : Optional[str] = Field(None)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# Configuration Models
|
| 484 |
+
class ChunkingConfig(BaseModel):
|
| 485 |
+
"""
|
| 486 |
+
Chunking configuration
|
| 487 |
+
"""
|
| 488 |
+
strategy : ChunkingStrategy = Field(...)
|
| 489 |
+
chunk_size : int = Field(..., gt = 0)
|
| 490 |
+
overlap : int = Field(..., ge = 0)
|
| 491 |
+
|
| 492 |
+
# Strategy-specific params
|
| 493 |
+
semantic_threshold : Optional[float] = Field(None, ge = 0.0, le = 1.0)
|
| 494 |
+
parent_size : Optional[int] = Field(None, gt = 0)
|
| 495 |
+
child_size : Optional[int] = Field(None, gt = 0)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class RetrievalConfig(BaseModel):
|
| 499 |
+
"""
|
| 500 |
+
Retrieval configuration
|
| 501 |
+
"""
|
| 502 |
+
top_k : int = Field(10, ge = 1, le = 100)
|
| 503 |
+
vector_weight : float = Field(0.6, ge = 0.0, le = 1.0)
|
| 504 |
+
bm25_weight : float = Field(0.4, ge = 0.0, le = 1.0)
|
| 505 |
+
enable_reranking : bool = Field(False)
|
| 506 |
+
faiss_nprobe : int = Field(10, ge = 1, le = 100)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
@field_validator("bm25_weight")
|
| 510 |
+
@classmethod
|
| 511 |
+
def validate_weights(cls, v: float, info) -> float:
|
| 512 |
+
"""
|
| 513 |
+
Ensure weights sum to 1.0
|
| 514 |
+
"""
|
| 515 |
+
if ("vector_weight" in info.data):
|
| 516 |
+
vector_weight = info.data["vector_weight"]
|
| 517 |
+
|
| 518 |
+
if (abs(vector_weight + v - 1.0) > 0.01):
|
| 519 |
+
raise ValueError("vector_weight + bm25_weight must equal 1.0")
|
| 520 |
+
|
| 521 |
+
return v
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
# Chat Response
|
| 525 |
+
class ChatRequest(BaseModel):
|
| 526 |
+
message : str
|
| 527 |
+
session_id : Optional[str] = None
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# Validation Utilities
|
| 531 |
+
def validate_document_id(document_id: str) -> bool:
|
| 532 |
+
"""
|
| 533 |
+
Validate document ID format
|
| 534 |
+
"""
|
| 535 |
+
# Format: doc_<timestamp>_<hash>
|
| 536 |
+
pattern = r'^doc_\d{10,}_[a-f0-9]{8}$'
|
| 537 |
+
|
| 538 |
+
return bool(re.match(pattern, document_id))
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def validate_chunk_id(chunk_id: str) -> bool:
|
| 542 |
+
"""
|
| 543 |
+
Validate chunk ID format
|
| 544 |
+
"""
|
| 545 |
+
# Format: chunk_<doc_id>_<index>
|
| 546 |
+
pattern = r'^chunk_doc_\d+_[a-f0-9]{8}_\d+$'
|
| 547 |
+
|
| 548 |
+
return bool(re.match(pattern, chunk_id))
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
# RAGAS Evaluation Models
|
| 552 |
+
class RAGASEvaluationResult(BaseModel):
|
| 553 |
+
"""
|
| 554 |
+
Single RAGAS evaluation result
|
| 555 |
+
"""
|
| 556 |
+
model_config = ConfigDict(arbitrary_types_allowed = True)
|
| 557 |
+
|
| 558 |
+
# Input data
|
| 559 |
+
query : str = Field(..., description = "User query")
|
| 560 |
+
answer : str = Field(..., description = "Generated answer")
|
| 561 |
+
contexts : List[str] = Field(..., description = "Retrieved context chunks")
|
| 562 |
+
ground_truth : Optional[str] = Field(None, description = "Reference answer (if available)")
|
| 563 |
+
timestamp : str = Field(..., description = "Evaluation timestamp")
|
| 564 |
+
|
| 565 |
+
# RAGAS metrics (without ground truth)
|
| 566 |
+
answer_relevancy : float = Field(..., ge = 0.0, le = 1.0, description = "How well answer addresses question")
|
| 567 |
+
faithfulness : float = Field(..., ge = 0.0, le = 1.0, description = "Is answer grounded in context")
|
| 568 |
+
context_precision : Optional[float] = Field(None, ge = 0.0, le = 1.0, description = "Are relevant chunks ranked high (requires ground truth)")
|
| 569 |
+
context_utilization: Optional[float] = Field(None, ge = 0.0, le = 1.0, description = "Context utilization score (without ground truth)")
|
| 570 |
+
context_relevancy : float = Field(..., ge = 0.0, le = 1.0, description = "How relevant are retrieved chunks")
|
| 571 |
+
|
| 572 |
+
# RAGAS metrics (requiring ground truth)
|
| 573 |
+
context_recall : Optional[float] = Field(None, ge = 0.0, le = 1.0, description = "Coverage of ground truth")
|
| 574 |
+
answer_similarity : Optional[float] = Field(None, ge = 0.0, le = 1.0, description = "Similarity to reference")
|
| 575 |
+
answer_correctness : Optional[float] = Field(None, ge = 0.0, le = 1.0, description = "Correctness vs reference")
|
| 576 |
+
|
| 577 |
+
# Performance metadata
|
| 578 |
+
retrieval_time_ms : int = Field(..., ge = 0, description = "Retrieval time in milliseconds")
|
| 579 |
+
generation_time_ms : int = Field(..., ge = 0, description = "Generation time in milliseconds")
|
| 580 |
+
total_time_ms : int = Field(..., ge = 0, description = "Total time in milliseconds")
|
| 581 |
+
chunks_retrieved : int = Field(..., ge = 0, description = "Number of chunks retrieved")
|
| 582 |
+
query_type : str = Field("rag", description = "Type of query: 'rag' or 'general'")
|
| 583 |
+
|
| 584 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 585 |
+
"""
|
| 586 |
+
Convert to dictionary
|
| 587 |
+
"""
|
| 588 |
+
return self.model_dump()
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
@property
|
| 592 |
+
def has_ground_truth_metrics(self) -> bool:
|
| 593 |
+
"""
|
| 594 |
+
Check if ground truth metrics are available
|
| 595 |
+
"""
|
| 596 |
+
return any([self.context_recall is not None,
|
| 597 |
+
self.answer_similarity is not None,
|
| 598 |
+
self.answer_correctness is not None
|
| 599 |
+
])
|
| 600 |
+
|
| 601 |
+
@property
|
| 602 |
+
def overall_score(self) -> float:
|
| 603 |
+
"""
|
| 604 |
+
Calculate weighted overall score
|
| 605 |
+
"""
|
| 606 |
+
scores = list()
|
| 607 |
+
weights = list()
|
| 608 |
+
|
| 609 |
+
# Always include these metrics
|
| 610 |
+
scores.append(self.answer_relevancy)
|
| 611 |
+
weights.append(0.4)
|
| 612 |
+
|
| 613 |
+
scores.append(self.faithfulness)
|
| 614 |
+
weights.append(0.3)
|
| 615 |
+
|
| 616 |
+
scores.append(self.context_relevancy)
|
| 617 |
+
weights.append(0.1)
|
| 618 |
+
|
| 619 |
+
# Include context_precision OR context_utilization (but not both)
|
| 620 |
+
if self.context_precision is not None:
|
| 621 |
+
scores.append(self.context_precision)
|
| 622 |
+
weights.append(0.2)
|
| 623 |
+
|
| 624 |
+
elif self.context_utilization is not None:
|
| 625 |
+
scores.append(self.context_utilization)
|
| 626 |
+
weights.append(0.2)
|
| 627 |
+
|
| 628 |
+
else:
|
| 629 |
+
# If neither is available, adjust weights
|
| 630 |
+
weights = [w * 1.25 for w in weights] # Scale existing weights
|
| 631 |
+
|
| 632 |
+
# Calculate weighted average
|
| 633 |
+
if (sum(weights) > 0):
|
| 634 |
+
weighted_sum = sum(s * w for s, w in zip(scores, weights))
|
| 635 |
+
score = weighted_sum / sum(weights)
|
| 636 |
+
|
| 637 |
+
else:
|
| 638 |
+
score = sum(scores) / len(scores) if scores else 0.0
|
| 639 |
+
|
| 640 |
+
return round(score, 3)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
class RAGASStatistics(BaseModel):
|
| 644 |
+
"""
|
| 645 |
+
Aggregate RAGAS statistics for a session
|
| 646 |
+
"""
|
| 647 |
+
total_evaluations : int = Field(..., ge = 0, description = "Total number of evaluations")
|
| 648 |
+
|
| 649 |
+
# Average metrics
|
| 650 |
+
avg_answer_relevancy : float = Field(..., ge = 0.0, le = 1.0)
|
| 651 |
+
avg_faithfulness : float = Field(..., ge = 0.0, le = 1.0)
|
| 652 |
+
avg_context_precision : Optional[float] = Field(None, ge = 0.0, le = 1.0, description = "Average context precision (requires ground truth)")
|
| 653 |
+
avg_context_utilization: Optional[float] = Field(None, ge = 0.0, le = 1.0, description = "Average context utilization (without ground truth)")
|
| 654 |
+
avg_context_relevancy : float = Field(..., ge = 0.0, le = 1.0)
|
| 655 |
+
avg_overall_score : float = Field(..., ge = 0.0, le = 1.0)
|
| 656 |
+
|
| 657 |
+
# Ground truth metrics (if available)
|
| 658 |
+
avg_context_recall : Optional[float] = Field(None, ge = 0.0, le = 1.0)
|
| 659 |
+
avg_answer_similarity : Optional[float] = Field(None, ge = 0.0, le = 1.0)
|
| 660 |
+
avg_answer_correctness : Optional[float] = Field(None, ge = 0.0, le = 1.0)
|
| 661 |
+
|
| 662 |
+
# Performance metrics
|
| 663 |
+
avg_retrieval_time_ms : float = Field(..., ge = 0.0)
|
| 664 |
+
avg_generation_time_ms : float = Field(..., ge = 0.0)
|
| 665 |
+
avg_total_time_ms : float = Field(..., ge = 0.0)
|
| 666 |
+
|
| 667 |
+
# Quality indicators
|
| 668 |
+
min_score : float = Field(..., ge = 0.0, le = 1.0, description = "Lowest overall score")
|
| 669 |
+
max_score : float = Field(..., ge = 0.0, le = 1.0, description = "Highest overall score")
|
| 670 |
+
std_dev : float = Field(..., ge = 0.0, description = "Standard deviation of scores")
|
| 671 |
+
|
| 672 |
+
# Session info
|
| 673 |
+
session_start : datetime = Field(..., description = "When evaluation session started")
|
| 674 |
+
last_updated : datetime = Field(..., description = "Last evaluation timestamp")
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class RAGASExportData(BaseModel):
|
| 678 |
+
"""
|
| 679 |
+
Complete RAGAS evaluation export data
|
| 680 |
+
"""
|
| 681 |
+
export_timestamp : datetime = Field(default_factory = datetime.now)
|
| 682 |
+
total_evaluations : int = Field(..., ge = 0)
|
| 683 |
+
statistics : RAGASStatistics
|
| 684 |
+
evaluations : List[RAGASEvaluationResult]
|
| 685 |
+
|
| 686 |
+
# Configuration info
|
| 687 |
+
ground_truth_enabled : bool = Field(...)
|
| 688 |
+
ragas_version : str = Field(default = "0.1.9")
|
| 689 |
+
|
| 690 |
+
@property
|
| 691 |
+
def export_filename(self) -> str:
|
| 692 |
+
"""
|
| 693 |
+
Generate export filename
|
| 694 |
+
"""
|
| 695 |
+
timestamp = self.export_timestamp.strftime("%Y%m%d_%H%M%S")
|
| 696 |
+
|
| 697 |
+
return f"ragas_evaluation_{timestamp}.json"
|
config/settings.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
from typing import Literal
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from pydantic import field_validator
|
| 10 |
+
from pydantic_settings import BaseSettings
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Settings(BaseSettings):
|
| 14 |
+
"""
|
| 15 |
+
Application configuration with environment variable support
|
| 16 |
+
|
| 17 |
+
Environment variables take precedence over defaults
|
| 18 |
+
"""
|
| 19 |
+
# Huggingface Space Deployment mode detection
|
| 20 |
+
IS_HF_SPACE : bool = Field(default = os.getenv("SPACE_ID") is not None, description = "Running in HF Space")
|
| 21 |
+
|
| 22 |
+
# Application Settings
|
| 23 |
+
APP_NAME : str = "AI Universal Knowledge Ingestion System"
|
| 24 |
+
APP_VERSION : str = "1.0.0"
|
| 25 |
+
DEBUG : bool = Field(default = False, description = "Enable debug mode")
|
| 26 |
+
HOST : str = Field(default = "0.0.0.0", description = "API host")
|
| 27 |
+
PORT : int = Field(default = int(os.getenv("PORT", 8000)), description = "API port (7860 for HF Spaces)")
|
| 28 |
+
|
| 29 |
+
# LLM Provider Selection (ADD THESE)
|
| 30 |
+
OLLAMA_ENABLED : bool = Field(default = os.getenv("OLLAMA_ENABLED", "true").lower() == "true", description = "Enable Ollama (set false for HF Spaces)")
|
| 31 |
+
USE_OPENAI : bool = Field(default = os.getenv("USE_OPENAI", "false").lower() == "true", description = "Use OpenAI API instead of local LLM")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# File Upload Settings
|
| 35 |
+
MAX_FILE_SIZE_MB : int = Field(default = 100, description = "Max file size in MB")
|
| 36 |
+
MAX_BATCH_FILES : int = Field(default = 10, description = "Max files per upload")
|
| 37 |
+
ALLOWED_EXTENSIONS : list[str] = Field(default = ["pdf", "docx", "txt"], description = "Allowed file extensions")
|
| 38 |
+
UPLOAD_DIR : Path = Field(default = Path("data/uploads"), description = "Directory for uploaded files")
|
| 39 |
+
|
| 40 |
+
# Ollama LLM Settings
|
| 41 |
+
OLLAMA_BASE_URL : str = Field(default = "http://localhost:11434", description = "Ollama API endpoint")
|
| 42 |
+
OLLAMA_MODEL : str = Field(default = "mistral:7b", description = "Ollama model name")
|
| 43 |
+
OLLAMA_TIMEOUT : int = Field(default = 120, description = "Ollama request timeout (seconds)")
|
| 44 |
+
|
| 45 |
+
# Generation parameters
|
| 46 |
+
DEFAULT_TEMPERATURE : float = Field(default = 0.1, ge = 0.0, le = 1.0, description = "LLM temperature (0=deterministic, 1=creative)")
|
| 47 |
+
TOP_P : float = Field(default = 0.9, ge = 0.0, le = 1.0, description = "Nucleus sampling threshold")
|
| 48 |
+
MAX_TOKENS : int = Field(default = 1000, description = "Max output tokens")
|
| 49 |
+
CONTEXT_WINDOW : int = Field(default = 8192, description = "Model context window size")
|
| 50 |
+
|
| 51 |
+
# OpenAI Settings
|
| 52 |
+
OPENAI_API_KEY : Optional[str] = Field(default = os.getenv("OPENAI_API_KEY"), description = "Open AI API secret key")
|
| 53 |
+
OPENAI_MODEL : str = Field(default = "gpt-3.5-turbo", description = "Ollama model name")
|
| 54 |
+
|
| 55 |
+
# Embedding Settings
|
| 56 |
+
EMBEDDING_MODEL : str = Field(default = "BAAI/bge-small-en-v1.5", description = "HuggingFace embedding model")
|
| 57 |
+
EMBEDDING_DIMENSION : int = Field(default = 384, description = "Embedding vector dimension")
|
| 58 |
+
EMBEDDING_DEVICE : Literal["cpu", "cuda", "mps"] = Field(default = "cpu", description = "Device for embedding generation")
|
| 59 |
+
EMBEDDING_BATCH_SIZE : int = Field(default = 32, description = "Batch size for embedding generation")
|
| 60 |
+
|
| 61 |
+
# Chunking Settings
|
| 62 |
+
# Fixed chunking
|
| 63 |
+
FIXED_CHUNK_SIZE : int = Field(default = 512, description = "Fixed chunk size in tokens")
|
| 64 |
+
FIXED_CHUNK_OVERLAP : int = Field(default = 25, description = "Overlap between chunks")
|
| 65 |
+
|
| 66 |
+
# Semantic chunking
|
| 67 |
+
SEMANTIC_BREAKPOINT_THRESHOLD : float = Field(default = 0.80, description = "Percentile for semantic breakpoints")
|
| 68 |
+
|
| 69 |
+
# Hierarchical chunking
|
| 70 |
+
PARENT_CHUNK_SIZE : int = Field(default = 2048, description = "Parent chunk size")
|
| 71 |
+
CHILD_CHUNK_SIZE : int = Field(default = 512, description = "Child chunk size")
|
| 72 |
+
|
| 73 |
+
# Adaptive thresholds
|
| 74 |
+
SMALL_DOC_THRESHOLD : int = Field(default = 1000, description = "Token threshold for fixed chunking")
|
| 75 |
+
LARGE_DOC_THRESHOLD : int = Field(default = 500000, description = "Token threshold for hierarchical chunking")
|
| 76 |
+
|
| 77 |
+
# Retrieval Settings
|
| 78 |
+
# Vector search
|
| 79 |
+
TOP_K_RETRIEVE : int = Field(default = 10, description = "Top chunks to retrieve")
|
| 80 |
+
TOP_K_FINAL : int = Field(default = 5, description = "Final chunks after reranking")
|
| 81 |
+
FAISS_NPROBE : int = Field(default = 10, description = "FAISS search probes")
|
| 82 |
+
|
| 83 |
+
# Hybrid search weights
|
| 84 |
+
VECTOR_WEIGHT : float = Field(default = 0.6, description = "Vector search weight")
|
| 85 |
+
BM25_WEIGHT : float = Field(default = 0.4, description = "BM25 search weight")
|
| 86 |
+
|
| 87 |
+
# BM25 parameters
|
| 88 |
+
BM25_K1 : float = Field(default = 1.5, description = "BM25 term saturation")
|
| 89 |
+
BM25_B : float = Field(default = 0.75, description = "BM25 length normalization")
|
| 90 |
+
|
| 91 |
+
# Reranking
|
| 92 |
+
ENABLE_RERANKING : bool = Field(default = True, description = "Enable cross-encoder reranking")
|
| 93 |
+
RERANKER_MODEL : str = Field(default = "cross-encoder/ms-marco-MiniLM-L-6-v2", description = "Reranker model")
|
| 94 |
+
|
| 95 |
+
# Storage Settings
|
| 96 |
+
VECTOR_STORE_DIR : Path = Field(default = Path("data/vector_store"), description = "FAISS index storage")
|
| 97 |
+
METADATA_DB_PATH : Path = Field(default = Path("data/metadata.db"), description = "SQLite metadata database")
|
| 98 |
+
|
| 99 |
+
# Backup
|
| 100 |
+
AUTO_BACKUP : bool = Field(default = True, description = "Enable auto-backup")
|
| 101 |
+
BACKUP_INTERVAL : int = Field(default = 1000, description = "Backup every N documents")
|
| 102 |
+
BACKUP_DIR : Path = Field(default = Path("data/backups"), description = "Backup directory")
|
| 103 |
+
|
| 104 |
+
# Cache Settings
|
| 105 |
+
ENABLE_CACHE : bool = Field(default = True, description = "Enable embedding cache")
|
| 106 |
+
CACHE_TYPE : Literal["memory", "redis"] = Field(default = "memory", description = "Cache backend")
|
| 107 |
+
CACHE_TTL : int = Field(default = 3600, description = "Cache TTL in seconds")
|
| 108 |
+
CACHE_MAX_SIZE : int = Field(default = 1000, description = "Max cached items")
|
| 109 |
+
|
| 110 |
+
# Redis (if used)
|
| 111 |
+
REDIS_HOST : str = Field(default = "localhost", description = "Redis host")
|
| 112 |
+
REDIS_PORT : int = Field(default = 6379, description = "Redis port")
|
| 113 |
+
REDIS_DB : int = Field(default = 0, description = "Redis database number")
|
| 114 |
+
|
| 115 |
+
# Logging Settings
|
| 116 |
+
LOG_LEVEL : Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(default = "INFO", description = "Logging level")
|
| 117 |
+
LOG_DIR : Path = Field(default = Path("logs"), description = "Log file directory")
|
| 118 |
+
LOG_FORMAT : str = Field(default = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", description = "Log format string")
|
| 119 |
+
LOG_ROTATION : str = Field(default = "500 MB", description = "Log rotation size")
|
| 120 |
+
LOG_RETENTION : str = Field(default = "30 days", description = "Log retention period")
|
| 121 |
+
|
| 122 |
+
# Evaluation Settings
|
| 123 |
+
ENABLE_RAGAS : bool = Field(default = True, description = "Enable Ragas evaluation")
|
| 124 |
+
RAGAS_ENABLE_GROUND_TRUTH : bool = Field(default = False, description = "Enable RAGAS metrics requiring ground truth")
|
| 125 |
+
RAGAS_METRICS : list[str] = Field(default = ["answer_relevancy", "faithfulness", "context_utilization", "context_relevancy"], description = "Ragas metrics to compute (base metrics without ground truth)")
|
| 126 |
+
RAGAS_GROUND_TRUTH_METRICS : list[str] = Field(default = ["context_precision", "context_recall", "answer_similarity", "answer_correctness"], description = "Ragas metrics requiring ground truth")
|
| 127 |
+
RAGAS_EVALUATION_TIMEOUT : int = Field(default = 60, description = "RAGAS evaluation timeout in seconds")
|
| 128 |
+
RAGAS_BATCH_SIZE : int = Field(default = 10, description = "Batch size for RAGAS evaluations")
|
| 129 |
+
|
| 130 |
+
# Web Scraping Settings (for future)
|
| 131 |
+
SCRAPING_ENABLED : bool = Field(default = False, description = "Enable web scraping")
|
| 132 |
+
USER_AGENT : str = Field(default = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", description = "User agent for scraping")
|
| 133 |
+
REQUEST_DELAY : float = Field(default = 2.0, description = "Delay between requests (seconds)")
|
| 134 |
+
MAX_RETRIES : int = Field(default = 3, description = "Max scraping retries")
|
| 135 |
+
|
| 136 |
+
# Performance Settings
|
| 137 |
+
MAX_WORKERS : int = Field(default = 4, description = "Max parallel workers")
|
| 138 |
+
ASYNC_BATCH_SIZE : int = Field(default = 10, description = "Async batch size")
|
| 139 |
+
|
| 140 |
+
# Security Settings
|
| 141 |
+
ENABLE_AUTH : bool = Field(default = False, description = "Enable authentication")
|
| 142 |
+
SECRET_KEY : str = Field(default = os.getenv("SECRET_KEY", "dev-key-change-in-production"))
|
| 143 |
+
|
| 144 |
+
FIXED_CHUNK_STRATEGY : str = Field(default = "fixed", description = "Default chunking strategy")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class Config:
|
| 148 |
+
env_file = ".env"
|
| 149 |
+
env_file_encoding = "utf-8"
|
| 150 |
+
case_sensitive = True
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@field_validator("UPLOAD_DIR", "VECTOR_STORE_DIR", "LOG_DIR", "BACKUP_DIR", "METADATA_DB_PATH")
|
| 154 |
+
@classmethod
|
| 155 |
+
def create_directories(cls, v: Path) -> Path:
|
| 156 |
+
"""
|
| 157 |
+
Ensure directories exist
|
| 158 |
+
"""
|
| 159 |
+
if v.suffix: # It's a file path (like metadata.db)
|
| 160 |
+
v.parent.mkdir(parents = True, exist_ok = True)
|
| 161 |
+
|
| 162 |
+
else: # It's a directory
|
| 163 |
+
v.mkdir(parents = True, exist_ok = True)
|
| 164 |
+
|
| 165 |
+
return v
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@field_validator("VECTOR_WEIGHT", "BM25_WEIGHT")
|
| 169 |
+
@classmethod
|
| 170 |
+
def validate_weights_sum(cls, v: float, info) -> float:
|
| 171 |
+
"""
|
| 172 |
+
Ensure vector and BM25 weights are valid
|
| 173 |
+
"""
|
| 174 |
+
if ((info.field_name == "BM25_WEIGHT") and ("VECTOR_WEIGHT" in info.data)):
|
| 175 |
+
vector_weight = info.data["VECTOR_WEIGHT"]
|
| 176 |
+
|
| 177 |
+
if (abs(vector_weight + v - 1.0) > 0.01):
|
| 178 |
+
raise ValueError(f"VECTOR_WEIGHT ({vector_weight}) + BM25_WEIGHT ({v}) must sum to 1.0")
|
| 179 |
+
|
| 180 |
+
return v
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def max_file_size_bytes(self) -> int:
|
| 185 |
+
"""
|
| 186 |
+
Convert MB to bytes
|
| 187 |
+
"""
|
| 188 |
+
return self.MAX_FILE_SIZE_MB * 1024 * 1024
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@property
|
| 192 |
+
def is_cuda_available(self) -> bool:
|
| 193 |
+
"""
|
| 194 |
+
Check if CUDA device is requested and available
|
| 195 |
+
"""
|
| 196 |
+
if self.EMBEDDING_DEVICE == "cuda":
|
| 197 |
+
try:
|
| 198 |
+
return torch.cuda.is_available()
|
| 199 |
+
|
| 200 |
+
except ImportError:
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_ollama_url(self, endpoint: str) -> str:
|
| 207 |
+
"""
|
| 208 |
+
Construct full Ollama API URL
|
| 209 |
+
"""
|
| 210 |
+
return f"{self.OLLAMA_BASE_URL.rstrip('/')}/{endpoint.lstrip('/')}"
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@classmethod
|
| 214 |
+
def get_timestamp_ms(cls) -> int:
|
| 215 |
+
"""
|
| 216 |
+
Get current timestamp in milliseconds
|
| 217 |
+
"""
|
| 218 |
+
return int(time.time() * 1000)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def summary(self) -> dict:
|
| 222 |
+
"""
|
| 223 |
+
Get configuration summary (excluding sensitive data)
|
| 224 |
+
"""
|
| 225 |
+
return {"app_name" : self.APP_NAME,
|
| 226 |
+
"version" : self.APP_VERSION,
|
| 227 |
+
"ollama_model" : self.OLLAMA_MODEL,
|
| 228 |
+
"embedding_model" : self.EMBEDDING_MODEL,
|
| 229 |
+
"embedding_device" : self.EMBEDDING_DEVICE,
|
| 230 |
+
"max_file_size_mb" : self.MAX_FILE_SIZE_MB,
|
| 231 |
+
"allowed_extensions" : self.ALLOWED_EXTENSIONS,
|
| 232 |
+
"chunking_strategy" : {"small_threshold" : self.SMALL_DOC_THRESHOLD, "large_threshold" : self.LARGE_DOC_THRESHOLD},
|
| 233 |
+
"retrieval" : {"top_k" : self.TOP_K_RETRIEVE, "hybrid_weights" : {"vector" : self.VECTOR_WEIGHT, "bm25" : self.BM25_WEIGHT}},
|
| 234 |
+
"evaluation" : {"ragas_enabled" : self.ENABLE_RAGAS, "ragas_ground_truth" : self.RAGAS_ENABLE_GROUND_TRUTH, "ragas_metrics" : self.RAGAS_METRICS},
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Global settings instance
|
| 239 |
+
settings = Settings()
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# Convenience function for getting settings
|
| 243 |
+
def get_settings() -> Settings:
|
| 244 |
+
"""
|
| 245 |
+
Get global settings instance
|
| 246 |
+
"""
|
| 247 |
+
return settings
|
docs/API.md
ADDED
|
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AI Universal Knowledge Ingestion System - API Documentation
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
The AI Universal Knowledge Ingestion System is a production-grade RAG (Retrieval-Augmented Generation) platform that enables organizations to unlock knowledge from multiple document sources while maintaining complete data privacy and eliminating API costs.
|
| 5 |
+
|
| 6 |
+
**Base URL:** http://localhost:8000 (or your deployed domain)
|
| 7 |
+
|
| 8 |
+
**API Version:** v1.0.0
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Authentication
|
| 13 |
+
Currently, the API operates without authentication for local development. For production deployments, consider implementing:
|
| 14 |
+
|
| 15 |
+
- API Key Authentication
|
| 16 |
+
- JWT Tokens
|
| 17 |
+
- OAuth2
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Rate Limiting
|
| 22 |
+
- Default: 100 requests per minute per IP
|
| 23 |
+
- File Uploads: 10MB max per file, 50MB total per request
|
| 24 |
+
- Chat Endpoints: 30 requests per minute per session
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Response Format
|
| 29 |
+
|
| 30 |
+
All API responses follow this standard format:
|
| 31 |
+
|
| 32 |
+
```json
|
| 33 |
+
{
|
| 34 |
+
"success": true,
|
| 35 |
+
"data": {...},
|
| 36 |
+
"message": "Operation completed successfully",
|
| 37 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 38 |
+
}
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Error responses:
|
| 42 |
+
|
| 43 |
+
```json
|
| 44 |
+
{
|
| 45 |
+
"success": false,
|
| 46 |
+
"error": "Error Type",
|
| 47 |
+
"message": "Human-readable error message",
|
| 48 |
+
"detail": {...},
|
| 49 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 50 |
+
}
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## System Management Endpoints
|
| 56 |
+
|
| 57 |
+
### Get System Health
|
| 58 |
+
|
| 59 |
+
**GET** `/api/health`
|
| 60 |
+
|
| 61 |
+
Check system health and component status.
|
| 62 |
+
|
| 63 |
+
**Response:**
|
| 64 |
+
```json
|
| 65 |
+
{
|
| 66 |
+
"status": "healthy",
|
| 67 |
+
"timestamp": "2024-01-15T10:30:00Z",
|
| 68 |
+
"version": "1.0.0",
|
| 69 |
+
"components": {
|
| 70 |
+
"vector_store": true,
|
| 71 |
+
"llm": true,
|
| 72 |
+
"embeddings": true,
|
| 73 |
+
"retrieval": true,
|
| 74 |
+
"generation": true
|
| 75 |
+
},
|
| 76 |
+
"details": {
|
| 77 |
+
"overall": "healthy",
|
| 78 |
+
"vector_store": true,
|
| 79 |
+
"llm": true,
|
| 80 |
+
"embeddings": true,
|
| 81 |
+
"retrieval": true,
|
| 82 |
+
"generation": true
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Get System Information
|
| 88 |
+
|
| 89 |
+
**GET** `/api/system-info`
|
| 90 |
+
|
| 91 |
+
Get comprehensive system status and statistics.
|
| 92 |
+
|
| 93 |
+
**Response:**
|
| 94 |
+
```json
|
| 95 |
+
{
|
| 96 |
+
"system_state": {
|
| 97 |
+
"is_ready": true,
|
| 98 |
+
"processing_status": "ready",
|
| 99 |
+
"total_documents": 15,
|
| 100 |
+
"active_sessions": 3
|
| 101 |
+
},
|
| 102 |
+
"configuration": {
|
| 103 |
+
"inference_model": "mistral:7b",
|
| 104 |
+
"embedding_model": "BAAI/bge-small-en-v1.5",
|
| 105 |
+
"retrieval_top_k": 10,
|
| 106 |
+
"vector_weight": 0.6,
|
| 107 |
+
"bm25_weight": 0.4,
|
| 108 |
+
"temperature": 0.1,
|
| 109 |
+
"enable_reranking": true
|
| 110 |
+
},
|
| 111 |
+
"llm_provider": {
|
| 112 |
+
"provider": "ollama",
|
| 113 |
+
"model": "mistral:7b",
|
| 114 |
+
"status": "healthy"
|
| 115 |
+
},
|
| 116 |
+
"system_information": {
|
| 117 |
+
"vector_store_status": "Ready (145 chunks)",
|
| 118 |
+
"current_model": "mistral:7b",
|
| 119 |
+
"embedding_model": "BAAI/bge-small-en-v1.5",
|
| 120 |
+
"chunking_strategy": "adaptive",
|
| 121 |
+
"system_uptime_seconds": 3600
|
| 122 |
+
},
|
| 123 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 124 |
+
}
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## Document Management Endpoints
|
| 130 |
+
|
| 131 |
+
### Upload Files
|
| 132 |
+
|
| 133 |
+
**POST** `/api/upload`
|
| 134 |
+
|
| 135 |
+
Upload multiple documents for processing.
|
| 136 |
+
|
| 137 |
+
**Form Data:**
|
| 138 |
+
- `files`: List of files (PDF, DOCX, TXT, ZIP) - max 2GB total
|
| 139 |
+
|
| 140 |
+
**Supported Formats:**
|
| 141 |
+
- PDF Documents (.pdf)
|
| 142 |
+
- Microsoft Word (.docx, .doc)
|
| 143 |
+
- Text Files (.txt, .md)
|
| 144 |
+
- ZIP Archives (.zip) - automatic extraction
|
| 145 |
+
|
| 146 |
+
**Response:**
|
| 147 |
+
```json
|
| 148 |
+
{
|
| 149 |
+
"success": true,
|
| 150 |
+
"message": "Successfully uploaded 3 files",
|
| 151 |
+
"files": [
|
| 152 |
+
{
|
| 153 |
+
"filename": "document_20240115_103000.pdf",
|
| 154 |
+
"original_name": "quarterly_report.pdf",
|
| 155 |
+
"size": 1542890,
|
| 156 |
+
"upload_time": "2024-01-15T10:30:00Z",
|
| 157 |
+
"file_path": "/uploads/document_20240115_103000.pdf",
|
| 158 |
+
"status": "uploaded"
|
| 159 |
+
}
|
| 160 |
+
]
|
| 161 |
+
}
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
### Start Processing
|
| 165 |
+
|
| 166 |
+
**POST** `/api/start-processing`
|
| 167 |
+
|
| 168 |
+
Start processing uploaded documents through the RAG pipeline.
|
| 169 |
+
|
| 170 |
+
**Pipeline Stages:**
|
| 171 |
+
1. Document parsing and text extraction
|
| 172 |
+
2. Adaptive chunking (fixed/semantic/hierarchical)
|
| 173 |
+
3. Embedding generation with BGE model
|
| 174 |
+
4. Vector indexing (FAISS + BM25)
|
| 175 |
+
5. Knowledge base compilation
|
| 176 |
+
|
| 177 |
+
**Response:**
|
| 178 |
+
```json
|
| 179 |
+
{
|
| 180 |
+
"success": true,
|
| 181 |
+
"message": "Processing completed successfully",
|
| 182 |
+
"status": "ready",
|
| 183 |
+
"documents_processed": 3,
|
| 184 |
+
"total_chunks": 245,
|
| 185 |
+
"chunking_statistics": {
|
| 186 |
+
"adaptive": 120,
|
| 187 |
+
"semantic": 80,
|
| 188 |
+
"hierarchical": 45
|
| 189 |
+
},
|
| 190 |
+
"index_stats": {
|
| 191 |
+
"total_chunks_indexed": 245,
|
| 192 |
+
"vector_index_size": 245,
|
| 193 |
+
"bm25_indexed": true,
|
| 194 |
+
"metadata_stored": true
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
### Get Processing Status
|
| 200 |
+
|
| 201 |
+
**GET** `/api/processing-status`
|
| 202 |
+
|
| 203 |
+
Monitor real-time processing progress.
|
| 204 |
+
|
| 205 |
+
**Response:**
|
| 206 |
+
```json
|
| 207 |
+
{
|
| 208 |
+
"status": "processing",
|
| 209 |
+
"progress": 65,
|
| 210 |
+
"current_step": "Generating embeddings for quarterly_report.pdf...",
|
| 211 |
+
"processed": 2,
|
| 212 |
+
"total": 3,
|
| 213 |
+
"details": {
|
| 214 |
+
"chunks_processed": 156,
|
| 215 |
+
"embeddings_generated": 156
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
---
|
| 221 |
+
|
| 222 |
+
## Chat & Query Endpoints
|
| 223 |
+
|
| 224 |
+
### Chat with Documents
|
| 225 |
+
|
| 226 |
+
**POST** `/api/chat`
|
| 227 |
+
|
| 228 |
+
Query your knowledge base with natural language questions. Includes automatic RAGAS evaluation if enabled.
|
| 229 |
+
|
| 230 |
+
**Request Body (JSON):**
|
| 231 |
+
```json
|
| 232 |
+
{
|
| 233 |
+
"message": "What were the Q3 revenue trends?",
|
| 234 |
+
"session_id": "session_1705314600"
|
| 235 |
+
}
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
**Response:**
|
| 239 |
+
```json
|
| 240 |
+
{
|
| 241 |
+
"session_id": "session_1705314600",
|
| 242 |
+
"response": "Based on the Q3 financial report, revenue increased by 15% quarter-over-quarter, reaching $45 million. The growth was primarily driven by enterprise sales and new market expansion. [1][2]",
|
| 243 |
+
"sources": [
|
| 244 |
+
{
|
| 245 |
+
"rank": 1,
|
| 246 |
+
"score": 0.894,
|
| 247 |
+
"document_id": "doc_1705300000_abc123",
|
| 248 |
+
"chunk_id": "chunk_doc_1705300000_abc123_0",
|
| 249 |
+
"text_preview": "Q3 Financial Highlights: Revenue growth of 15% QoQ reaching $45M...",
|
| 250 |
+
"page_number": 7,
|
| 251 |
+
"section_title": "Financial Performance",
|
| 252 |
+
"retrieval_method": "hybrid"
|
| 253 |
+
}
|
| 254 |
+
],
|
| 255 |
+
"metrics": {
|
| 256 |
+
"retrieval_time": 245,
|
| 257 |
+
"generation_time": 3100,
|
| 258 |
+
"total_time": 3345,
|
| 259 |
+
"chunks_retrieved": 8,
|
| 260 |
+
"chunks_used": 3,
|
| 261 |
+
"tokens_used": 487
|
| 262 |
+
},
|
| 263 |
+
"ragas_metrics": {
|
| 264 |
+
"answer_relevancy": 0.89,
|
| 265 |
+
"faithfulness": 0.94,
|
| 266 |
+
"context_utilization": 0.87,
|
| 267 |
+
"context_relevancy": 0.91,
|
| 268 |
+
"overall_score": 0.90,
|
| 269 |
+
"context_precision": null,
|
| 270 |
+
"context_recall": null,
|
| 271 |
+
"answer_similarity": null,
|
| 272 |
+
"answer_correctness": null
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
**Note:** Ground truth metrics (context_precision, context_recall, answer_similarity, answer_correctness) are null unless ground truth is provided and `RAGAS_ENABLE_GROUND_TRUTH=True`.
|
| 278 |
+
|
| 279 |
+
### Export Chat History
|
| 280 |
+
|
| 281 |
+
**GET** `/api/export-chat/{session_id}`
|
| 282 |
+
|
| 283 |
+
Export conversation history for analysis or reporting.
|
| 284 |
+
|
| 285 |
+
**Parameters:**
|
| 286 |
+
- `session_id`: string (required) - Session identifier
|
| 287 |
+
- `format`: string (optional) - Export format: `json` (default) or `csv`
|
| 288 |
+
|
| 289 |
+
**Response (JSON):**
|
| 290 |
+
```json
|
| 291 |
+
{
|
| 292 |
+
"session_id": "session_1705314600",
|
| 293 |
+
"export_time": "2024-01-15T11:00:00Z",
|
| 294 |
+
"total_messages": 5,
|
| 295 |
+
"history": [
|
| 296 |
+
{
|
| 297 |
+
"query": "What was the Q3 revenue growth?",
|
| 298 |
+
"response": "Revenue increased by 15% quarter-over-quarter...",
|
| 299 |
+
"sources": [...],
|
| 300 |
+
"timestamp": "2024-01-15T10:30:00Z",
|
| 301 |
+
"metrics": {
|
| 302 |
+
"total_time": 3345
|
| 303 |
+
},
|
| 304 |
+
"ragas_metrics": {
|
| 305 |
+
"answer_relevancy": 0.89,
|
| 306 |
+
"faithfulness": 0.94,
|
| 307 |
+
"overall_score": 0.90
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
]
|
| 311 |
+
}
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
---
|
| 315 |
+
|
| 316 |
+
## RAGAS Evaluation Endpoints
|
| 317 |
+
|
| 318 |
+
### Get RAGAS History
|
| 319 |
+
|
| 320 |
+
**GET** `/api/ragas/history`
|
| 321 |
+
|
| 322 |
+
Get complete RAGAS evaluation history for the current session.
|
| 323 |
+
|
| 324 |
+
**Response:**
|
| 325 |
+
```json
|
| 326 |
+
{
|
| 327 |
+
"success": true,
|
| 328 |
+
"total_count": 25,
|
| 329 |
+
"statistics": {
|
| 330 |
+
"total_evaluations": 25,
|
| 331 |
+
"avg_answer_relevancy": 0.876,
|
| 332 |
+
"avg_faithfulness": 0.912,
|
| 333 |
+
"avg_context_utilization": 0.845,
|
| 334 |
+
"avg_context_relevancy": 0.889,
|
| 335 |
+
"avg_overall_score": 0.881,
|
| 336 |
+
"avg_retrieval_time_ms": 235,
|
| 337 |
+
"avg_generation_time_ms": 3250,
|
| 338 |
+
"avg_total_time_ms": 3485,
|
| 339 |
+
"min_score": 0.723,
|
| 340 |
+
"max_score": 0.967,
|
| 341 |
+
"std_dev": 0.089,
|
| 342 |
+
"session_start": "2024-01-15T09:00:00Z",
|
| 343 |
+
"last_updated": "2024-01-15T11:00:00Z"
|
| 344 |
+
},
|
| 345 |
+
"history": [
|
| 346 |
+
{
|
| 347 |
+
"query": "What were the Q3 revenue trends?",
|
| 348 |
+
"answer": "Revenue increased by 15%...",
|
| 349 |
+
"contexts": ["Q3 Financial Highlights...", "Revenue breakdown..."],
|
| 350 |
+
"timestamp": "2024-01-15T10:30:00Z",
|
| 351 |
+
"answer_relevancy": 0.89,
|
| 352 |
+
"faithfulness": 0.94,
|
| 353 |
+
"context_utilization": 0.87,
|
| 354 |
+
"context_relevancy": 0.91,
|
| 355 |
+
"overall_score": 0.90,
|
| 356 |
+
"retrieval_time_ms": 245,
|
| 357 |
+
"generation_time_ms": 3100,
|
| 358 |
+
"total_time_ms": 3345,
|
| 359 |
+
"chunks_retrieved": 8
|
| 360 |
+
}
|
| 361 |
+
]
|
| 362 |
+
}
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
### Get RAGAS Statistics
|
| 366 |
+
|
| 367 |
+
**GET** `/api/ragas/statistics`
|
| 368 |
+
|
| 369 |
+
Get aggregate RAGAS statistics for the current session.
|
| 370 |
+
|
| 371 |
+
**Response:**
|
| 372 |
+
```json
|
| 373 |
+
{
|
| 374 |
+
"success": true,
|
| 375 |
+
"statistics": {
|
| 376 |
+
"total_evaluations": 25,
|
| 377 |
+
"avg_answer_relevancy": 0.876,
|
| 378 |
+
"avg_faithfulness": 0.912,
|
| 379 |
+
"avg_context_utilization": 0.845,
|
| 380 |
+
"avg_context_relevancy": 0.889,
|
| 381 |
+
"avg_overall_score": 0.881,
|
| 382 |
+
"avg_retrieval_time_ms": 235,
|
| 383 |
+
"avg_generation_time_ms": 3250,
|
| 384 |
+
"avg_total_time_ms": 3485,
|
| 385 |
+
"min_score": 0.723,
|
| 386 |
+
"max_score": 0.967,
|
| 387 |
+
"std_dev": 0.089,
|
| 388 |
+
"session_start": "2024-01-15T09:00:00Z",
|
| 389 |
+
"last_updated": "2024-01-15T11:00:00Z"
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
### Clear RAGAS History
|
| 395 |
+
|
| 396 |
+
**POST** `/api/ragas/clear`
|
| 397 |
+
|
| 398 |
+
Clear all RAGAS evaluation history and start a new session.
|
| 399 |
+
|
| 400 |
+
**Response:**
|
| 401 |
+
```json
|
| 402 |
+
{
|
| 403 |
+
"success": true,
|
| 404 |
+
"message": "RAGAS evaluation history cleared, new session started"
|
| 405 |
+
}
|
| 406 |
+
```
|
| 407 |
+
|
| 408 |
+
### Export RAGAS Data
|
| 409 |
+
|
| 410 |
+
**GET** `/api/ragas/export`
|
| 411 |
+
|
| 412 |
+
Export all RAGAS evaluation data as JSON.
|
| 413 |
+
|
| 414 |
+
**Response:** JSON file download containing:
|
| 415 |
+
```json
|
| 416 |
+
{
|
| 417 |
+
"export_timestamp": "2024-01-15T11:00:00Z",
|
| 418 |
+
"total_evaluations": 25,
|
| 419 |
+
"statistics": {...},
|
| 420 |
+
"evaluations": [...],
|
| 421 |
+
"ground_truth_enabled": false
|
| 422 |
+
}
|
| 423 |
+
```
|
| 424 |
+
|
| 425 |
+
### Get RAGAS Configuration
|
| 426 |
+
|
| 427 |
+
**GET** `/api/ragas/config`
|
| 428 |
+
|
| 429 |
+
Get current RAGAS configuration settings.
|
| 430 |
+
|
| 431 |
+
**Response:**
|
| 432 |
+
```json
|
| 433 |
+
{
|
| 434 |
+
"enabled": true,
|
| 435 |
+
"ground_truth_enabled": false,
|
| 436 |
+
"base_metrics": [
|
| 437 |
+
"answer_relevancy",
|
| 438 |
+
"faithfulness",
|
| 439 |
+
"context_utilization",
|
| 440 |
+
"context_relevancy"
|
| 441 |
+
],
|
| 442 |
+
"ground_truth_metrics": [
|
| 443 |
+
"context_precision",
|
| 444 |
+
"context_recall",
|
| 445 |
+
"answer_similarity",
|
| 446 |
+
"answer_correctness"
|
| 447 |
+
],
|
| 448 |
+
"evaluation_timeout": 60,
|
| 449 |
+
"batch_size": 10
|
| 450 |
+
}
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
---
|
| 454 |
+
|
| 455 |
+
## Analytics Endpoints
|
| 456 |
+
|
| 457 |
+
### Get System Analytics
|
| 458 |
+
|
| 459 |
+
**GET** `/api/analytics`
|
| 460 |
+
|
| 461 |
+
Get comprehensive system analytics and performance metrics with caching.
|
| 462 |
+
|
| 463 |
+
**Response:**
|
| 464 |
+
```json
|
| 465 |
+
{
|
| 466 |
+
"performance_metrics": {
|
| 467 |
+
"avg_response_time": 3485,
|
| 468 |
+
"min_response_time": 2100,
|
| 469 |
+
"max_response_time": 8900,
|
| 470 |
+
"total_queries": 127,
|
| 471 |
+
"queries_last_hour": 23,
|
| 472 |
+
"p95_response_time": 7200
|
| 473 |
+
},
|
| 474 |
+
"quality_metrics": {
|
| 475 |
+
"answer_relevancy": 0.876,
|
| 476 |
+
"faithfulness": 0.912,
|
| 477 |
+
"context_precision": 0.845,
|
| 478 |
+
"context_recall": null,
|
| 479 |
+
"overall_score": 0.878,
|
| 480 |
+
"avg_sources_per_query": 4.2,
|
| 481 |
+
"queries_with_sources": 125,
|
| 482 |
+
"confidence": "high",
|
| 483 |
+
"metrics_available": true
|
| 484 |
+
},
|
| 485 |
+
"system_information": {
|
| 486 |
+
"vector_store_status": "Ready (245 chunks)",
|
| 487 |
+
"current_model": "mistral:7b",
|
| 488 |
+
"embedding_model": "BAAI/bge-small-en-v1.5",
|
| 489 |
+
"chunking_strategy": "adaptive",
|
| 490 |
+
"system_uptime_seconds": 7200,
|
| 491 |
+
"last_updated": "2024-01-15T11:00:00Z"
|
| 492 |
+
},
|
| 493 |
+
"health_status": {
|
| 494 |
+
"overall": "healthy",
|
| 495 |
+
"llm": true,
|
| 496 |
+
"vector_store": true,
|
| 497 |
+
"embeddings": true,
|
| 498 |
+
"retrieval": true,
|
| 499 |
+
"generation": true
|
| 500 |
+
},
|
| 501 |
+
"chunking_statistics": {
|
| 502 |
+
"primary_strategy": "semantic",
|
| 503 |
+
"total_chunks": 245,
|
| 504 |
+
"strategies_used": {
|
| 505 |
+
"fixed": 98,
|
| 506 |
+
"semantic": 112,
|
| 507 |
+
"hierarchical": 35
|
| 508 |
+
}
|
| 509 |
+
},
|
| 510 |
+
"document_statistics": {
|
| 511 |
+
"total_documents": 15,
|
| 512 |
+
"total_chunks": 245,
|
| 513 |
+
"uploaded_files": 15,
|
| 514 |
+
"total_file_size_bytes": 52428800,
|
| 515 |
+
"total_file_size_mb": 50.0,
|
| 516 |
+
"avg_chunks_per_document": 16.3
|
| 517 |
+
},
|
| 518 |
+
"session_statistics": {
|
| 519 |
+
"total_sessions": 8,
|
| 520 |
+
"total_messages": 127,
|
| 521 |
+
"avg_messages_per_session": 15.9
|
| 522 |
+
},
|
| 523 |
+
"index_statistics": {
|
| 524 |
+
"total_chunks_indexed": 245,
|
| 525 |
+
"vector_index_size": 245,
|
| 526 |
+
"bm25_indexed": true
|
| 527 |
+
},
|
| 528 |
+
"calculated_at": "2024-01-15T11:00:00Z",
|
| 529 |
+
"cache_info": {
|
| 530 |
+
"from_cache": false,
|
| 531 |
+
"next_refresh_in": 30
|
| 532 |
+
}
|
| 533 |
+
}
|
| 534 |
+
```
|
| 535 |
+
|
| 536 |
+
### Refresh Analytics Cache
|
| 537 |
+
|
| 538 |
+
**GET** `/api/analytics/refresh`
|
| 539 |
+
|
| 540 |
+
Force refresh analytics cache and get fresh data.
|
| 541 |
+
|
| 542 |
+
**Response:**
|
| 543 |
+
```json
|
| 544 |
+
{
|
| 545 |
+
"success": true,
|
| 546 |
+
"message": "Analytics cache refreshed successfully",
|
| 547 |
+
"data": {
|
| 548 |
+
// Same structure as /api/analytics
|
| 549 |
+
}
|
| 550 |
+
}
|
| 551 |
+
```
|
| 552 |
+
|
| 553 |
+
### Get Detailed Analytics
|
| 554 |
+
|
| 555 |
+
**GET** `/api/analytics/detailed`
|
| 556 |
+
|
| 557 |
+
Get detailed analytics including session breakdowns and component performance.
|
| 558 |
+
|
| 559 |
+
**Response:**
|
| 560 |
+
```json
|
| 561 |
+
{
|
| 562 |
+
// All fields from /api/analytics, plus:
|
| 563 |
+
"detailed_sessions": [
|
| 564 |
+
{
|
| 565 |
+
"session_id": "session_1705314600",
|
| 566 |
+
"message_count": 12,
|
| 567 |
+
"first_message": "2024-01-15T09:00:00Z",
|
| 568 |
+
"last_message": "2024-01-15T10:45:00Z",
|
| 569 |
+
"total_response_time": 38500,
|
| 570 |
+
"avg_sources_per_query": 3.8
|
| 571 |
+
}
|
| 572 |
+
],
|
| 573 |
+
"component_performance": {
|
| 574 |
+
"retrieval": {
|
| 575 |
+
"avg_time_ms": 245,
|
| 576 |
+
"cache_hit_rate": 0.23
|
| 577 |
+
},
|
| 578 |
+
"embeddings": {
|
| 579 |
+
"model": "BAAI/bge-small-en-v1.5",
|
| 580 |
+
"dimension": 384,
|
| 581 |
+
"device": "cpu"
|
| 582 |
+
}
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
+
```
|
| 586 |
+
|
| 587 |
+
---
|
| 588 |
+
|
| 589 |
+
## Configuration Endpoints
|
| 590 |
+
|
| 591 |
+
### Get Current Configuration
|
| 592 |
+
|
| 593 |
+
**GET** `/api/configuration`
|
| 594 |
+
|
| 595 |
+
Retrieve current system configuration.
|
| 596 |
+
|
| 597 |
+
**Response:**
|
| 598 |
+
```json
|
| 599 |
+
{
|
| 600 |
+
"configuration": {
|
| 601 |
+
"inference_model": "mistral:7b",
|
| 602 |
+
"embedding_model": "BAAI/bge-small-en-v1.5",
|
| 603 |
+
"vector_weight": 0.6,
|
| 604 |
+
"bm25_weight": 0.4,
|
| 605 |
+
"temperature": 0.1,
|
| 606 |
+
"max_tokens": 1000,
|
| 607 |
+
"chunk_size": 512,
|
| 608 |
+
"chunk_overlap": 50,
|
| 609 |
+
"top_k_retrieve": 10,
|
| 610 |
+
"enable_reranking": true,
|
| 611 |
+
"is_ready": true,
|
| 612 |
+
"llm_healthy": true
|
| 613 |
+
},
|
| 614 |
+
"health": {
|
| 615 |
+
"overall": "healthy",
|
| 616 |
+
"llm": true,
|
| 617 |
+
"vector_store": true,
|
| 618 |
+
"embeddings": true,
|
| 619 |
+
"retrieval": true,
|
| 620 |
+
"generation": true
|
| 621 |
+
}
|
| 622 |
+
}
|
| 623 |
+
```
|
| 624 |
+
|
| 625 |
+
### Update Configuration
|
| 626 |
+
|
| 627 |
+
**POST** `/api/configuration`
|
| 628 |
+
|
| 629 |
+
Update system configuration parameters.
|
| 630 |
+
|
| 631 |
+
**Form Data:**
|
| 632 |
+
- `temperature`: float (0.0-1.0) - Generation temperature
|
| 633 |
+
- `max_tokens`: integer (100-4000) - Maximum response tokens
|
| 634 |
+
- `retrieval_top_k`: integer (1-50) - Number of chunks to retrieve
|
| 635 |
+
- `vector_weight`: float (0.0-1.0) - Weight for vector search
|
| 636 |
+
- `bm25_weight`: float (0.0-1.0) - Weight for keyword search
|
| 637 |
+
- `enable_reranking`: boolean - Enable cross-encoder reranking
|
| 638 |
+
- `session_id`: string (optional) - Session identifier for overrides
|
| 639 |
+
|
| 640 |
+
**Response:**
|
| 641 |
+
```json
|
| 642 |
+
{
|
| 643 |
+
"success": true,
|
| 644 |
+
"message": "Configuration updated successfully",
|
| 645 |
+
"updates": {
|
| 646 |
+
"temperature": 0.2,
|
| 647 |
+
"retrieval_top_k": 15
|
| 648 |
+
}
|
| 649 |
+
}
|
| 650 |
+
```
|
| 651 |
+
|
| 652 |
+
---
|
| 653 |
+
|
| 654 |
+
## Error Handling
|
| 655 |
+
|
| 656 |
+
### Common HTTP Status Codes
|
| 657 |
+
|
| 658 |
+
- **200** - Success
|
| 659 |
+
- **400** - Bad Request (invalid parameters)
|
| 660 |
+
- **404** - Resource Not Found
|
| 661 |
+
- **500** - Internal Server Error
|
| 662 |
+
- **503** - Service Unavailable (component not ready)
|
| 663 |
+
|
| 664 |
+
### Error Response Examples
|
| 665 |
+
|
| 666 |
+
#### RAGAS Evaluation Disabled:
|
| 667 |
+
```json
|
| 668 |
+
{
|
| 669 |
+
"success": false,
|
| 670 |
+
"error": "RAGASDisabled",
|
| 671 |
+
"message": "RAGAS evaluation is not enabled. Set ENABLE_RAGAS=True in settings.",
|
| 672 |
+
"detail": {
|
| 673 |
+
"current_setting": "ENABLE_RAGAS=False"
|
| 674 |
+
},
|
| 675 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 676 |
+
}
|
| 677 |
+
```
|
| 678 |
+
|
| 679 |
+
#### System Not Ready:
|
| 680 |
+
```json
|
| 681 |
+
{
|
| 682 |
+
"success": false,
|
| 683 |
+
"error": "SystemNotReady",
|
| 684 |
+
"message": "System not ready. Please upload and process documents first.",
|
| 685 |
+
"detail": {
|
| 686 |
+
"is_ready": false,
|
| 687 |
+
"documents_processed": 0
|
| 688 |
+
},
|
| 689 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 690 |
+
}
|
| 691 |
+
```
|
| 692 |
+
|
| 693 |
+
#### LLM Service Unavailable:
|
| 694 |
+
```json
|
| 695 |
+
{
|
| 696 |
+
"success": false,
|
| 697 |
+
"error": "LLMUnavailable",
|
| 698 |
+
"message": "LLM service unavailable. Please ensure Ollama is running.",
|
| 699 |
+
"detail": {
|
| 700 |
+
"llm_healthy": false,
|
| 701 |
+
"suggestion": "Run 'ollama serve' in a separate terminal"
|
| 702 |
+
},
|
| 703 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 704 |
+
}
|
| 705 |
+
```
|
| 706 |
+
|
| 707 |
+
---
|
| 708 |
+
|
| 709 |
+
## Best Practices
|
| 710 |
+
|
| 711 |
+
### 1. File Upload
|
| 712 |
+
|
| 713 |
+
- Use chunked upload for large files (>100MB)
|
| 714 |
+
- Compress documents into ZIP archives for multiple files
|
| 715 |
+
- Ensure documents are text-extractable (not scanned images without OCR)
|
| 716 |
+
|
| 717 |
+
### 2. Query Optimization
|
| 718 |
+
|
| 719 |
+
- Be specific and contextual in questions
|
| 720 |
+
- Use natural language - no special syntax required
|
| 721 |
+
- Break complex questions into multiple simpler queries
|
| 722 |
+
|
| 723 |
+
### 3. Session Management
|
| 724 |
+
|
| 725 |
+
- Reuse `session_id` for conversation continuity
|
| 726 |
+
- Sessions automatically expire after 24 hours of inactivity
|
| 727 |
+
- Export important conversations for long-term storage
|
| 728 |
+
|
| 729 |
+
### 4. RAGAS Evaluation
|
| 730 |
+
|
| 731 |
+
- Ensure OpenAI API key is configured for RAGAS to work
|
| 732 |
+
- Monitor evaluation metrics to track system quality
|
| 733 |
+
- Use analytics endpoints to identify quality trends
|
| 734 |
+
- Export evaluation data regularly for offline analysis
|
| 735 |
+
|
| 736 |
+
### 5. Performance Monitoring
|
| 737 |
+
|
| 738 |
+
- Monitor response times and token usage
|
| 739 |
+
- Use analytics endpoint for system health checks
|
| 740 |
+
- Set up alerts for quality metric degradation
|
| 741 |
+
- Enable caching for frequently accessed embeddings
|
| 742 |
+
|
| 743 |
+
### 6. Configuration Management
|
| 744 |
+
|
| 745 |
+
- Test configuration changes with a few queries first
|
| 746 |
+
- Monitor RAGAS metrics after configuration updates
|
| 747 |
+
- Use session-based overrides for experimentation
|
| 748 |
+
- Document optimal configurations for different use cases
|
| 749 |
+
|
| 750 |
+
---
|
| 751 |
+
|
| 752 |
+
## SDK Examples
|
| 753 |
+
|
| 754 |
+
### Python Client
|
| 755 |
+
|
| 756 |
+
```python
|
| 757 |
+
import requests
|
| 758 |
+
|
| 759 |
+
class KnowledgeBaseClient:
|
| 760 |
+
def __init__(self, base_url="http://localhost:8000"):
|
| 761 |
+
self.base_url = base_url
|
| 762 |
+
self.session_id = None
|
| 763 |
+
|
| 764 |
+
def upload_documents(self, file_paths):
|
| 765 |
+
files = [('files', open(fpath, 'rb')) for fpath in file_paths]
|
| 766 |
+
response = requests.post(f"{self.base_url}/api/upload", files=files)
|
| 767 |
+
return response.json()
|
| 768 |
+
|
| 769 |
+
def start_processing(self):
|
| 770 |
+
response = requests.post(f"{self.base_url}/api/start-processing")
|
| 771 |
+
return response.json()
|
| 772 |
+
|
| 773 |
+
def query(self, question):
|
| 774 |
+
data = {'message': question}
|
| 775 |
+
if self.session_id:
|
| 776 |
+
data['session_id'] = self.session_id
|
| 777 |
+
response = requests.post(f"{self.base_url}/api/chat", json=data)
|
| 778 |
+
result = response.json()
|
| 779 |
+
if not self.session_id:
|
| 780 |
+
self.session_id = result.get('session_id')
|
| 781 |
+
return result
|
| 782 |
+
|
| 783 |
+
def get_ragas_history(self):
|
| 784 |
+
response = requests.get(f"{self.base_url}/api/ragas/history")
|
| 785 |
+
return response.json()
|
| 786 |
+
|
| 787 |
+
def get_analytics(self):
|
| 788 |
+
response = requests.get(f"{self.base_url}/api/analytics")
|
| 789 |
+
return response.json()
|
| 790 |
+
|
| 791 |
+
# Usage
|
| 792 |
+
client = KnowledgeBaseClient()
|
| 793 |
+
|
| 794 |
+
# Upload and process
|
| 795 |
+
client.upload_documents(['report.pdf', 'contract.docx'])
|
| 796 |
+
client.start_processing()
|
| 797 |
+
|
| 798 |
+
# Query
|
| 799 |
+
result = client.query("What are the key findings?")
|
| 800 |
+
print(result['response'])
|
| 801 |
+
print(f"Quality Score: {result['ragas_metrics']['overall_score']}")
|
| 802 |
+
|
| 803 |
+
# Get analytics
|
| 804 |
+
analytics = client.get_analytics()
|
| 805 |
+
print(f"Avg Response Time: {analytics['performance_metrics']['avg_response_time']}ms")
|
| 806 |
+
```
|
| 807 |
+
|
| 808 |
+
### JavaScript Client
|
| 809 |
+
|
| 810 |
+
```javascript
|
| 811 |
+
class KnowledgeBaseClient {
|
| 812 |
+
constructor(baseUrl = 'http://localhost:8000') {
|
| 813 |
+
this.baseUrl = baseUrl;
|
| 814 |
+
this.sessionId = null;
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
async uploadDocuments(files) {
|
| 818 |
+
const formData = new FormData();
|
| 819 |
+
files.forEach(file => formData.append('files', file));
|
| 820 |
+
|
| 821 |
+
const response = await fetch(`${this.baseUrl}/api/upload`, {
|
| 822 |
+
method: 'POST',
|
| 823 |
+
body: formData
|
| 824 |
+
});
|
| 825 |
+
return await response.json();
|
| 826 |
+
}
|
| 827 |
+
|
| 828 |
+
async startProcessing() {
|
| 829 |
+
const response = await fetch(`${this.baseUrl}/api/start-processing`, {
|
| 830 |
+
method: 'POST'
|
| 831 |
+
});
|
| 832 |
+
return await response.json();
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
async query(question) {
|
| 836 |
+
const body = { message: question };
|
| 837 |
+
if (this.sessionId) body.session_id = this.sessionId;
|
| 838 |
+
|
| 839 |
+
const response = await fetch(`${this.baseUrl}/api/chat`, {
|
| 840 |
+
method: 'POST',
|
| 841 |
+
headers: { 'Content-Type': 'application/json' },
|
| 842 |
+
body: JSON.stringify(body)
|
| 843 |
+
});
|
| 844 |
+
|
| 845 |
+
const result = await response.json();
|
| 846 |
+
if (!this.sessionId) this.sessionId = result.session_id;
|
| 847 |
+
return result;
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
async getRagasHistory() {
|
| 851 |
+
const response = await fetch(`${this.baseUrl}/api/ragas/history`);
|
| 852 |
+
return await response.json();
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
async getAnalytics() {
|
| 856 |
+
const response = await fetch(`${this.baseUrl}/api/analytics`);
|
| 857 |
+
return await response.json();
|
| 858 |
+
}
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
// Usage
|
| 862 |
+
const client = new KnowledgeBaseClient();
|
| 863 |
+
|
| 864 |
+
// Query
|
| 865 |
+
const result = await client.query("What are the revenue trends?");
|
| 866 |
+
console.log(result.response);
|
| 867 |
+
console.log(`Quality: ${result.ragas_metrics.overall_score}`);
|
| 868 |
+
|
| 869 |
+
// Get RAGAS history
|
| 870 |
+
const history = await client.getRagasHistory();
|
| 871 |
+
console.log(`Total evaluations: ${history.total_count}`);
|
| 872 |
+
console.log(`Avg relevancy: ${history.statistics.avg_answer_relevancy}`);
|
| 873 |
+
```
|
| 874 |
+
|
| 875 |
+
---
|
| 876 |
+
|
| 877 |
+
## Support & Troubleshooting
|
| 878 |
+
|
| 879 |
+
### For API issues:
|
| 880 |
+
|
| 881 |
+
- Check system health endpoint first
|
| 882 |
+
- Verify document processing status
|
| 883 |
+
- Review error messages and suggested actions
|
| 884 |
+
- Check component readiness flags
|
| 885 |
+
|
| 886 |
+
### For RAGAS issues:
|
| 887 |
+
|
| 888 |
+
- Ensure OpenAI API key is configured
|
| 889 |
+
- Check RAGAS is enabled in settings
|
| 890 |
+
- Monitor evaluation timeout settings
|
| 891 |
+
- Review logs for detailed error messages
|
| 892 |
+
|
| 893 |
+
### For quality issues:
|
| 894 |
+
|
| 895 |
+
- Monitor RAGAS evaluation metrics
|
| 896 |
+
- Adjust retrieval and generation parameters
|
| 897 |
+
- Review source citations for context relevance
|
| 898 |
+
- Consider document preprocessing improvements
|
| 899 |
+
|
| 900 |
+
---
|
| 901 |
+
|
| 902 |
+
> **This API provides a complete RAG solution with multi-format document ingestion, intelligent retrieval, local LLM generation, and comprehensive RAGAS-based quality evaluation.**
|
| 903 |
+
|
| 904 |
+
---
|
docs/ARCHITECTURE.md
ADDED
|
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AI Universal Knowledge Ingestion System - Technical Architecture Document
|
| 2 |
+
|
| 3 |
+
## 1. System Overview
|
| 4 |
+
|
| 5 |
+
### 1.1 High-Level Architecture
|
| 6 |
+
|
| 7 |
+
```mermaid
|
| 8 |
+
graph TB
|
| 9 |
+
subgraph "Frontend Layer"
|
| 10 |
+
A[Web UI<br/>HTML/CSS/JS]
|
| 11 |
+
B[File Upload<br/>Drag & Drop]
|
| 12 |
+
C[Chat Interface<br/>Real-time]
|
| 13 |
+
D[Analytics Dashboard<br/>RAGAS Metrics]
|
| 14 |
+
end
|
| 15 |
+
|
| 16 |
+
subgraph "API Gateway"
|
| 17 |
+
E[FastAPI Server<br/>Python 3.11+]
|
| 18 |
+
end
|
| 19 |
+
|
| 20 |
+
subgraph "Core Processing Engine"
|
| 21 |
+
F[Ingestion Module]
|
| 22 |
+
G[Processing Module]
|
| 23 |
+
H[Retrieval Module]
|
| 24 |
+
I[Generation Module]
|
| 25 |
+
J[Evaluation Module]
|
| 26 |
+
end
|
| 27 |
+
|
| 28 |
+
subgraph "AI/ML Layer"
|
| 29 |
+
K[Ollama LLM<br/>Mistral-7B]
|
| 30 |
+
L[Embedding Model<br/>BGE-small-en]
|
| 31 |
+
M[FAISS Vector DB]
|
| 32 |
+
end
|
| 33 |
+
|
| 34 |
+
subgraph "Quality Assurance"
|
| 35 |
+
N[RAGAS Evaluator<br/>Real-time Metrics]
|
| 36 |
+
end
|
| 37 |
+
|
| 38 |
+
A --> E
|
| 39 |
+
E --> F
|
| 40 |
+
F --> G
|
| 41 |
+
G --> H
|
| 42 |
+
H --> I
|
| 43 |
+
I --> K
|
| 44 |
+
G --> L
|
| 45 |
+
L --> M
|
| 46 |
+
H --> M
|
| 47 |
+
I --> N
|
| 48 |
+
N --> E
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### 1.2 System Characteristics
|
| 52 |
+
|
| 53 |
+
| Aspect | Specification |
|
| 54 |
+
|--------|---------------|
|
| 55 |
+
| **Architecture Style** | Modular Microservices-inspired |
|
| 56 |
+
| **Deployment** | Docker Containerized |
|
| 57 |
+
| **Processing Model** | Async/Event-driven |
|
| 58 |
+
| **Data Flow** | Pipeline-based with Checkpoints |
|
| 59 |
+
| **Scalability** | Horizontal (Stateless API) + Vertical (GPU) |
|
| 60 |
+
| **Caching** | In-Memory LRU Cache |
|
| 61 |
+
| **Evaluation** | Real-time RAGAS Metrics |
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## 2. Component Architecture
|
| 66 |
+
|
| 67 |
+
### 2.1 Ingestion Module
|
| 68 |
+
|
| 69 |
+
```mermaid
|
| 70 |
+
flowchart TD
|
| 71 |
+
A[User Input] --> B{Input Type Detection}
|
| 72 |
+
|
| 73 |
+
B -->|PDF/DOCX| D[Document Parser]
|
| 74 |
+
B -->|ZIP| E[Archive Extractor]
|
| 75 |
+
|
| 76 |
+
subgraph D [Document Processing]
|
| 77 |
+
D1[PyPDF2<br/>PDF Text]
|
| 78 |
+
D2[python-docx<br/>Word Docs]
|
| 79 |
+
D3[EasyOCR<br/>Scanned PDFs]
|
| 80 |
+
end
|
| 81 |
+
|
| 82 |
+
subgraph E [Archive Handling]
|
| 83 |
+
E1[zipfile<br/>Extraction]
|
| 84 |
+
E2[Recursive Processing]
|
| 85 |
+
E3[Size Validation<br/>2GB Max]
|
| 86 |
+
end
|
| 87 |
+
|
| 88 |
+
D --> F[Text Cleaning]
|
| 89 |
+
E --> F
|
| 90 |
+
|
| 91 |
+
F --> G[Encoding Normalization]
|
| 92 |
+
G --> H[Structure Preservation]
|
| 93 |
+
H --> I[Output: Cleaned Text<br/>+ Metadata]
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
#### Ingestion Specifications:
|
| 97 |
+
|
| 98 |
+
| Component | Technology | Configuration | Limits |
|
| 99 |
+
|-----------|------------|---------------|---------|
|
| 100 |
+
| **PDF Parser** | PyPDF2 + EasyOCR | OCR: English+Multilingual | 1000 pages max |
|
| 101 |
+
| **Document Parser** | python-docx | Preserve formatting | 50MB per file |
|
| 102 |
+
| **Archive Handler** | zipfile | Recursion depth: 5 | 2GB total, 10k files |
|
| 103 |
+
|
| 104 |
+
### 2.2 Processing Module
|
| 105 |
+
|
| 106 |
+
#### 2.2.1 Adaptive Chunking Strategy
|
| 107 |
+
|
| 108 |
+
```mermaid
|
| 109 |
+
flowchart TD
|
| 110 |
+
A[Input Text] --> B[Token Count Analysis]
|
| 111 |
+
B --> C{Document Size}
|
| 112 |
+
|
| 113 |
+
C -->|<50K tokens| D[Fixed-Size Chunking]
|
| 114 |
+
C -->|50K-500K tokens| E[Semantic Chunking]
|
| 115 |
+
C -->|>500K tokens| F[Hierarchical Chunking]
|
| 116 |
+
|
| 117 |
+
subgraph D [Strategy 1: Fixed]
|
| 118 |
+
D1[Chunk Size: 512 tokens]
|
| 119 |
+
D2[Overlap: 50 tokens]
|
| 120 |
+
D3[Method: Simple sliding window]
|
| 121 |
+
end
|
| 122 |
+
|
| 123 |
+
subgraph E [Strategy 2: Semantic]
|
| 124 |
+
E1[Breakpoint: 95th percentile similarity]
|
| 125 |
+
E2[Method: LlamaIndex SemanticSplitter]
|
| 126 |
+
E3[Preserve: Section boundaries]
|
| 127 |
+
end
|
| 128 |
+
|
| 129 |
+
subgraph F [Strategy 3: Hierarchical]
|
| 130 |
+
F1[Parent: 2048 tokens]
|
| 131 |
+
F2[Child: 512 tokens]
|
| 132 |
+
F3[Retrieval: Child → Parent expansion]
|
| 133 |
+
end
|
| 134 |
+
|
| 135 |
+
D --> G[Chunk Metadata]
|
| 136 |
+
E --> G
|
| 137 |
+
F --> G
|
| 138 |
+
|
| 139 |
+
G --> H[Embedding Generation]
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
#### 2.2.2 Embedding Pipeline
|
| 143 |
+
|
| 144 |
+
```python
|
| 145 |
+
# Embedding Configuration
|
| 146 |
+
EMBEDDING_CONFIG = {
|
| 147 |
+
"model": "BAAI/bge-small-en-v1.5",
|
| 148 |
+
"dimensions": 384,
|
| 149 |
+
"batch_size": 32,
|
| 150 |
+
"normalize": True,
|
| 151 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 152 |
+
"max_sequence_length": 512
|
| 153 |
+
}
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
| Parameter | Value | Rationale |
|
| 157 |
+
|-----------|-------|-----------|
|
| 158 |
+
| **Model** | BAAI/bge-small-en-v1.5 | SOTA quality, 62.17 MTEB score |
|
| 159 |
+
| **Dimensions** | 384 | Optimal speed/accuracy balance |
|
| 160 |
+
| **Batch Size** | 32 | Memory efficiency on GPU/CPU |
|
| 161 |
+
| **Normalization** | L2 | Required for cosine similarity |
|
| 162 |
+
| **Speed** | 1000 docs/sec (CPU) | 10x faster than alternatives |
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
### 2.3 Storage Module Architecture
|
| 167 |
+
|
| 168 |
+
```mermaid
|
| 169 |
+
graph TB
|
| 170 |
+
subgraph "Storage Layer"
|
| 171 |
+
A[FAISS Vector Store]
|
| 172 |
+
B[BM25 Keyword Index]
|
| 173 |
+
C[SQLite Metadata]
|
| 174 |
+
D[LRU Cache<br/>In-Memory]
|
| 175 |
+
end
|
| 176 |
+
|
| 177 |
+
subgraph A [Vector Storage Architecture]
|
| 178 |
+
A1[IndexHNSW<br/>Large datasets]
|
| 179 |
+
A2[IndexIVFFlat<br/>Medium datasets]
|
| 180 |
+
A3[IndexFlatL2<br/>Small datasets]
|
| 181 |
+
end
|
| 182 |
+
|
| 183 |
+
subgraph B [Keyword Index]
|
| 184 |
+
B1[rank_bm25 Library]
|
| 185 |
+
B2[TF-IDF Weights]
|
| 186 |
+
B3[In-memory Index]
|
| 187 |
+
end
|
| 188 |
+
|
| 189 |
+
subgraph C [Metadata Management]
|
| 190 |
+
C1[Document Metadata]
|
| 191 |
+
C2[Chunk Relationships]
|
| 192 |
+
C3[User Sessions]
|
| 193 |
+
C4[RAGAS Evaluations]
|
| 194 |
+
end
|
| 195 |
+
|
| 196 |
+
subgraph D [Cache Layer]
|
| 197 |
+
D1[Query Embeddings]
|
| 198 |
+
D2[Frequent Results]
|
| 199 |
+
D3[LRU Eviction]
|
| 200 |
+
end
|
| 201 |
+
|
| 202 |
+
A --> E[Hybrid Retrieval]
|
| 203 |
+
B --> E
|
| 204 |
+
C --> E
|
| 205 |
+
D --> E
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
#### Vector Store Configuration
|
| 209 |
+
|
| 210 |
+
| Index Type | Use Case | Parameters | Performance |
|
| 211 |
+
|------------|----------|------------|-------------|
|
| 212 |
+
| **IndexFlatL2** | < 100K vectors | Exact search | O(n), High accuracy |
|
| 213 |
+
| **IndexIVFFlat** | 100K-1M vectors | nprobe: 10-20 | O(log n), Balanced |
|
| 214 |
+
| **IndexHNSW** | > 1M vectors | M: 16, efConstruction: 40 | O(log n), Fastest |
|
| 215 |
+
|
| 216 |
+
#### Caching Strategy
|
| 217 |
+
|
| 218 |
+
```python
|
| 219 |
+
# LRU Cache Configuration
|
| 220 |
+
CACHE_CONFIG = {
|
| 221 |
+
"max_size": 1000, # Maximum cached items
|
| 222 |
+
"ttl": 3600, # Time to live (seconds)
|
| 223 |
+
"eviction": "LRU", # Least Recently Used
|
| 224 |
+
"cache_embeddings": True,
|
| 225 |
+
"cache_results": True
|
| 226 |
+
}
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
**Benefits:**
|
| 230 |
+
- **Reduced latency**: 80% reduction for repeat queries
|
| 231 |
+
- **Resource efficiency**: Avoid re-computing embeddings
|
| 232 |
+
- **No external dependencies**: Pure Python implementation
|
| 233 |
+
- **Memory efficient**: LRU eviction prevents unbounded growth
|
| 234 |
+
|
| 235 |
+
---
|
| 236 |
+
|
| 237 |
+
### 2.4 Retrieval Module
|
| 238 |
+
|
| 239 |
+
#### 2.4.1 Hybrid Retrieval Pipeline
|
| 240 |
+
|
| 241 |
+
```mermaid
|
| 242 |
+
flowchart TD
|
| 243 |
+
A[User Query] --> B[Query Processing]
|
| 244 |
+
|
| 245 |
+
B --> C[Vector Embedding]
|
| 246 |
+
B --> D[Keyword Extraction]
|
| 247 |
+
|
| 248 |
+
C --> E[FAISS Search<br/>Top-K: 10]
|
| 249 |
+
D --> F[BM25 Search<br/>Top-K: 10]
|
| 250 |
+
|
| 251 |
+
E --> G[Reciprocal Rank Fusion]
|
| 252 |
+
F --> G
|
| 253 |
+
|
| 254 |
+
G --> H{Reranking Enabled?}
|
| 255 |
+
|
| 256 |
+
H -->|Yes| I[Cross-Encoder Reranking]
|
| 257 |
+
H -->|No| J[Final Top-5 Selection]
|
| 258 |
+
|
| 259 |
+
I --> J
|
| 260 |
+
|
| 261 |
+
J --> K[Context Assembly]
|
| 262 |
+
K --> L[Citation Formatting]
|
| 263 |
+
L --> M[Output: Context + Sources]
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
#### 2.4.2 Retrieval Algorithms
|
| 267 |
+
|
| 268 |
+
**Hybrid Fusion Formula:**
|
| 269 |
+
|
| 270 |
+
```text
|
| 271 |
+
RRF_score(doc) = vector_weight * (1 / (60 + vector_rank)) + bm25_weight * (1 / (60 + bm25_rank))
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
**Default Weights:**
|
| 275 |
+
- Vector Similarity: 60%
|
| 276 |
+
- BM25 Keyword: 40%
|
| 277 |
+
|
| 278 |
+
**BM25 Parameters:**
|
| 279 |
+
|
| 280 |
+
```python
|
| 281 |
+
BM25_CONFIG = {
|
| 282 |
+
"k1": 1.5, # Term frequency saturation
|
| 283 |
+
"b": 0.75, # Length normalization
|
| 284 |
+
"epsilon": 0.25 # Smoothing factor
|
| 285 |
+
}
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
|
| 290 |
+
### 2.5 Generation Module
|
| 291 |
+
|
| 292 |
+
#### 2.5.1 LLM Integration Architecture
|
| 293 |
+
|
| 294 |
+
```mermaid
|
| 295 |
+
graph TB
|
| 296 |
+
subgraph "Ollama Integration"
|
| 297 |
+
A[Ollama Server]
|
| 298 |
+
B[Mistral-7B-Instruct]
|
| 299 |
+
C[LLaMA-2-13B-Chat]
|
| 300 |
+
end
|
| 301 |
+
|
| 302 |
+
subgraph "Prompt Engineering"
|
| 303 |
+
D[System Prompt Template]
|
| 304 |
+
E[Context Formatting]
|
| 305 |
+
F[Citation Injection]
|
| 306 |
+
end
|
| 307 |
+
|
| 308 |
+
subgraph "Generation Control"
|
| 309 |
+
G[Temperature Controller]
|
| 310 |
+
H[Token Manager]
|
| 311 |
+
I[Streaming Handler]
|
| 312 |
+
end
|
| 313 |
+
|
| 314 |
+
A --> J[API Client]
|
| 315 |
+
B --> A
|
| 316 |
+
C --> A
|
| 317 |
+
|
| 318 |
+
D --> K[Prompt Assembly]
|
| 319 |
+
E --> K
|
| 320 |
+
F --> K
|
| 321 |
+
|
| 322 |
+
G --> L[Generation Parameters]
|
| 323 |
+
H --> L
|
| 324 |
+
I --> L
|
| 325 |
+
|
| 326 |
+
K --> M[LLM Request]
|
| 327 |
+
L --> M
|
| 328 |
+
M --> J
|
| 329 |
+
J --> N[Response Processing]
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
#### 2.5.2 LLM Configuration
|
| 333 |
+
|
| 334 |
+
| Parameter | Default Value | Range | Description |
|
| 335 |
+
|-----------|---------------|-------|-------------|
|
| 336 |
+
| **Model** | Mistral-7B-Instruct | - | Primary inference model |
|
| 337 |
+
| **Temperature** | 0.1 | 0.0-1.0 | Response creativity |
|
| 338 |
+
| **Max Tokens** | 1000 | 100-4000 | Response length limit |
|
| 339 |
+
| **Top-P** | 0.9 | 0.1-1.0 | Nucleus sampling |
|
| 340 |
+
| **Context Window** | 32K | - | Mistral model capacity |
|
| 341 |
+
|
| 342 |
+
---
|
| 343 |
+
|
| 344 |
+
### 2.6 RAGAS Evaluation Module
|
| 345 |
+
|
| 346 |
+
#### 2.6.1 RAGAS Evaluation Pipeline
|
| 347 |
+
|
| 348 |
+
```mermaid
|
| 349 |
+
flowchart LR
|
| 350 |
+
A[Query] --> B[Generated Answer]
|
| 351 |
+
C[Retrieved Context] --> B
|
| 352 |
+
|
| 353 |
+
B --> D[RAGAS Evaluator]
|
| 354 |
+
C --> D
|
| 355 |
+
|
| 356 |
+
D --> E[Answer Relevancy]
|
| 357 |
+
D --> F[Faithfulness]
|
| 358 |
+
D --> G[Context Utilization]
|
| 359 |
+
D --> H[Context Relevancy]
|
| 360 |
+
|
| 361 |
+
E --> I[Metrics Aggregation]
|
| 362 |
+
F --> I
|
| 363 |
+
G --> I
|
| 364 |
+
H --> I
|
| 365 |
+
|
| 366 |
+
I --> J[Analytics Dashboard]
|
| 367 |
+
I --> K[SQLite Storage]
|
| 368 |
+
I --> L[Session Statistics]
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
#### 2.6.2 Evaluation Metrics
|
| 372 |
+
|
| 373 |
+
| Metric | Target | Measurement Method | Importance |
|
| 374 |
+
|--------|--------|-------------------|------------|
|
| 375 |
+
| **Answer Relevancy** | > 0.85 | LLM-based evaluation | Core user satisfaction |
|
| 376 |
+
| **Faithfulness** | > 0.90 | Grounded in context check | Prevents hallucinations |
|
| 377 |
+
| **Context Utilization** | > 0.80 | How well context is used | Generation effectiveness |
|
| 378 |
+
| **Context Relevancy** | > 0.85 | Retrieved chunks relevance | Retrieval quality |
|
| 379 |
+
|
| 380 |
+
**Implementation Details:**
|
| 381 |
+
|
| 382 |
+
```python
|
| 383 |
+
# RAGAS Configuration
|
| 384 |
+
RAGAS_CONFIG = {
|
| 385 |
+
"enable_ragas": True,
|
| 386 |
+
"enable_ground_truth": False,
|
| 387 |
+
"base_metrics": [
|
| 388 |
+
"answer_relevancy",
|
| 389 |
+
"faithfulness",
|
| 390 |
+
"context_utilization",
|
| 391 |
+
"context_relevancy"
|
| 392 |
+
],
|
| 393 |
+
"ground_truth_metrics": [
|
| 394 |
+
"context_precision",
|
| 395 |
+
"context_recall",
|
| 396 |
+
"answer_similarity",
|
| 397 |
+
"answer_correctness"
|
| 398 |
+
],
|
| 399 |
+
"evaluation_timeout": 60,
|
| 400 |
+
"batch_size": 10
|
| 401 |
+
}
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
**Evaluation Flow:**
|
| 405 |
+
|
| 406 |
+
1. **Automatic Trigger**: Every query-response pair is evaluated
|
| 407 |
+
2. **Async Processing**: Evaluation runs in background (non-blocking)
|
| 408 |
+
3. **Storage**: Results stored in SQLite for analytics
|
| 409 |
+
4. **Aggregation**: Session-level statistics computed on-demand
|
| 410 |
+
5. **Export**: Full evaluation data available for download
|
| 411 |
+
|
| 412 |
+
---
|
| 413 |
+
|
| 414 |
+
## 3. Data Flow & Workflows
|
| 415 |
+
|
| 416 |
+
### 3.1 End-to-End Processing Pipeline
|
| 417 |
+
|
| 418 |
+
```mermaid
|
| 419 |
+
sequenceDiagram
|
| 420 |
+
participant U as User
|
| 421 |
+
participant F as Frontend
|
| 422 |
+
participant A as API Gateway
|
| 423 |
+
participant I as Ingestion
|
| 424 |
+
participant P as Processing
|
| 425 |
+
participant S as Storage
|
| 426 |
+
participant R as Retrieval
|
| 427 |
+
participant G as Generation
|
| 428 |
+
participant E as RAGAS Evaluator
|
| 429 |
+
|
| 430 |
+
U->>F: Upload Documents
|
| 431 |
+
F->>A: POST /api/upload
|
| 432 |
+
A->>I: Process Input Sources
|
| 433 |
+
|
| 434 |
+
Note over I: Parallel Processing
|
| 435 |
+
I->>I: Document Parsing
|
| 436 |
+
I->>P: Extracted Text + Metadata
|
| 437 |
+
|
| 438 |
+
P->>P: Adaptive Chunking
|
| 439 |
+
P->>P: Embedding Generation
|
| 440 |
+
P->>S: Store Vectors + Indexes
|
| 441 |
+
|
| 442 |
+
S->>F: Processing Complete
|
| 443 |
+
|
| 444 |
+
U->>F: Send Query
|
| 445 |
+
F->>A: POST /api/chat
|
| 446 |
+
|
| 447 |
+
A->>R: Hybrid Retrieval
|
| 448 |
+
R->>S: Vector + BM25 Search
|
| 449 |
+
S->>R: Top-K Chunks
|
| 450 |
+
|
| 451 |
+
R->>G: Context + Query
|
| 452 |
+
G->>G: LLM Generation
|
| 453 |
+
G->>F: Response + Citations
|
| 454 |
+
|
| 455 |
+
G->>E: Auto-evaluation (async)
|
| 456 |
+
E->>E: Compute RAGAS Metrics
|
| 457 |
+
E->>S: Store Evaluation Results
|
| 458 |
+
E->>F: Return Metrics
|
| 459 |
+
```
|
| 460 |
+
|
| 461 |
+
### 3.2 Real-time Query Processing
|
| 462 |
+
|
| 463 |
+
```mermaid
|
| 464 |
+
flowchart TD
|
| 465 |
+
A[User Query] --> B[Query Understanding]
|
| 466 |
+
B --> C[Check Cache]
|
| 467 |
+
|
| 468 |
+
C --> D{Cache Hit?}
|
| 469 |
+
D -->|Yes| E[Return Cached Embedding]
|
| 470 |
+
D -->|No| F[Generate Embedding]
|
| 471 |
+
|
| 472 |
+
F --> G[Store in Cache]
|
| 473 |
+
E --> H[FAISS Vector Search]
|
| 474 |
+
G --> H
|
| 475 |
+
|
| 476 |
+
B --> I[Keyword Extraction]
|
| 477 |
+
I --> J[BM25 Keyword Search]
|
| 478 |
+
|
| 479 |
+
H --> K[Reciprocal Rank Fusion]
|
| 480 |
+
J --> K
|
| 481 |
+
|
| 482 |
+
K --> L[Top-20 Candidates]
|
| 483 |
+
L --> M{Reranking Enabled?}
|
| 484 |
+
|
| 485 |
+
M -->|Yes| N[Cross-Encoder Reranking]
|
| 486 |
+
M -->|No| O[Select Top-5]
|
| 487 |
+
|
| 488 |
+
N --> O
|
| 489 |
+
O --> P[Context Assembly]
|
| 490 |
+
P --> Q[LLM Prompt Construction]
|
| 491 |
+
Q --> R[Ollama Generation]
|
| 492 |
+
R --> S[Citation Formatting]
|
| 493 |
+
S --> T[Response Streaming]
|
| 494 |
+
T --> U[User Display]
|
| 495 |
+
|
| 496 |
+
R --> V[Async RAGAS Evaluation]
|
| 497 |
+
V --> W[Compute Metrics]
|
| 498 |
+
W --> X[Store Results]
|
| 499 |
+
X --> Y[Update Dashboard]
|
| 500 |
+
```
|
| 501 |
+
|
| 502 |
+
---
|
| 503 |
+
|
| 504 |
+
## 4. Infrastructure & Deployment
|
| 505 |
+
|
| 506 |
+
### 4.1 Container Architecture
|
| 507 |
+
|
| 508 |
+
```mermaid
|
| 509 |
+
graph TB
|
| 510 |
+
subgraph "Docker Compose Stack"
|
| 511 |
+
A[Frontend Container<br/>nginx:alpine]
|
| 512 |
+
B[Backend Container<br/>python:3.11]
|
| 513 |
+
C[Ollama Container<br/>ollama/ollama]
|
| 514 |
+
end
|
| 515 |
+
|
| 516 |
+
subgraph "External Services"
|
| 517 |
+
D[FAISS Indices<br/>Persistent Volume]
|
| 518 |
+
E[SQLite Database<br/>Persistent Volume]
|
| 519 |
+
F[Log Files<br/>Persistent Volume]
|
| 520 |
+
end
|
| 521 |
+
|
| 522 |
+
A --> B
|
| 523 |
+
B --> C
|
| 524 |
+
B --> D
|
| 525 |
+
B --> E
|
| 526 |
+
B --> F
|
| 527 |
+
```
|
| 528 |
+
|
| 529 |
+
### 4.2 Resource Requirements
|
| 530 |
+
|
| 531 |
+
#### 4.2.1 Minimum Deployment
|
| 532 |
+
|
| 533 |
+
| Resource | Specification | Purpose |
|
| 534 |
+
|----------|---------------|---------|
|
| 535 |
+
| **CPU** | 4 cores | Document processing, embeddings |
|
| 536 |
+
| **RAM** | 8GB | Model loading, FAISS indices, cache |
|
| 537 |
+
| **Storage** | 20GB | Models, indices, documents |
|
| 538 |
+
| **GPU** | Optional | 2-3x speedup for inference |
|
| 539 |
+
|
| 540 |
+
#### 4.2.2 Production Deployment
|
| 541 |
+
|
| 542 |
+
| Resource | Specification | Purpose |
|
| 543 |
+
|----------|---------------|---------|
|
| 544 |
+
| **CPU** | 8+ cores | Concurrent processing |
|
| 545 |
+
| **RAM** | 16GB+ | Larger datasets, caching |
|
| 546 |
+
| **GPU** | RTX 3090/4090 | 20-30 tokens/sec inference |
|
| 547 |
+
| **Storage** | 100GB+ SSD | Fast vector search |
|
| 548 |
+
|
| 549 |
+
---
|
| 550 |
+
|
| 551 |
+
## 5. API Architecture
|
| 552 |
+
|
| 553 |
+
### 5.1 REST API Endpoints
|
| 554 |
+
|
| 555 |
+
```mermaid
|
| 556 |
+
graph TB
|
| 557 |
+
subgraph "System Management"
|
| 558 |
+
A[GET /api/health]
|
| 559 |
+
B[GET /api/system-info]
|
| 560 |
+
C[GET /api/configuration]
|
| 561 |
+
D[POST /api/configuration]
|
| 562 |
+
end
|
| 563 |
+
|
| 564 |
+
subgraph "Document Management"
|
| 565 |
+
E[POST /api/upload]
|
| 566 |
+
F[POST /api/start-processing]
|
| 567 |
+
G[GET /api/processing-status]
|
| 568 |
+
end
|
| 569 |
+
|
| 570 |
+
subgraph "Query & Chat"
|
| 571 |
+
H[POST /api/chat]
|
| 572 |
+
I[GET /api/export-chat/:session_id]
|
| 573 |
+
end
|
| 574 |
+
|
| 575 |
+
subgraph "RAGAS Evaluation"
|
| 576 |
+
J[GET /api/ragas/history]
|
| 577 |
+
K[GET /api/ragas/statistics]
|
| 578 |
+
L[POST /api/ragas/clear]
|
| 579 |
+
M[GET /api/ragas/export]
|
| 580 |
+
N[GET /api/ragas/config]
|
| 581 |
+
end
|
| 582 |
+
|
| 583 |
+
subgraph "Analytics"
|
| 584 |
+
O[GET /api/analytics]
|
| 585 |
+
P[GET /api/analytics/refresh]
|
| 586 |
+
Q[GET /api/analytics/detailed]
|
| 587 |
+
end
|
| 588 |
+
```
|
| 589 |
+
|
| 590 |
+
### 5.2 Request/Response Flow
|
| 591 |
+
|
| 592 |
+
```python
|
| 593 |
+
# Typical Chat Request Flow with RAGAS
|
| 594 |
+
REQUEST_FLOW = {
|
| 595 |
+
"authentication": "None (local deployment)",
|
| 596 |
+
"rate_limiting": "100 requests/minute per IP",
|
| 597 |
+
"validation": "Query length, session ID format",
|
| 598 |
+
"processing": "Async with progress tracking",
|
| 599 |
+
"response": "JSON with citations + metrics + RAGAS scores",
|
| 600 |
+
"caching": "LRU cache for embeddings",
|
| 601 |
+
"evaluation": "Automatic RAGAS metrics (async)"
|
| 602 |
+
}
|
| 603 |
+
```
|
| 604 |
+
|
| 605 |
+
---
|
| 606 |
+
|
| 607 |
+
## 6. Monitoring & Quality Assurance
|
| 608 |
+
|
| 609 |
+
### 6.1 RAGAS Integration
|
| 610 |
+
|
| 611 |
+
```mermaid
|
| 612 |
+
graph LR
|
| 613 |
+
A[API Gateway] --> B[Query Processing]
|
| 614 |
+
C[Retrieval Module] --> B
|
| 615 |
+
D[Generation Module] --> B
|
| 616 |
+
|
| 617 |
+
B --> E[RAGAS Evaluator]
|
| 618 |
+
|
| 619 |
+
E --> F[Analytics Dashboard]
|
| 620 |
+
|
| 621 |
+
F --> G[Answer Relevancy]
|
| 622 |
+
F --> H[Faithfulness]
|
| 623 |
+
F --> I[Context Utilization]
|
| 624 |
+
F --> J[Context Relevancy]
|
| 625 |
+
F --> K[Session Statistics]
|
| 626 |
+
```
|
| 627 |
+
|
| 628 |
+
### 6.2 Key Performance Indicators
|
| 629 |
+
|
| 630 |
+
| Category | Metric | Target | Alert Threshold |
|
| 631 |
+
|----------|--------|--------|-----------------|
|
| 632 |
+
| **Performance** | Query Latency (p95) | < 5s | > 10s |
|
| 633 |
+
| **Quality** | Answer Relevancy | > 0.85 | < 0.70 |
|
| 634 |
+
| **Quality** | Faithfulness | > 0.90 | < 0.80 |
|
| 635 |
+
| **Quality** | Context Utilization | > 0.80 | < 0.65 |
|
| 636 |
+
| **Quality** | Overall Score | > 0.85 | < 0.70 |
|
| 637 |
+
| **Reliability** | Uptime | > 99.5% | < 95% |
|
| 638 |
+
|
| 639 |
+
### 6.3 Analytics Dashboard Features
|
| 640 |
+
|
| 641 |
+
**Real-Time Metrics:**
|
| 642 |
+
- RAGAS evaluation table with all query-response pairs
|
| 643 |
+
- Session-level aggregate statistics
|
| 644 |
+
- Performance metrics (latency, throughput)
|
| 645 |
+
- Component health status
|
| 646 |
+
|
| 647 |
+
**Historical Analysis:**
|
| 648 |
+
- Quality trend over time
|
| 649 |
+
- Performance degradation detection
|
| 650 |
+
- Cache hit rate monitoring
|
| 651 |
+
- Resource utilization tracking
|
| 652 |
+
|
| 653 |
+
**Export Capabilities:**
|
| 654 |
+
- JSON export of all evaluation data
|
| 655 |
+
- CSV export for external analysis
|
| 656 |
+
- Session-based filtering
|
| 657 |
+
- Time-range queries
|
| 658 |
+
|
| 659 |
+
---
|
| 660 |
+
|
| 661 |
+
## 7. Technology Stack Details
|
| 662 |
+
|
| 663 |
+
### Complete Technology Matrix
|
| 664 |
+
|
| 665 |
+
| Layer | Component | Technology | Version | Purpose |
|
| 666 |
+
|-------|-----------|------------|---------|----------|
|
| 667 |
+
| **Frontend** | UI Framework | HTML5/CSS3/JS | - | Responsive interface |
|
| 668 |
+
| **Frontend** | Styling | Tailwind CSS | 3.3+ | Utility-first CSS |
|
| 669 |
+
| **Frontend** | Icons | Font Awesome | 6.0+ | Icon library |
|
| 670 |
+
| **Backend** | API Framework | FastAPI | 0.104+ | Async REST API |
|
| 671 |
+
| **Backend** | Python Version | Python | 3.11+ | Runtime |
|
| 672 |
+
| **AI/ML** | LLM Engine | Ollama | 0.1.20+ | Local LLM inference |
|
| 673 |
+
| **AI/ML** | Primary Model | Mistral-7B-Instruct | v0.2 | Text generation |
|
| 674 |
+
| **AI/ML** | Embeddings | sentence-transformers | 2.2.2+ | Vector embeddings |
|
| 675 |
+
| **AI/ML** | Embedding Model | BAAI/bge-small-en | v1.5 | Semantic search |
|
| 676 |
+
| **Vector DB** | Storage | FAISS | 1.7.4+ | Vector similarity |
|
| 677 |
+
| **Search** | Keyword | rank-bm25 | 0.2.1 | BM25 implementation |
|
| 678 |
+
| **Evaluation** | Quality | Ragas | 0.1.9 | RAG evaluation |
|
| 679 |
+
| **Document** | PDF | PyPDF2 | 3.0+ | PDF text extraction |
|
| 680 |
+
| **Document** | Word | python-docx | 1.1+ | DOCX processing |
|
| 681 |
+
| **OCR** | Text Recognition | EasyOCR | 1.7+ | Scanned documents |
|
| 682 |
+
| **Database** | Metadata | SQLite | 3.35+ | Local storage |
|
| 683 |
+
| **Cache** | In-memory | Python functools | - | LRU caching |
|
| 684 |
+
| **Deployment** | Container | Docker | 24.0+ | Containerization |
|
| 685 |
+
| **Deployment** | Orchestration | Docker Compose | 2.20+ | Multi-container |
|
| 686 |
+
|
| 687 |
+
---
|
| 688 |
+
|
| 689 |
+
## 8. Key Architectural Decisions
|
| 690 |
+
|
| 691 |
+
### 8.1 Why Local Caching Instead of Redis?
|
| 692 |
+
|
| 693 |
+
**Decision:** Use in-memory LRU cache with Python's `functools.lru_cache`
|
| 694 |
+
|
| 695 |
+
**Rationale:**
|
| 696 |
+
- **Simplicity**: No external service to manage
|
| 697 |
+
- **Performance**: Faster access (no network overhead)
|
| 698 |
+
- **MVP Focus**: Adequate for initial deployment
|
| 699 |
+
- **Resource Efficient**: No additional memory footprint
|
| 700 |
+
- **Easy Migration**: Can upgrade to Redis later if needed
|
| 701 |
+
|
| 702 |
+
**Trade-offs:**
|
| 703 |
+
- Cache doesn't persist across restarts
|
| 704 |
+
- Can't share cache across multiple instances
|
| 705 |
+
- Limited by single-process memory
|
| 706 |
+
|
| 707 |
+
### 8.2 Why RAGAS for Evaluation?
|
| 708 |
+
|
| 709 |
+
**Decision:** Integrate RAGAS for real-time quality assessment
|
| 710 |
+
|
| 711 |
+
**Rationale:**
|
| 712 |
+
- **Automated Metrics**: No manual annotation required
|
| 713 |
+
- **Production-Ready**: Quantifiable quality scores
|
| 714 |
+
- **Real-Time**: Evaluate every query-response pair
|
| 715 |
+
- **Comprehensive**: Multiple dimensions of quality
|
| 716 |
+
- **Research-Backed**: Based on academic research
|
| 717 |
+
|
| 718 |
+
**Implementation Details:**
|
| 719 |
+
- OpenAI API key required for LLM-based metrics
|
| 720 |
+
- Async evaluation to avoid blocking responses
|
| 721 |
+
- SQLite storage for historical analysis
|
| 722 |
+
- Export capability for offline processing
|
| 723 |
+
|
| 724 |
+
### 8.3 Why No Web Scraping?
|
| 725 |
+
|
| 726 |
+
**Decision:** Removed web scraping from MVP
|
| 727 |
+
|
| 728 |
+
**Rationale:**
|
| 729 |
+
- **Complexity**: Anti-scraping mechanisms require maintenance
|
| 730 |
+
- **Reliability**: Website changes break scrapers
|
| 731 |
+
- **Legal**: Potential legal/ethical issues
|
| 732 |
+
- **Scope**: Focus on core RAG functionality first
|
| 733 |
+
|
| 734 |
+
**Alternative:**
|
| 735 |
+
- Users can save web pages as PDFs
|
| 736 |
+
- Future enhancement if market demands it
|
| 737 |
+
|
| 738 |
+
---
|
| 739 |
+
|
| 740 |
+
## 9. Performance Optimization Strategies
|
| 741 |
+
|
| 742 |
+
### 9.1 Embedding Cache Strategy
|
| 743 |
+
|
| 744 |
+
```python
|
| 745 |
+
# Cache Implementation
|
| 746 |
+
from functools import lru_cache
|
| 747 |
+
|
| 748 |
+
@lru_cache(maxsize=1000)
|
| 749 |
+
def get_query_embedding(query: str) -> np.ndarray:
|
| 750 |
+
"""Cache query embeddings for repeat queries"""
|
| 751 |
+
return embedder.embed(query)
|
| 752 |
+
|
| 753 |
+
# Benefits:
|
| 754 |
+
# - 80% reduction in latency for repeat queries
|
| 755 |
+
# - No re-computation of identical queries
|
| 756 |
+
# - Automatic LRU eviction
|
| 757 |
+
```
|
| 758 |
+
|
| 759 |
+
### 9.2 Batch Processing
|
| 760 |
+
|
| 761 |
+
```python
|
| 762 |
+
# Batch Embedding Generation
|
| 763 |
+
BATCH_SIZE = 32
|
| 764 |
+
|
| 765 |
+
def embed_chunks_batch(chunks: List[str]) -> List[np.ndarray]:
|
| 766 |
+
embeddings = []
|
| 767 |
+
for i in range(0, len(chunks), BATCH_SIZE):
|
| 768 |
+
batch = chunks[i:i+BATCH_SIZE]
|
| 769 |
+
batch_embeddings = embedder.embed_batch(batch)
|
| 770 |
+
embeddings.extend(batch_embeddings)
|
| 771 |
+
return embeddings
|
| 772 |
+
```
|
| 773 |
+
|
| 774 |
+
### 9.3 Async Processing
|
| 775 |
+
|
| 776 |
+
```python
|
| 777 |
+
# Async Document Processing
|
| 778 |
+
import asyncio
|
| 779 |
+
|
| 780 |
+
async def process_documents_async(documents: List[Path]):
|
| 781 |
+
tasks = [process_single_document(doc) for doc in documents]
|
| 782 |
+
results = await asyncio.gather(*tasks)
|
| 783 |
+
return results
|
| 784 |
+
```
|
| 785 |
+
|
| 786 |
+
---
|
| 787 |
+
|
| 788 |
+
## 10. Security Considerations
|
| 789 |
+
|
| 790 |
+
### 10.1 Data Privacy
|
| 791 |
+
|
| 792 |
+
- **On-Premise Processing**: All data stays local
|
| 793 |
+
- **No External APIs**: Except OpenAI for RAGAS (configurable)
|
| 794 |
+
- **Local LLM**: Ollama runs entirely on-premise
|
| 795 |
+
- **Encrypted Storage**: Optional SQLite encryption
|
| 796 |
+
|
| 797 |
+
### 10.2 Input Validation
|
| 798 |
+
|
| 799 |
+
```python
|
| 800 |
+
# File Upload Validation
|
| 801 |
+
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
| 802 |
+
ALLOWED_EXTENSIONS = {'.pdf', '.docx', '.txt', '.zip'}
|
| 803 |
+
|
| 804 |
+
def validate_upload(file: UploadFile):
|
| 805 |
+
# Check extension
|
| 806 |
+
if Path(file.filename).suffix not in ALLOWED_EXTENSIONS:
|
| 807 |
+
raise ValueError("Unsupported file type")
|
| 808 |
+
|
| 809 |
+
# Check size
|
| 810 |
+
if file.size > MAX_FILE_SIZE:
|
| 811 |
+
raise ValueError("File too large")
|
| 812 |
+
|
| 813 |
+
# Scan for malicious content (optional)
|
| 814 |
+
# scan_for_malware(file)
|
| 815 |
+
```
|
| 816 |
+
|
| 817 |
+
### 10.3 Rate Limiting
|
| 818 |
+
|
| 819 |
+
```python
|
| 820 |
+
# Simple rate limiting
|
| 821 |
+
from fastapi import Request
|
| 822 |
+
from collections import defaultdict
|
| 823 |
+
from datetime import datetime, timedelta
|
| 824 |
+
|
| 825 |
+
rate_limits = defaultdict(list)
|
| 826 |
+
|
| 827 |
+
def check_rate_limit(request: Request, limit: int = 100):
|
| 828 |
+
ip = request.client.host
|
| 829 |
+
now = datetime.now()
|
| 830 |
+
|
| 831 |
+
# Clean old requests
|
| 832 |
+
rate_limits[ip] = [
|
| 833 |
+
ts for ts in rate_limits[ip]
|
| 834 |
+
if now - ts < timedelta(minutes=1)
|
| 835 |
+
]
|
| 836 |
+
|
| 837 |
+
# Check limit
|
| 838 |
+
if len(rate_limits[ip]) >= limit:
|
| 839 |
+
raise HTTPException(429, "Rate limit exceeded")
|
| 840 |
+
|
| 841 |
+
rate_limits[ip].append(now)
|
| 842 |
+
```
|
| 843 |
+
|
| 844 |
+
---
|
| 845 |
+
|
| 846 |
+
## Conclusion
|
| 847 |
+
|
| 848 |
+
This architecture document provides a comprehensive technical blueprint for the AI Universal Knowledge Ingestion System. The modular design, clear separation of concerns, and production-ready considerations make this system suitable for enterprise deployment while maintaining flexibility for future enhancements.
|
| 849 |
+
|
| 850 |
+
### Key Architectural Strengths
|
| 851 |
+
|
| 852 |
+
1. **Modularity**: Each component is independent and replaceable
|
| 853 |
+
2. **Scalability**: Horizontal scaling through stateless API design
|
| 854 |
+
3. **Performance**: Intelligent caching and batch processing
|
| 855 |
+
4. **Quality**: Real-time RAGAS evaluation for continuous monitoring
|
| 856 |
+
5. **Privacy**: Complete on-premise processing with local LLM
|
| 857 |
+
6. **Simplicity**: Minimal external dependencies (no Redis, no web scraping)
|
| 858 |
+
|
| 859 |
+
### Future Enhancements
|
| 860 |
+
|
| 861 |
+
**Short-term:**
|
| 862 |
+
- Redis cache for multi-instance deployments
|
| 863 |
+
- Advanced monitoring dashboard
|
| 864 |
+
- User authentication and authorization
|
| 865 |
+
- API rate limiting enhancements
|
| 866 |
+
|
| 867 |
+
**Long-term:**
|
| 868 |
+
- Distributed processing with Celery
|
| 869 |
+
- Web scraping module (optional)
|
| 870 |
+
- Fine-tuned domain-specific embeddings
|
| 871 |
+
- Multi-tenant support
|
| 872 |
+
- Advanced analytics and reporting
|
| 873 |
+
|
| 874 |
+
---
|
| 875 |
+
|
| 876 |
+
Document Version: 1.0
|
| 877 |
+
Last Updated: November 2025
|
| 878 |
+
Author: Satyaki Mitra
|
| 879 |
+
|
| 880 |
+
---
|
| 881 |
+
|
| 882 |
+
> This document is part of the AI Universal Knowledge Ingestion System technical documentation suite.
|
document_parser/__init__.py
ADDED
|
File without changes
|
document_parser/docx_parser.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import docx
|
| 3 |
+
import hashlib
|
| 4 |
+
from typing import List
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from docx.table import Table
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from docx.document import Document
|
| 10 |
+
from config.models import DocumentType
|
| 11 |
+
from docx.text.paragraph import Paragraph
|
| 12 |
+
from utils.text_cleaner import TextCleaner
|
| 13 |
+
from config.models import DocumentMetadata
|
| 14 |
+
from config.logging_config import get_logger
|
| 15 |
+
from utils.error_handler import handle_errors
|
| 16 |
+
from utils.error_handler import DOCXParseError
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Setup Logging
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DOCXParser:
|
| 24 |
+
"""
|
| 25 |
+
Comprehensive DOCX parsing with structure preservation: Handles paragraphs, tables, headers, and footers
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.logger = logger
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@handle_errors(error_type = DOCXParseError, log_error = True, reraise = True)
|
| 32 |
+
def parse(self, file_path: Path, extract_metadata: bool = True, clean_text: bool = True, include_tables: bool = True, include_headers_footers: bool = False) -> tuple[str, Optional[DocumentMetadata]]:
|
| 33 |
+
"""
|
| 34 |
+
Parse DOCX and extract text and metadata
|
| 35 |
+
|
| 36 |
+
Arguments:
|
| 37 |
+
----------
|
| 38 |
+
file_path { Path } : Path to DOCX file
|
| 39 |
+
|
| 40 |
+
extract_metadata { bool } : Extract document metadata
|
| 41 |
+
|
| 42 |
+
clean_text { bool } : Clean extracted text
|
| 43 |
+
|
| 44 |
+
include_tables { bool } : Include table content
|
| 45 |
+
|
| 46 |
+
include_headers_footers { bool } : Include headers and footers
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
--------
|
| 50 |
+
{ tuple } : Tuple of (extracted_text, metadata)
|
| 51 |
+
|
| 52 |
+
Raises:
|
| 53 |
+
-------
|
| 54 |
+
DOCXParseError : If parsing fails
|
| 55 |
+
"""
|
| 56 |
+
file_path = Path(file_path)
|
| 57 |
+
|
| 58 |
+
if not file_path.exists():
|
| 59 |
+
raise DOCXParseError(str(file_path), original_error = FileNotFoundError(f"DOCX file not found: {file_path}"))
|
| 60 |
+
|
| 61 |
+
self.logger.info(f"Parsing DOCX: {file_path}")
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# Open document
|
| 65 |
+
doc = docx.Document(file_path)
|
| 66 |
+
|
| 67 |
+
# Extract text content
|
| 68 |
+
text_parts = list()
|
| 69 |
+
|
| 70 |
+
# Extract paragraphs
|
| 71 |
+
paragraph_text = self._extract_paragraphs(doc = doc)
|
| 72 |
+
|
| 73 |
+
text_parts.append(paragraph_text)
|
| 74 |
+
|
| 75 |
+
# Extract tables
|
| 76 |
+
if include_tables:
|
| 77 |
+
table_text = self._extract_tables(doc)
|
| 78 |
+
|
| 79 |
+
if table_text:
|
| 80 |
+
text_parts.append("\n[TABLES]\n" + table_text)
|
| 81 |
+
|
| 82 |
+
# Extract headers and footers
|
| 83 |
+
if include_headers_footers:
|
| 84 |
+
header_footer_text = self._extract_headers_footers(doc)
|
| 85 |
+
|
| 86 |
+
if header_footer_text:
|
| 87 |
+
text_parts.append("\n[HEADERS/FOOTERS]\n" + header_footer_text)
|
| 88 |
+
|
| 89 |
+
# Combine all text
|
| 90 |
+
text_content = "\n".join(text_parts)
|
| 91 |
+
|
| 92 |
+
# Extract metadata
|
| 93 |
+
metadata = None
|
| 94 |
+
|
| 95 |
+
if extract_metadata:
|
| 96 |
+
metadata = self._extract_metadata(doc, file_path)
|
| 97 |
+
|
| 98 |
+
# Clean text
|
| 99 |
+
if clean_text:
|
| 100 |
+
text_content = TextCleaner.clean(text_content,
|
| 101 |
+
remove_html = False,
|
| 102 |
+
normalize_whitespace = True,
|
| 103 |
+
preserve_structure = True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.logger.info(f"Successfully parsed DOCX: {len(text_content)} characters, {len(doc.paragraphs)} paragraphs")
|
| 107 |
+
|
| 108 |
+
return text_content, metadata
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
self.logger.error(f"Failed to parse DOCX {file_path}: {str(e)}")
|
| 112 |
+
|
| 113 |
+
raise DOCXParseError(str(file_path), original_error = e)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _extract_paragraphs(self, doc: Document) -> str:
|
| 117 |
+
"""
|
| 118 |
+
Extract text from paragraphs, preserving structure
|
| 119 |
+
|
| 120 |
+
Arguments:
|
| 121 |
+
----------
|
| 122 |
+
doc { Document } : Document object
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
--------
|
| 126 |
+
{ str } : Combined paragraph text
|
| 127 |
+
"""
|
| 128 |
+
text_parts = list()
|
| 129 |
+
|
| 130 |
+
for i, para in enumerate(doc.paragraphs):
|
| 131 |
+
text = para.text.strip()
|
| 132 |
+
|
| 133 |
+
if not text:
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
# Detect headings
|
| 137 |
+
if para.style.name.startswith('Heading'):
|
| 138 |
+
heading_level = para.style.name.replace('Heading', '').strip()
|
| 139 |
+
text_parts.append(f"\n[HEADING {heading_level}] {text}\n")
|
| 140 |
+
|
| 141 |
+
else:
|
| 142 |
+
text_parts.append(text)
|
| 143 |
+
|
| 144 |
+
return "\n".join(text_parts)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _extract_tables(self, doc: Document) -> str:
|
| 148 |
+
"""
|
| 149 |
+
Extract text from tables
|
| 150 |
+
|
| 151 |
+
Arguments:
|
| 152 |
+
----------
|
| 153 |
+
doc { Document } : Document object
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
--------
|
| 157 |
+
{ str } : Combined table text
|
| 158 |
+
"""
|
| 159 |
+
if not doc.tables:
|
| 160 |
+
return ""
|
| 161 |
+
|
| 162 |
+
table_parts = list()
|
| 163 |
+
|
| 164 |
+
for table_idx, table in enumerate(doc.tables):
|
| 165 |
+
table_text = self._parse_table(table)
|
| 166 |
+
|
| 167 |
+
if table_text:
|
| 168 |
+
table_parts.append(f"\n[TABLE {table_idx + 1}]\n{table_text}")
|
| 169 |
+
|
| 170 |
+
return "\n".join(table_parts)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _parse_table(self, table: Table) -> str:
|
| 174 |
+
"""
|
| 175 |
+
Parse a single table into text
|
| 176 |
+
|
| 177 |
+
Arguments:
|
| 178 |
+
----------
|
| 179 |
+
table { Table } : Table object
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
--------
|
| 183 |
+
{ str } : Table text
|
| 184 |
+
"""
|
| 185 |
+
rows_text = list()
|
| 186 |
+
|
| 187 |
+
for row in table.rows:
|
| 188 |
+
cells_text = list()
|
| 189 |
+
for cell in row.cells:
|
| 190 |
+
cell_text = cell.text.strip()
|
| 191 |
+
|
| 192 |
+
cells_text.append(cell_text)
|
| 193 |
+
|
| 194 |
+
# Join cells with pipe separator for readability
|
| 195 |
+
rows_text.append(" | ".join(cells_text))
|
| 196 |
+
|
| 197 |
+
return "\n".join(rows_text)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _extract_headers_footers(self, doc: Document) -> str:
|
| 201 |
+
"""
|
| 202 |
+
Extract headers and footers
|
| 203 |
+
|
| 204 |
+
Arguments:
|
| 205 |
+
----------
|
| 206 |
+
doc { Document } : Document object
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
--------
|
| 210 |
+
{ str } : Headers and footers text
|
| 211 |
+
"""
|
| 212 |
+
parts = list()
|
| 213 |
+
|
| 214 |
+
# Extract from each section
|
| 215 |
+
for section in doc.sections:
|
| 216 |
+
# Header
|
| 217 |
+
if section.header:
|
| 218 |
+
header_text = self._extract_paragraphs_from_element(element = section.header)
|
| 219 |
+
|
| 220 |
+
if header_text:
|
| 221 |
+
parts.append(f"[HEADER]\n{header_text}")
|
| 222 |
+
|
| 223 |
+
# Footer
|
| 224 |
+
if section.footer:
|
| 225 |
+
footer_text = self._extract_paragraphs_from_element(element = section.footer)
|
| 226 |
+
|
| 227 |
+
if footer_text:
|
| 228 |
+
parts.append(f"[FOOTER]\n{footer_text}")
|
| 229 |
+
|
| 230 |
+
return "\n".join(parts)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@staticmethod
|
| 234 |
+
def _extract_paragraphs_from_element(element) -> str:
|
| 235 |
+
"""
|
| 236 |
+
Extract paragraphs from header/footer element
|
| 237 |
+
"""
|
| 238 |
+
parts = list()
|
| 239 |
+
|
| 240 |
+
for para in element.paragraphs:
|
| 241 |
+
text = para.text.strip()
|
| 242 |
+
if text:
|
| 243 |
+
parts.append(text)
|
| 244 |
+
|
| 245 |
+
return "\n".join(parts)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _extract_metadata(self, doc: Document, file_path: Path) -> DocumentMetadata:
|
| 249 |
+
"""
|
| 250 |
+
Extract metadata from DOCX
|
| 251 |
+
|
| 252 |
+
Arguments:
|
| 253 |
+
----------
|
| 254 |
+
doc { Document} : Document object
|
| 255 |
+
|
| 256 |
+
file_path { Path } : Path to DOCX file
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
--------
|
| 260 |
+
{ DocumentMetadata } : DocumentMetadata object
|
| 261 |
+
"""
|
| 262 |
+
# Get core properties
|
| 263 |
+
core_props = doc.core_properties
|
| 264 |
+
|
| 265 |
+
# Extract fields
|
| 266 |
+
title = core_props.title or file_path.stem
|
| 267 |
+
author = core_props.author
|
| 268 |
+
created_date = core_props.created
|
| 269 |
+
modified_date = core_props.modified
|
| 270 |
+
|
| 271 |
+
# Get file size
|
| 272 |
+
file_size = file_path.stat().st_size
|
| 273 |
+
|
| 274 |
+
# Generate document ID
|
| 275 |
+
doc_hash = hashlib.md5(str(file_path).encode()).hexdigest()
|
| 276 |
+
doc_id = f"doc_{int(datetime.now().timestamp())}_{doc_hash}"
|
| 277 |
+
|
| 278 |
+
# Count paragraphs and estimate pages
|
| 279 |
+
num_paragraphs = len(doc.paragraphs)
|
| 280 |
+
|
| 281 |
+
# Rough estimate: 500 words per page, 5-10 words per paragraph
|
| 282 |
+
estimated_pages = max(1, num_paragraphs // 50)
|
| 283 |
+
|
| 284 |
+
# Create metadata object
|
| 285 |
+
metadata = DocumentMetadata(document_id = doc_id,
|
| 286 |
+
filename = file_path.name,
|
| 287 |
+
file_path = file_path,
|
| 288 |
+
document_type = DocumentType.DOCX,
|
| 289 |
+
title = title,
|
| 290 |
+
author = author,
|
| 291 |
+
created_date = created_date,
|
| 292 |
+
modified_date = modified_date,
|
| 293 |
+
file_size_bytes = file_size,
|
| 294 |
+
num_pages = estimated_pages,
|
| 295 |
+
extra = {"num_paragraphs" : num_paragraphs,
|
| 296 |
+
"num_tables" : len(doc.tables),
|
| 297 |
+
"num_sections" : len(doc.sections),
|
| 298 |
+
"category" : core_props.category,
|
| 299 |
+
"comments" : core_props.comments,
|
| 300 |
+
"keywords" : core_props.keywords,
|
| 301 |
+
"subject" : core_props.subject,
|
| 302 |
+
}
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
return metadata
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_paragraph_count(self, file_path: Path) -> int:
|
| 309 |
+
"""
|
| 310 |
+
Get number of paragraphs in document
|
| 311 |
+
|
| 312 |
+
Arguments:
|
| 313 |
+
----------
|
| 314 |
+
file_path { Path } : Path to DOCX file
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
--------
|
| 318 |
+
{ int } : Number of paragraphs
|
| 319 |
+
"""
|
| 320 |
+
try:
|
| 321 |
+
doc = docx.Document(file_path)
|
| 322 |
+
return len(doc.paragraphs)
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
self.logger.error(f"Failed to get paragraph count: {repr(e)}")
|
| 326 |
+
raise DOCXParseError(str(file_path), original_error = e)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def extract_section(self, file_path: Path, section_index: int, clean_text: bool = True) -> str:
|
| 330 |
+
"""
|
| 331 |
+
Extract text from a specific section
|
| 332 |
+
|
| 333 |
+
Arguments:
|
| 334 |
+
----------
|
| 335 |
+
file_path { Path } : Path to DOCX file
|
| 336 |
+
|
| 337 |
+
section_index { int } : Section index (0-indexed)
|
| 338 |
+
|
| 339 |
+
clean_text { bool } : Clean extracted text
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
--------
|
| 343 |
+
{ str } : Section text
|
| 344 |
+
"""
|
| 345 |
+
try:
|
| 346 |
+
doc = docx.Document(file_path)
|
| 347 |
+
|
| 348 |
+
if ((section_index < 0) or (section_index >= len(doc.sections))):
|
| 349 |
+
raise ValueError(f"Section index {section_index} out of range (0-{len(doc.sections)-1})")
|
| 350 |
+
|
| 351 |
+
# Note: Extracting text by section is not straightforward in python-docx
|
| 352 |
+
section = doc.sections[section_index]
|
| 353 |
+
|
| 354 |
+
# For now, we'll extract the entire document
|
| 355 |
+
text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
|
| 356 |
+
|
| 357 |
+
if clean_text:
|
| 358 |
+
text = TextCleaner.clean(text)
|
| 359 |
+
|
| 360 |
+
return text
|
| 361 |
+
|
| 362 |
+
except Exception as e:
|
| 363 |
+
self.logger.error(f"Failed to extract section: {repr(e)}")
|
| 364 |
+
raise DOCXParseError(str(file_path), original_error = e)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def extract_heading_sections(self, file_path: Path, clean_text: bool = True) -> dict[str, str]:
|
| 368 |
+
"""
|
| 369 |
+
Extract text organized by headings
|
| 370 |
+
|
| 371 |
+
Arguments:
|
| 372 |
+
----------
|
| 373 |
+
file_path { Path } : Path to DOCX file
|
| 374 |
+
|
| 375 |
+
clean_text { bool } : Clean extracted text
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
--------
|
| 379 |
+
{ dict } : Dictionary mapping heading text to content
|
| 380 |
+
"""
|
| 381 |
+
try:
|
| 382 |
+
doc = docx.Document(file_path)
|
| 383 |
+
sections = dict()
|
| 384 |
+
current_content = list()
|
| 385 |
+
current_heading = "Introduction"
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
for para in doc.paragraphs:
|
| 389 |
+
text = para.text.strip()
|
| 390 |
+
|
| 391 |
+
if not text:
|
| 392 |
+
continue
|
| 393 |
+
|
| 394 |
+
# Check if it's a heading
|
| 395 |
+
if para.style.name.startswith('Heading'):
|
| 396 |
+
# Save previous section
|
| 397 |
+
if current_content:
|
| 398 |
+
section_text = "\n".join(current_content)
|
| 399 |
+
|
| 400 |
+
if clean_text:
|
| 401 |
+
section_text = TextCleaner.clean(section_text)
|
| 402 |
+
|
| 403 |
+
sections[current_heading] = section_text
|
| 404 |
+
|
| 405 |
+
# Start new section
|
| 406 |
+
current_heading = text
|
| 407 |
+
current_content = list()
|
| 408 |
+
|
| 409 |
+
else:
|
| 410 |
+
current_content.append(text)
|
| 411 |
+
|
| 412 |
+
# Save last section
|
| 413 |
+
if current_content:
|
| 414 |
+
section_text = "\n".join(current_content)
|
| 415 |
+
|
| 416 |
+
if clean_text:
|
| 417 |
+
section_text = TextCleaner.clean(section_text)
|
| 418 |
+
|
| 419 |
+
sections[current_heading] = section_text
|
| 420 |
+
|
| 421 |
+
return sections
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
self.logger.error(f"Failed to extract heading sections: {repr(e)}")
|
| 425 |
+
raise DOCXParseError(str(file_path), original_error = e)
|
document_parser/ocr_engine.py
ADDED
|
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import fitz
|
| 5 |
+
import easyocr
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from typing import List
|
| 10 |
+
from typing import Dict
|
| 11 |
+
from typing import Tuple
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from PIL import ImageFilter
|
| 14 |
+
from typing import Optional
|
| 15 |
+
from PIL import ImageEnhance
|
| 16 |
+
from paddleocr import PaddleOCR
|
| 17 |
+
from config.settings import get_settings
|
| 18 |
+
from utils.error_handler import OCRException
|
| 19 |
+
from config.logging_config import get_logger
|
| 20 |
+
from utils.error_handler import handle_errors
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Setup Settings and Logging
|
| 24 |
+
settings = get_settings()
|
| 25 |
+
logger = get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class OCREngine:
|
| 29 |
+
"""
|
| 30 |
+
OCR engine with layout preservation - maintains document structure and formatting
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, use_paddle: bool = True, lang: str = 'en', gpu: bool = False):
|
| 33 |
+
"""
|
| 34 |
+
Initialize OCR engine
|
| 35 |
+
|
| 36 |
+
Arguments:
|
| 37 |
+
----------
|
| 38 |
+
use_paddle { bool } : Use PaddleOCR as primary (better accuracy)
|
| 39 |
+
|
| 40 |
+
lang { str } : Language code ('en', 'es', 'fr', 'de', etc.)
|
| 41 |
+
|
| 42 |
+
gpu { bool } : Use GPU acceleration if available
|
| 43 |
+
"""
|
| 44 |
+
self.logger = logger
|
| 45 |
+
self.use_paddle = use_paddle
|
| 46 |
+
self.lang = lang
|
| 47 |
+
self.gpu = gpu
|
| 48 |
+
self.paddle_ocr = None
|
| 49 |
+
self.easy_ocr = None
|
| 50 |
+
self._initialized = False
|
| 51 |
+
|
| 52 |
+
self._initialize_engines()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _initialize_engines(self):
|
| 56 |
+
"""
|
| 57 |
+
Initialize OCR engines with proper error handling
|
| 58 |
+
"""
|
| 59 |
+
if self.use_paddle:
|
| 60 |
+
try:
|
| 61 |
+
self.paddle_ocr = PaddleOCR(use_angle_cls = True,
|
| 62 |
+
lang = self.lang,
|
| 63 |
+
use_gpu = self.gpu,
|
| 64 |
+
show_log = False,
|
| 65 |
+
det_db_thresh = 0.3,
|
| 66 |
+
det_db_box_thresh = 0.5,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.logger.info("PaddleOCR initialized successfully")
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
self.logger.warning(f"PaddleOCR not available: {repr(e)}. Falling back to EasyOCR.")
|
| 73 |
+
self.use_paddle = False
|
| 74 |
+
|
| 75 |
+
if not self.use_paddle:
|
| 76 |
+
try:
|
| 77 |
+
self.easy_ocr = easyocr.Reader([self.lang], gpu = self.gpu)
|
| 78 |
+
self.logger.info("EasyOCR initialized successfully")
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
self.logger.error(f"Failed to initialize EasyOCR: {repr(e)}")
|
| 82 |
+
raise OCRException(f"OCR engine initialization failed: {repr(e)}")
|
| 83 |
+
|
| 84 |
+
self._initialized = True
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@handle_errors(error_type=OCRException, log_error=True, reraise=True)
|
| 88 |
+
def extract_text_from_pdf(self, pdf_path: Path, pages: Optional[List[int]] = None, preserve_layout: bool = True) -> str:
|
| 89 |
+
"""
|
| 90 |
+
Extract text from PDF using OCR with layout preservation
|
| 91 |
+
|
| 92 |
+
Arguments:
|
| 93 |
+
----------
|
| 94 |
+
pdf_path { Path } : Path to PDF file
|
| 95 |
+
|
| 96 |
+
pages { list } : Specific pages to OCR (None = all pages)
|
| 97 |
+
|
| 98 |
+
preserve_layout { bool } : Preserve document layout and structure
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
--------
|
| 102 |
+
{ str } : Extracted text with preserved formatting
|
| 103 |
+
"""
|
| 104 |
+
pdf_path = Path(pdf_path)
|
| 105 |
+
|
| 106 |
+
self.logger.info(f"Starting OCR extraction from PDF: {pdf_path}")
|
| 107 |
+
|
| 108 |
+
if not pdf_path.exists():
|
| 109 |
+
raise OCRException(f"PDF file not found: {pdf_path}")
|
| 110 |
+
|
| 111 |
+
# Convert PDF pages to high-quality images
|
| 112 |
+
images = self._pdf_to_images(pdf_path = pdf_path,
|
| 113 |
+
pages = pages,
|
| 114 |
+
dpi = 300,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.logger.info(f"Converted {len(images)} pages to images for OCR")
|
| 118 |
+
|
| 119 |
+
# OCR each image with layout preservation
|
| 120 |
+
all_text = list()
|
| 121 |
+
|
| 122 |
+
for i, image in enumerate(images):
|
| 123 |
+
page_num = pages[i] if pages else i + 1
|
| 124 |
+
|
| 125 |
+
self.logger.info(f"Processing page {page_num}...")
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
if preserve_layout:
|
| 129 |
+
# Extract text with layout information
|
| 130 |
+
page_text = self._extract_text_with_layout(image = image,
|
| 131 |
+
page_num = page_num,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
# Simple extraction without layout
|
| 136 |
+
img_array = np.array(image)
|
| 137 |
+
page_text = self._ocr_image(img_array)
|
| 138 |
+
|
| 139 |
+
if page_text and page_text.strip():
|
| 140 |
+
all_text.append(f"[PAGE {page_num}]\n{page_text}")
|
| 141 |
+
self.logger.info(f"✓ Extracted {len(page_text)} characters from page {page_num}")
|
| 142 |
+
|
| 143 |
+
else:
|
| 144 |
+
self.logger.warning(f"No text extracted from page {page_num}")
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
self.logger.error(f"OCR failed for page {page_num}: {repr(e)}")
|
| 148 |
+
all_text.append(f"[PAGE {page_num}]\n[OCR FAILED: {str(e)}]")
|
| 149 |
+
|
| 150 |
+
combined_text = "\n\n".join(all_text)
|
| 151 |
+
self.logger.info(f"OCR completed: {len(combined_text)} total characters extracted")
|
| 152 |
+
|
| 153 |
+
return combined_text
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _extract_text_with_layout(self, image: Image.Image, page_num: int) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Extract text while preserving document layout and structure
|
| 159 |
+
|
| 160 |
+
Arguments:
|
| 161 |
+
----------
|
| 162 |
+
image { Image.Image } : PIL Image
|
| 163 |
+
|
| 164 |
+
page_num { int } : Page number
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
--------
|
| 168 |
+
{ str } : Formatted text with layout preserved
|
| 169 |
+
"""
|
| 170 |
+
img_array = np.array(image)
|
| 171 |
+
|
| 172 |
+
# Get OCR results with bounding boxes
|
| 173 |
+
if (self.use_paddle and self.paddle_ocr):
|
| 174 |
+
text_blocks = self._ocr_with_layout_paddle(image_array = img_array)
|
| 175 |
+
|
| 176 |
+
elif self.easy_ocr:
|
| 177 |
+
text_blocks = self._ocr_with_layout_easyocr(image_array = img_array)
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
return ""
|
| 181 |
+
|
| 182 |
+
if not text_blocks:
|
| 183 |
+
return ""
|
| 184 |
+
|
| 185 |
+
# Organize text blocks into reading order with layout preservation
|
| 186 |
+
formatted_text = self._reconstruct_layout(text_blocks = text_blocks,
|
| 187 |
+
image_size = image.size,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return formatted_text
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _ocr_with_layout_paddle(self, image_array: np.ndarray) -> List[Dict]:
|
| 194 |
+
"""
|
| 195 |
+
OCR using PaddleOCR and return structured text blocks with positions
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
--------
|
| 199 |
+
{ list } : {'text': str, 'bbox': [...], 'confidence': float}
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
result = self.paddle_ocr.ocr(image_array, cls=True)
|
| 203 |
+
|
| 204 |
+
if not result or not result[0]:
|
| 205 |
+
return []
|
| 206 |
+
|
| 207 |
+
text_blocks = list()
|
| 208 |
+
|
| 209 |
+
for line in result[0]:
|
| 210 |
+
if (line and (len(line) >= 2)):
|
| 211 |
+
bbox = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
| 212 |
+
text_info = line[1]
|
| 213 |
+
|
| 214 |
+
if (isinstance(text_info, (list, tuple)) and (len(text_info) >= 2)):
|
| 215 |
+
text = text_info[0]
|
| 216 |
+
confidence = text_info[1]
|
| 217 |
+
|
| 218 |
+
elif isinstance(text_info, str):
|
| 219 |
+
text = text_info
|
| 220 |
+
confidence = 1.0
|
| 221 |
+
|
| 222 |
+
else:
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
if ((confidence > 0.5) and text and text.strip()):
|
| 226 |
+
# Calculate bounding box coordinates
|
| 227 |
+
x_coords = [point[0] for point in bbox]
|
| 228 |
+
y_coords = [point[1] for point in bbox]
|
| 229 |
+
|
| 230 |
+
text_blocks.append({'text' : text.strip(),
|
| 231 |
+
'bbox' : {'x1': min(x_coords),
|
| 232 |
+
'y1': min(y_coords),
|
| 233 |
+
'x2': max(x_coords),
|
| 234 |
+
'y2': max(y_coords)
|
| 235 |
+
},
|
| 236 |
+
'confidence' : confidence,
|
| 237 |
+
'center_y' : (min(y_coords) + max(y_coords)) / 2,
|
| 238 |
+
'center_x' : (min(x_coords) + max(x_coords)) / 2,
|
| 239 |
+
})
|
| 240 |
+
|
| 241 |
+
return text_blocks
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
self.logger.error(f"PaddleOCR layout extraction failed: {repr(e)}")
|
| 245 |
+
return []
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _ocr_with_layout_easyocr(self, image_array: np.ndarray) -> List[Dict]:
|
| 249 |
+
"""
|
| 250 |
+
OCR using EasyOCR and return structured text blocks with positions
|
| 251 |
+
"""
|
| 252 |
+
try:
|
| 253 |
+
result = self.easy_ocr.readtext(image_array, paragraph=False)
|
| 254 |
+
|
| 255 |
+
if not result:
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
text_blocks = list()
|
| 259 |
+
|
| 260 |
+
for detection in result:
|
| 261 |
+
bbox = detection[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
| 262 |
+
text = detection[1]
|
| 263 |
+
confidence = detection[2]
|
| 264 |
+
|
| 265 |
+
if ((confidence > 0.5) and text and text.strip()):
|
| 266 |
+
x_coords = [point[0] for point in bbox]
|
| 267 |
+
y_coords = [point[1] for point in bbox]
|
| 268 |
+
|
| 269 |
+
text_blocks.append({'text' : text.strip(),
|
| 270 |
+
'bbox' : {'x1' : min(x_coords),
|
| 271 |
+
'y1' : min(y_coords),
|
| 272 |
+
'x2' : max(x_coords),
|
| 273 |
+
'y2' : max(y_coords),
|
| 274 |
+
},
|
| 275 |
+
'confidence' : confidence,
|
| 276 |
+
'center_y' : (min(y_coords) + max(y_coords)) / 2,
|
| 277 |
+
'center_x' : (min(x_coords) + max(x_coords)) / 2,
|
| 278 |
+
})
|
| 279 |
+
|
| 280 |
+
return text_blocks
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
self.logger.error(f"EasyOCR layout extraction failed: {repr(e)}")
|
| 284 |
+
return []
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _reconstruct_layout(self, text_blocks: List[Dict], image_size: Tuple[int, int]) -> str:
|
| 288 |
+
"""
|
| 289 |
+
Reconstruct document layout from text blocks
|
| 290 |
+
|
| 291 |
+
Strategy:
|
| 292 |
+
1. Group text blocks into lines (similar Y coordinates)
|
| 293 |
+
2. Detect columns, tables, lists
|
| 294 |
+
3. Sort lines top to bottom
|
| 295 |
+
4. Within each line, sort left to right
|
| 296 |
+
5. Detect paragraphs, headings, and lists
|
| 297 |
+
6. Add appropriate spacing and formatting
|
| 298 |
+
"""
|
| 299 |
+
if not text_blocks:
|
| 300 |
+
return ""
|
| 301 |
+
|
| 302 |
+
# Sort all blocks by Y position first
|
| 303 |
+
sorted_blocks = sorted(text_blocks, key = lambda x: (x['center_y'], x['center_x']))
|
| 304 |
+
|
| 305 |
+
# Detect multi-column layout
|
| 306 |
+
columns = self._detect_columns(text_blocks = text_blocks,
|
| 307 |
+
image_size = image_size,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Group into lines (blocks with similar Y coordinates)
|
| 311 |
+
lines = list()
|
| 312 |
+
|
| 313 |
+
current_line = [sorted_blocks[0]]
|
| 314 |
+
|
| 315 |
+
# pixels
|
| 316 |
+
line_height_threshold = 25
|
| 317 |
+
|
| 318 |
+
for block in sorted_blocks[1:]:
|
| 319 |
+
# Check if this block is on the same line as the previous one
|
| 320 |
+
y_diff = abs(block['center_y'] - current_line[-1]['center_y'])
|
| 321 |
+
|
| 322 |
+
if (y_diff < line_height_threshold):
|
| 323 |
+
current_line.append(block)
|
| 324 |
+
|
| 325 |
+
else:
|
| 326 |
+
# Sort current line by X position and add to lines
|
| 327 |
+
current_line.sort(key = lambda x: x['center_x'])
|
| 328 |
+
lines.append(current_line)
|
| 329 |
+
|
| 330 |
+
current_line = [block]
|
| 331 |
+
|
| 332 |
+
# Don't forget the last line
|
| 333 |
+
if current_line:
|
| 334 |
+
current_line.sort(key = lambda x: x['center_x'])
|
| 335 |
+
lines.append(current_line)
|
| 336 |
+
|
| 337 |
+
# Reconstruct text with formatting
|
| 338 |
+
formatted_lines = list()
|
| 339 |
+
prev_y = 0
|
| 340 |
+
prev_indent = 0
|
| 341 |
+
|
| 342 |
+
for i, line_blocks in enumerate(lines):
|
| 343 |
+
# Calculate line metrics
|
| 344 |
+
current_y = line_blocks[0]['center_y']
|
| 345 |
+
vertical_gap = current_y - prev_y if (prev_y > 0) else 0
|
| 346 |
+
|
| 347 |
+
# Detect indentation (left margin)
|
| 348 |
+
line_left_margin = line_blocks[0]['bbox']['x1']
|
| 349 |
+
|
| 350 |
+
# Combine text blocks in this line with proper spacing
|
| 351 |
+
line_text = self._combine_line_blocks(line_blocks = line_blocks)
|
| 352 |
+
|
| 353 |
+
# Clean the text
|
| 354 |
+
line_text = self._clean_ocr_text(text = line_text)
|
| 355 |
+
|
| 356 |
+
# Skip if empty after cleaning
|
| 357 |
+
if not line_text.strip():
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
# Skip likely page numbers or artifacts (single numbers, very short text)
|
| 361 |
+
if self._is_page_artifact(line_text):
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
# Add extra newline for paragraph breaks (large vertical gaps)
|
| 365 |
+
# Threshold for paragraph break
|
| 366 |
+
if (vertical_gap > 35):
|
| 367 |
+
formatted_lines.append("")
|
| 368 |
+
|
| 369 |
+
# Detect and format different line types
|
| 370 |
+
if (self._is_heading(line_text, line_blocks)):
|
| 371 |
+
# Heading - add extra spacing
|
| 372 |
+
formatted_lines.append(f"\n{line_text}")
|
| 373 |
+
|
| 374 |
+
elif (self._is_bullet_point(line_text)):
|
| 375 |
+
# Bullet point or list item
|
| 376 |
+
formatted_lines.append(f" {line_text}")
|
| 377 |
+
|
| 378 |
+
elif (self._is_table_row(line_blocks)):
|
| 379 |
+
# Table row - preserve spacing between columns
|
| 380 |
+
formatted_lines.append(self._format_table_row(line_blocks))
|
| 381 |
+
|
| 382 |
+
else:
|
| 383 |
+
# Regular paragraph text
|
| 384 |
+
formatted_lines.append(line_text)
|
| 385 |
+
|
| 386 |
+
prev_y = current_y
|
| 387 |
+
prev_indent = line_left_margin
|
| 388 |
+
|
| 389 |
+
return "\n".join(formatted_lines)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _combine_line_blocks(self, line_blocks: List[Dict]) -> str:
|
| 393 |
+
"""
|
| 394 |
+
Combine text blocks in a line with intelligent spacing
|
| 395 |
+
"""
|
| 396 |
+
if (len(line_blocks) == 1):
|
| 397 |
+
return line_blocks[0]['text']
|
| 398 |
+
|
| 399 |
+
result = list()
|
| 400 |
+
|
| 401 |
+
for i, block in enumerate(line_blocks):
|
| 402 |
+
result.append(block['text'])
|
| 403 |
+
|
| 404 |
+
# Add space between blocks if they're not touching
|
| 405 |
+
if (i < len(line_blocks) - 1):
|
| 406 |
+
next_block = line_blocks[i + 1]
|
| 407 |
+
gap = next_block['bbox']['x1'] - block['bbox']['x2']
|
| 408 |
+
|
| 409 |
+
# If gap is significant, add spacing
|
| 410 |
+
if (gap > 20): # Threshold for adding extra space
|
| 411 |
+
# Double space for columns/tables
|
| 412 |
+
result.append(" ")
|
| 413 |
+
|
| 414 |
+
elif (gap > 5):
|
| 415 |
+
# Normal space
|
| 416 |
+
result.append(" ")
|
| 417 |
+
|
| 418 |
+
return "".join(result)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def _clean_ocr_text(self, text: str) -> str:
|
| 422 |
+
"""
|
| 423 |
+
Clean OCR artifacts and normalize text
|
| 424 |
+
"""
|
| 425 |
+
# Replace common OCR errors
|
| 426 |
+
replacements = {''' : "'", # Smart quote to regular quote
|
| 427 |
+
''' : "'",
|
| 428 |
+
'"' : '"',
|
| 429 |
+
'"' : '"',
|
| 430 |
+
'—' : '-',
|
| 431 |
+
'–' : '-',
|
| 432 |
+
'…' : '...',
|
| 433 |
+
'\u00a0' : ' ', # Non-breaking space
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
for old, new in replacements.items():
|
| 437 |
+
text = text.replace(old, new)
|
| 438 |
+
|
| 439 |
+
# Fix common OCR mistakes
|
| 440 |
+
text = text.replace('l ', 'I ') # lowercase L to I at start of sentence
|
| 441 |
+
text = text.replace(' l ', ' I ') # lowercase L to I
|
| 442 |
+
|
| 443 |
+
# Remove extra spaces
|
| 444 |
+
text = ' '.join(text.split())
|
| 445 |
+
|
| 446 |
+
return text
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _is_page_artifact(self, text: str) -> bool:
|
| 450 |
+
"""
|
| 451 |
+
Detect page numbers, headers, footers, and other artifacts
|
| 452 |
+
"""
|
| 453 |
+
text = text.strip()
|
| 454 |
+
|
| 455 |
+
# Empty or very short
|
| 456 |
+
if (len(text) < 2):
|
| 457 |
+
return True
|
| 458 |
+
|
| 459 |
+
# Just a number (likely page number)
|
| 460 |
+
if (text.isdigit() and (len(text) <= 3)):
|
| 461 |
+
return True
|
| 462 |
+
|
| 463 |
+
# Common footer patterns
|
| 464 |
+
footer_patterns = ['page', 'of', 'for informational purposes', 'confidential', 'draft', 'version']
|
| 465 |
+
text_lower = text.lower()
|
| 466 |
+
|
| 467 |
+
if ((len(text) < 50) and (any(pattern in text_lower for pattern in footer_patterns))):
|
| 468 |
+
# This is actually useful - don't skip
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
# Very short isolated text (likely artifact)
|
| 472 |
+
if ((len(text) <= 3) and not text.isalnum()):
|
| 473 |
+
return True
|
| 474 |
+
|
| 475 |
+
return False
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _is_bullet_point(self, text: str) -> bool:
|
| 479 |
+
"""
|
| 480 |
+
Detect if text is a bullet point or list item
|
| 481 |
+
"""
|
| 482 |
+
text = text.strip()
|
| 483 |
+
|
| 484 |
+
# Check for common bullet markers
|
| 485 |
+
bullet_markers = ['•', '·', '-', '○', '◦', '*', '►', '▪']
|
| 486 |
+
|
| 487 |
+
if (text and (text[0] in bullet_markers)):
|
| 488 |
+
return True
|
| 489 |
+
|
| 490 |
+
# Check for numbered lists
|
| 491 |
+
if (len(text) > 2):
|
| 492 |
+
|
| 493 |
+
# Pattern: "1. ", "a) ", "i. "
|
| 494 |
+
if (text[0].isdigit() and text[1] in '.):'):
|
| 495 |
+
return True
|
| 496 |
+
|
| 497 |
+
if (text[0].isalpha() and len(text) > 1 and text[1] in '.):'):
|
| 498 |
+
return True
|
| 499 |
+
|
| 500 |
+
return False
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def _is_table_row(self, line_blocks: List[Dict]) -> bool:
|
| 504 |
+
"""
|
| 505 |
+
Detect if a line is part of a table (multiple separated columns)
|
| 506 |
+
"""
|
| 507 |
+
if (len(line_blocks) < 2):
|
| 508 |
+
return False
|
| 509 |
+
|
| 510 |
+
# Calculate gaps between blocks
|
| 511 |
+
gaps = list()
|
| 512 |
+
|
| 513 |
+
for i in range(len(line_blocks) - 1):
|
| 514 |
+
gap = line_blocks[i + 1]['bbox']['x1'] - line_blocks[i]['bbox']['x2']
|
| 515 |
+
gaps.append(gap)
|
| 516 |
+
|
| 517 |
+
# If there are significant gaps, likely a table
|
| 518 |
+
significant_gaps = sum(1 for gap in gaps if gap > 30)
|
| 519 |
+
|
| 520 |
+
return (significant_gaps >= 1) and (len(line_blocks) >= 2)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def _format_table_row(self, line_blocks: List[Dict]) -> str:
|
| 524 |
+
"""
|
| 525 |
+
Format a table row with proper column alignment
|
| 526 |
+
"""
|
| 527 |
+
cells = list()
|
| 528 |
+
|
| 529 |
+
for block in line_blocks:
|
| 530 |
+
cells.append(block['text'].strip())
|
| 531 |
+
|
| 532 |
+
# Join with tab or multiple spaces for better readability
|
| 533 |
+
return (" | ".join(cells))
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def _detect_columns(self, text_blocks: List[Dict], image_size: Tuple[int, int]) -> List[Dict]:
|
| 537 |
+
"""
|
| 538 |
+
Detect multi-column layout
|
| 539 |
+
"""
|
| 540 |
+
# Group blocks by X position to detect columns
|
| 541 |
+
if not text_blocks:
|
| 542 |
+
return []
|
| 543 |
+
|
| 544 |
+
# Return single column
|
| 545 |
+
return [{'x_start': 0, 'x_end': image_size[0]}]
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def _is_heading(self, text: str, blocks: List[Dict]) -> bool:
|
| 549 |
+
"""
|
| 550 |
+
Detect if a line is likely a heading
|
| 551 |
+
|
| 552 |
+
Heuristics:
|
| 553 |
+
- All uppercase or Title Case
|
| 554 |
+
- Shorter than typical paragraph lines
|
| 555 |
+
- Often centered or left-aligned
|
| 556 |
+
- Larger font (if detectable from bbox height)
|
| 557 |
+
"""
|
| 558 |
+
words = text.split()
|
| 559 |
+
if not words:
|
| 560 |
+
return False
|
| 561 |
+
|
| 562 |
+
# Skip very short text (likely artifacts)
|
| 563 |
+
if len(text) < 3:
|
| 564 |
+
return False
|
| 565 |
+
|
| 566 |
+
# Check for common heading keywords
|
| 567 |
+
heading_keywords = ['summary', 'introduction', 'conclusion', 'analysis', 'report', 'overview', 'chapter', 'section', 'terms', 'points', 'protections', 'category', 'breakdown', 'recommendation', 'clause']
|
| 568 |
+
text_lower = text.lower()
|
| 569 |
+
has_heading_keyword = any(keyword in text_lower for keyword in heading_keywords)
|
| 570 |
+
|
| 571 |
+
# All caps or mostly caps
|
| 572 |
+
caps_ratio = sum(1 for w in words if w.isupper() and len(w) > 1) / len(words)
|
| 573 |
+
|
| 574 |
+
# Title case (each word starts with capital)
|
| 575 |
+
title_case_ratio = sum(1 for w in words if w and w[0].isupper()) / len(words)
|
| 576 |
+
|
| 577 |
+
# Short lines might be headings
|
| 578 |
+
is_short = len(text) < 100
|
| 579 |
+
|
| 580 |
+
# Check if text is likely a heading
|
| 581 |
+
is_likely_heading = ((caps_ratio > 0.7 and is_short) or # Mostly uppercase and short
|
| 582 |
+
(title_case_ratio > 0.8 and is_short and has_heading_keyword) or # Title case with keywords
|
| 583 |
+
(has_heading_keyword and is_short and title_case_ratio > 0.5) # Keywords + some capitals
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Check font size (larger bounding box height indicates heading)
|
| 587 |
+
if blocks:
|
| 588 |
+
avg_height = sum(b['bbox']['y2'] - b['bbox']['y1'] for b in blocks) / len(blocks)
|
| 589 |
+
|
| 590 |
+
# Headings often have larger font (taller bbox)
|
| 591 |
+
if (avg_height > 25): # Threshold for heading font size
|
| 592 |
+
is_likely_heading = is_likely_heading or (is_short and title_case_ratio > 0.5)
|
| 593 |
+
|
| 594 |
+
return is_likely_heading
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def _pdf_to_images(self, pdf_path: Path, pages: Optional[List[int]] = None, dpi: int = 300) -> List[Image.Image]:
|
| 598 |
+
"""
|
| 599 |
+
Convert PDF pages to high-quality images
|
| 600 |
+
"""
|
| 601 |
+
try:
|
| 602 |
+
doc = fitz.open(str(pdf_path))
|
| 603 |
+
images = list()
|
| 604 |
+
|
| 605 |
+
if pages is None:
|
| 606 |
+
pages_to_process = range(len(doc))
|
| 607 |
+
|
| 608 |
+
else:
|
| 609 |
+
pages_to_process = [p-1 for p in pages if (0 < p <= len(doc))]
|
| 610 |
+
|
| 611 |
+
for page_num in pages_to_process:
|
| 612 |
+
page = doc[page_num]
|
| 613 |
+
|
| 614 |
+
# High-quality conversion
|
| 615 |
+
zoom = dpi / 72.0
|
| 616 |
+
mat = fitz.Matrix(zoom, zoom)
|
| 617 |
+
pix = page.get_pixmap(matrix = mat, alpha = False)
|
| 618 |
+
|
| 619 |
+
# Convert to PIL Image
|
| 620 |
+
img_data = pix.tobytes("png")
|
| 621 |
+
image = Image.open(BytesIO(img_data))
|
| 622 |
+
|
| 623 |
+
if (image.mode != 'RGB'):
|
| 624 |
+
image = image.convert('RGB')
|
| 625 |
+
|
| 626 |
+
images.append(image)
|
| 627 |
+
|
| 628 |
+
doc.close()
|
| 629 |
+
return images
|
| 630 |
+
|
| 631 |
+
except Exception as e:
|
| 632 |
+
raise OCRException(f"Failed to convert PDF to images: {repr(e)}")
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def _ocr_image(self, image_array: np.ndarray) -> str:
|
| 636 |
+
"""
|
| 637 |
+
Simple OCR without layout preservation
|
| 638 |
+
"""
|
| 639 |
+
if self.use_paddle and self.paddle_ocr:
|
| 640 |
+
try:
|
| 641 |
+
result = self._ocr_with_paddle_simple(image_array)
|
| 642 |
+
|
| 643 |
+
if result:
|
| 644 |
+
return result
|
| 645 |
+
|
| 646 |
+
except Exception as e:
|
| 647 |
+
self.logger.debug(f"PaddleOCR failed: {repr(e)}")
|
| 648 |
+
|
| 649 |
+
if self.easy_ocr:
|
| 650 |
+
try:
|
| 651 |
+
result = self._ocr_with_easyocr_simple(image_array)
|
| 652 |
+
|
| 653 |
+
if result:
|
| 654 |
+
return result
|
| 655 |
+
|
| 656 |
+
except Exception as e:
|
| 657 |
+
self.logger.debug(f"EasyOCR failed: {repr(e)}")
|
| 658 |
+
|
| 659 |
+
return ""
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def _ocr_with_paddle_simple(self, image_array: np.ndarray) -> str:
|
| 663 |
+
"""
|
| 664 |
+
Simple PaddleOCR extraction
|
| 665 |
+
"""
|
| 666 |
+
result = self.paddle_ocr.ocr(image_array, cls=True)
|
| 667 |
+
|
| 668 |
+
if not result or not result[0]:
|
| 669 |
+
return ""
|
| 670 |
+
|
| 671 |
+
texts = list()
|
| 672 |
+
|
| 673 |
+
for line in result[0]:
|
| 674 |
+
if (line and (len(line) >= 2)):
|
| 675 |
+
text_info = line[1]
|
| 676 |
+
if isinstance(text_info, (list, tuple)):
|
| 677 |
+
text, conf = text_info[0], text_info[1]
|
| 678 |
+
|
| 679 |
+
else:
|
| 680 |
+
text, conf = text_info, 1.0
|
| 681 |
+
|
| 682 |
+
if ((conf > 0.5) and text):
|
| 683 |
+
texts.append(text.strip())
|
| 684 |
+
|
| 685 |
+
return "\n".join(texts)
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def _ocr_with_easyocr_simple(self, image_array: np.ndarray) -> str:
|
| 689 |
+
"""
|
| 690 |
+
Simple EasyOCR extraction
|
| 691 |
+
"""
|
| 692 |
+
result = self.easy_ocr.readtext(image_array)
|
| 693 |
+
|
| 694 |
+
if not result:
|
| 695 |
+
return ""
|
| 696 |
+
|
| 697 |
+
texts = list()
|
| 698 |
+
|
| 699 |
+
for detection in result:
|
| 700 |
+
text, conf = detection[1], detection[2]
|
| 701 |
+
if ((conf > 0.5) and text):
|
| 702 |
+
texts.append(text.strip())
|
| 703 |
+
|
| 704 |
+
return "\n".join(texts)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
@handle_errors(error_type = OCRException, log_error = True, reraise = True)
|
| 708 |
+
def extract_text_from_image(self, image_path: Path, preserve_layout: bool = True) -> str:
|
| 709 |
+
"""
|
| 710 |
+
Extract text from image file
|
| 711 |
+
"""
|
| 712 |
+
image_path = Path(image_path)
|
| 713 |
+
|
| 714 |
+
self.logger.info(f"Extracting text from image: {image_path}")
|
| 715 |
+
|
| 716 |
+
if not image_path.exists():
|
| 717 |
+
raise OCRException(f"Image file not found: {image_path}")
|
| 718 |
+
|
| 719 |
+
image = Image.open(image_path)
|
| 720 |
+
|
| 721 |
+
if (image.mode != 'RGB'):
|
| 722 |
+
image = image.convert('RGB')
|
| 723 |
+
|
| 724 |
+
if preserve_layout:
|
| 725 |
+
text = self._extract_text_with_layout(image, page_num=1)
|
| 726 |
+
|
| 727 |
+
else:
|
| 728 |
+
img_array = np.array(image)
|
| 729 |
+
text = self._ocr_image(img_array)
|
| 730 |
+
|
| 731 |
+
self.logger.info(f"Image OCR completed: {len(text)} characters extracted")
|
| 732 |
+
return text
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def get_supported_languages(self) -> List[str]:
|
| 736 |
+
"""
|
| 737 |
+
Get list of supported languages
|
| 738 |
+
"""
|
| 739 |
+
return ['en', 'es', 'fr', 'de', 'it', 'pt', 'ru', 'zh', 'ja', 'ko', 'ar']
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def get_engine_info(self) -> dict:
|
| 743 |
+
"""
|
| 744 |
+
Get information about OCR engine configuration
|
| 745 |
+
"""
|
| 746 |
+
return {"primary_engine" : "PaddleOCR" if self.use_paddle else "EasyOCR",
|
| 747 |
+
"language" : self.lang,
|
| 748 |
+
"gpu_enabled" : self.gpu,
|
| 749 |
+
"initialized" : self._initialized,
|
| 750 |
+
"layout_preservation" : True,
|
| 751 |
+
"supported_languages" : self.get_supported_languages(),
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
# Global OCR instance
|
| 756 |
+
_global_ocr_engine = None
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def get_ocr_engine() -> OCREngine:
|
| 760 |
+
"""
|
| 761 |
+
Get global OCR engine instance (singleton)
|
| 762 |
+
"""
|
| 763 |
+
global _global_ocr_engine
|
| 764 |
+
|
| 765 |
+
if _global_ocr_engine is None:
|
| 766 |
+
_global_ocr_engine = OCREngine()
|
| 767 |
+
|
| 768 |
+
return _global_ocr_engine
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def extract_text_with_ocr(file_path: Path, preserve_layout: bool = True, **kwargs) -> str:
|
| 772 |
+
"""
|
| 773 |
+
Convenience function for OCR text extraction with layout preservation
|
| 774 |
+
"""
|
| 775 |
+
ocr_engine = get_ocr_engine()
|
| 776 |
+
|
| 777 |
+
if (file_path.suffix.lower() == '.pdf'):
|
| 778 |
+
return ocr_engine.extract_text_from_pdf(file_path, preserve_layout=preserve_layout, **kwargs)
|
| 779 |
+
|
| 780 |
+
else:
|
| 781 |
+
return ocr_engine.extract_text_from_image(file_path, preserve_layout=preserve_layout, **kwargs)
|
document_parser/parser_factory.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import List
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Union
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from utils.helpers import IDGenerator
|
| 8 |
+
from config.models import DocumentType
|
| 9 |
+
from config.models import DocumentMetadata
|
| 10 |
+
from utils.file_handler import FileHandler
|
| 11 |
+
from utils.error_handler import RAGException
|
| 12 |
+
from config.logging_config import get_logger
|
| 13 |
+
from document_parser.pdf_parser import PDFParser
|
| 14 |
+
from document_parser.txt_parser import TXTParser
|
| 15 |
+
from document_parser.ocr_engine import OCREngine
|
| 16 |
+
from document_parser.docx_parser import DOCXParser
|
| 17 |
+
from utils.error_handler import InvalidFileTypeError
|
| 18 |
+
from document_parser.zip_handler import ArchiveHandler
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Setup Logging
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ParserFactory:
|
| 26 |
+
"""
|
| 27 |
+
Factory class for creating appropriate document parsers: implements Factory pattern for extensible parser selection
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self.logger = logger
|
| 31 |
+
|
| 32 |
+
# Initialize parsers (reusable instances)
|
| 33 |
+
self._parsers = {DocumentType.PDF : PDFParser(),
|
| 34 |
+
DocumentType.DOCX : DOCXParser(),
|
| 35 |
+
DocumentType.TXT : TXTParser(),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Initialize helper components
|
| 39 |
+
self._ocr_engine = None
|
| 40 |
+
self._archive_handler = None
|
| 41 |
+
|
| 42 |
+
# File extension to DocumentType mapping
|
| 43 |
+
self._extension_mapping = {'pdf' : DocumentType.PDF,
|
| 44 |
+
'docx' : DocumentType.DOCX,
|
| 45 |
+
'doc' : DocumentType.DOCX,
|
| 46 |
+
'txt' : DocumentType.TXT,
|
| 47 |
+
'text' : DocumentType.TXT,
|
| 48 |
+
'md' : DocumentType.TXT,
|
| 49 |
+
'log' : DocumentType.TXT,
|
| 50 |
+
'csv' : DocumentType.TXT,
|
| 51 |
+
'json' : DocumentType.TXT,
|
| 52 |
+
'xml' : DocumentType.TXT,
|
| 53 |
+
'png' : DocumentType.IMAGE,
|
| 54 |
+
'jpg' : DocumentType.IMAGE,
|
| 55 |
+
'jpeg' : DocumentType.IMAGE,
|
| 56 |
+
'gif' : DocumentType.IMAGE,
|
| 57 |
+
'bmp' : DocumentType.IMAGE,
|
| 58 |
+
'tiff' : DocumentType.IMAGE,
|
| 59 |
+
'webp' : DocumentType.IMAGE,
|
| 60 |
+
'zip' : DocumentType.ARCHIVE,
|
| 61 |
+
'tar' : DocumentType.ARCHIVE,
|
| 62 |
+
'gz' : DocumentType.ARCHIVE,
|
| 63 |
+
'tgz' : DocumentType.ARCHIVE,
|
| 64 |
+
'rar' : DocumentType.ARCHIVE,
|
| 65 |
+
'7z' : DocumentType.ARCHIVE,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_parser(self, file_path: Path):
|
| 70 |
+
"""
|
| 71 |
+
Get appropriate parser for file
|
| 72 |
+
|
| 73 |
+
Arguments:
|
| 74 |
+
----------
|
| 75 |
+
file_path { Path } : Path to document
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
--------
|
| 79 |
+
{ object } : Parser instance or handler
|
| 80 |
+
|
| 81 |
+
Raises:
|
| 82 |
+
-------
|
| 83 |
+
InvalidFileTypeError : If file type not supported
|
| 84 |
+
"""
|
| 85 |
+
doc_type = self.detect_document_type(file_path = file_path)
|
| 86 |
+
|
| 87 |
+
# Handle special types (image, archive)
|
| 88 |
+
if (doc_type == DocumentType.IMAGE):
|
| 89 |
+
return self._get_ocr_engine()
|
| 90 |
+
|
| 91 |
+
elif (doc_type == DocumentType.ARCHIVE):
|
| 92 |
+
return self._get_archive_handler()
|
| 93 |
+
|
| 94 |
+
# Handle standard document types
|
| 95 |
+
elif doc_type in self._parsers:
|
| 96 |
+
return self._parsers[doc_type]
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
raise InvalidFileTypeError(file_type = str(doc_type),
|
| 100 |
+
allowed_types = [t.value for t in self._parsers.keys()] + [DocumentType.IMAGE.value, DocumentType.ARCHIVE.value],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def detect_document_type(self, file_path: Path) -> Union[DocumentType, str]:
|
| 105 |
+
"""
|
| 106 |
+
Detect document type from file extension and content
|
| 107 |
+
|
| 108 |
+
Arguments:
|
| 109 |
+
----------
|
| 110 |
+
file_path { Path } : Path to document
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
--------
|
| 114 |
+
{ Union } : DocumentType enum or string for special types
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
-------
|
| 118 |
+
InvalidFileTypeError : If type cannot be determined
|
| 119 |
+
"""
|
| 120 |
+
file_path = Path(file_path)
|
| 121 |
+
|
| 122 |
+
# Get extension
|
| 123 |
+
extension = file_path.suffix.lstrip('.').lower()
|
| 124 |
+
|
| 125 |
+
# Check if extension is mapped
|
| 126 |
+
if extension in self._extension_mapping:
|
| 127 |
+
doc_type = self._extension_mapping[extension]
|
| 128 |
+
|
| 129 |
+
self.logger.debug(f"Detected type {doc_type} from extension .{extension}")
|
| 130 |
+
|
| 131 |
+
return doc_type
|
| 132 |
+
|
| 133 |
+
# Try detecting from file content
|
| 134 |
+
detected_type = FileHandler.detect_file_type(file_path)
|
| 135 |
+
|
| 136 |
+
if (detected_type and (detected_type in self._extension_mapping)):
|
| 137 |
+
doc_type = self._extension_mapping[detected_type]
|
| 138 |
+
|
| 139 |
+
self.logger.debug(f"Detected type {doc_type} from content")
|
| 140 |
+
|
| 141 |
+
return doc_type
|
| 142 |
+
|
| 143 |
+
raise InvalidFileTypeError(file_type = extension, allowed_types = list(self._extension_mapping.keys()))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def parse(self, file_path: Union[str, Path], extract_metadata: bool = True, clean_text: bool = True, **kwargs) -> tuple[str, Optional[DocumentMetadata]]:
|
| 147 |
+
"""
|
| 148 |
+
Parse document using appropriate parser
|
| 149 |
+
|
| 150 |
+
Arguments:
|
| 151 |
+
----------
|
| 152 |
+
file_path { Path } : Path to document
|
| 153 |
+
|
| 154 |
+
extract_metadata { bool } : Extract document metadata
|
| 155 |
+
|
| 156 |
+
clean_text { bool } : Clean extracted text
|
| 157 |
+
|
| 158 |
+
**kwargs : Additional parser-specific arguments
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
--------
|
| 162 |
+
{ tuple } : Tuple of (extracted_text, metadata)
|
| 163 |
+
|
| 164 |
+
Raises:
|
| 165 |
+
-------
|
| 166 |
+
InvalidFileTypeError : If file type not supported
|
| 167 |
+
|
| 168 |
+
RAGException : If parsing fails
|
| 169 |
+
"""
|
| 170 |
+
file_path = Path(file_path)
|
| 171 |
+
|
| 172 |
+
self.logger.info(f"Parsing document: {file_path}")
|
| 173 |
+
|
| 174 |
+
# Get appropriate parser/handler
|
| 175 |
+
parser = self.get_parser(file_path)
|
| 176 |
+
|
| 177 |
+
# Handle different parser types
|
| 178 |
+
if isinstance(parser, (PDFParser, DOCXParser, TXTParser)):
|
| 179 |
+
# Standard document parser
|
| 180 |
+
text, metadata = parser.parse(file_path,
|
| 181 |
+
extract_metadata = extract_metadata,
|
| 182 |
+
clean_text = clean_text,
|
| 183 |
+
**kwargs
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
elif isinstance(parser, OCREngine):
|
| 187 |
+
# Image file - use OCR
|
| 188 |
+
text = parser.extract_text_from_image(file_path)
|
| 189 |
+
metadata = self._create_image_metadata(file_path) if extract_metadata else None
|
| 190 |
+
|
| 191 |
+
elif isinstance(parser, ArchiveHandler):
|
| 192 |
+
# Archive file - extract and parse contents
|
| 193 |
+
return self._parse_archive(file_path = file_path,
|
| 194 |
+
extract_metadata = extract_metadata,
|
| 195 |
+
clean_text = clean_text,
|
| 196 |
+
**kwargs
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
else:
|
| 200 |
+
raise InvalidFileTypeError(file_type = file_path.suffix, allowed_types = self.get_supported_extensions())
|
| 201 |
+
|
| 202 |
+
self.logger.info(f"Successfully parsed {file_path.name}: {len(text)} chars, type={metadata.document_type if metadata else 'unknown'}")
|
| 203 |
+
|
| 204 |
+
return text, metadata
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _get_ocr_engine(self) -> OCREngine:
|
| 208 |
+
"""
|
| 209 |
+
Get OCR engine instance (lazy initialization)
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
--------
|
| 213 |
+
{ OCREngine } : OCR engine instance
|
| 214 |
+
"""
|
| 215 |
+
if self._ocr_engine is None:
|
| 216 |
+
self._ocr_engine = OCREngine()
|
| 217 |
+
|
| 218 |
+
return self._ocr_engine
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _get_archive_handler(self) -> ArchiveHandler:
|
| 222 |
+
"""
|
| 223 |
+
Get archive handler instance (lazy initialization)
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
--------
|
| 227 |
+
{ ArchiveHandler } : Archive handler instance
|
| 228 |
+
"""
|
| 229 |
+
if self._archive_handler is None:
|
| 230 |
+
self._archive_handler = ArchiveHandler()
|
| 231 |
+
|
| 232 |
+
return self._archive_handler
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _create_image_metadata(self, file_path: Path) -> DocumentMetadata:
|
| 236 |
+
"""
|
| 237 |
+
Create metadata for image file
|
| 238 |
+
|
| 239 |
+
Arguments:
|
| 240 |
+
----------
|
| 241 |
+
file_path { Path } : Path to image file
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
--------
|
| 245 |
+
{ DocumentMetadata } : DocumentMetadata object
|
| 246 |
+
"""
|
| 247 |
+
stat = file_path.stat()
|
| 248 |
+
|
| 249 |
+
return DocumentMetadata(document_id = IDGenerator.generate_document_id(),
|
| 250 |
+
filename = file_path.name,
|
| 251 |
+
file_path = file_path,
|
| 252 |
+
document_type = DocumentType.IMAGE,
|
| 253 |
+
file_size_bytes = stat.st_size,
|
| 254 |
+
created_date = stat.st_ctime,
|
| 255 |
+
modified_date = stat.st_mtime,
|
| 256 |
+
extra = {"file_type": "image"},
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _parse_archive(self, file_path: Path, extract_metadata: bool = True, clean_text: bool = True, **kwargs) -> tuple[str, Optional[DocumentMetadata]]:
|
| 261 |
+
"""
|
| 262 |
+
Parse archive file: extract contents and parse all supported files
|
| 263 |
+
|
| 264 |
+
Arguments:
|
| 265 |
+
----------
|
| 266 |
+
file_path { Path } : Path to archive file
|
| 267 |
+
|
| 268 |
+
extract_metadata { bool } : Extract document metadata
|
| 269 |
+
|
| 270 |
+
clean_text { bool } : Clean extracted text
|
| 271 |
+
|
| 272 |
+
**kwargs : Additional arguments
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
--------
|
| 276 |
+
{ tuple } : Tuple of (combined_text, metadata)
|
| 277 |
+
"""
|
| 278 |
+
archive_handler = self._get_archive_handler()
|
| 279 |
+
|
| 280 |
+
# Extract archive contents
|
| 281 |
+
extracted_files = archive_handler.extract_archive(file_path)
|
| 282 |
+
|
| 283 |
+
# Parse all extracted files
|
| 284 |
+
combined_text = ""
|
| 285 |
+
all_metadata = list()
|
| 286 |
+
|
| 287 |
+
for extracted_file in extracted_files:
|
| 288 |
+
if self.is_supported(extracted_file):
|
| 289 |
+
try:
|
| 290 |
+
file_text, file_metadata = self.parse(extracted_file,
|
| 291 |
+
extract_metadata = extract_metadata,
|
| 292 |
+
clean_text = clean_text,
|
| 293 |
+
**kwargs
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
combined_text += f"\n\n[FILE: {extracted_file.name}]\n{file_text}"
|
| 297 |
+
|
| 298 |
+
if file_metadata:
|
| 299 |
+
all_metadata.append(file_metadata)
|
| 300 |
+
|
| 301 |
+
except Exception as e:
|
| 302 |
+
self.logger.warning(f"Failed to parse extracted file {extracted_file}: {repr(e)}")
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
# Create combined metadata
|
| 306 |
+
combined_metadata = None
|
| 307 |
+
|
| 308 |
+
if extract_metadata and all_metadata:
|
| 309 |
+
combined_metadata = DocumentMetadata(document_id = IDGenerator.generate_document_id(),
|
| 310 |
+
filename = file_path.name,
|
| 311 |
+
file_path = file_path,
|
| 312 |
+
document_type = DocumentType.ARCHIVE,
|
| 313 |
+
file_size_bytes = file_path.stat().st_size,
|
| 314 |
+
extra = {"archive_contents" : len(extracted_files),
|
| 315 |
+
"parsed_files" : len(all_metadata),
|
| 316 |
+
"contained_documents" : [meta.document_id for meta in all_metadata],
|
| 317 |
+
}
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return combined_text.strip(), combined_metadata
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def is_supported(self, file_path: Path) -> bool:
|
| 324 |
+
"""
|
| 325 |
+
Check if file type is supported.
|
| 326 |
+
|
| 327 |
+
Arguments:
|
| 328 |
+
----------
|
| 329 |
+
file_path { Path } : Path to document
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
--------
|
| 333 |
+
{ bool } : True if supported
|
| 334 |
+
"""
|
| 335 |
+
try:
|
| 336 |
+
self.detect_document_type(file_path = file_path)
|
| 337 |
+
return True
|
| 338 |
+
|
| 339 |
+
except InvalidFileTypeError:
|
| 340 |
+
return False
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def get_supported_extensions(self) -> list[str]:
|
| 344 |
+
"""
|
| 345 |
+
Get list of supported file extensions.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
--------
|
| 349 |
+
{ list } : List of extensions (without dot)
|
| 350 |
+
"""
|
| 351 |
+
return list(self._extension_mapping.keys())
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def register_parser(self, doc_type: DocumentType, parser_instance, extensions: Optional[list[str]] = None):
|
| 355 |
+
"""
|
| 356 |
+
Register a new parser type (for extensibility)
|
| 357 |
+
|
| 358 |
+
Arguments:
|
| 359 |
+
----------
|
| 360 |
+
doc_type { DocumentType } : Document type enum
|
| 361 |
+
|
| 362 |
+
parser_instance : Parser instance
|
| 363 |
+
|
| 364 |
+
extensions { list } : File extensions to map to this parser
|
| 365 |
+
"""
|
| 366 |
+
self._parsers[doc_type] = parser_instance
|
| 367 |
+
|
| 368 |
+
if extensions:
|
| 369 |
+
for ext in extensions:
|
| 370 |
+
self._extension_mapping[ext.lstrip('.')] = doc_type
|
| 371 |
+
|
| 372 |
+
self.logger.info(f"Registered parser for {doc_type}")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def batch_parse(self, file_paths: list[Path], extract_metadata: bool = True, clean_text: bool = True, skip_errors: bool = True) -> list[tuple[Path, str, Optional[DocumentMetadata]]]:
|
| 376 |
+
"""
|
| 377 |
+
Parse multiple documents.
|
| 378 |
+
|
| 379 |
+
Arguments:
|
| 380 |
+
----------
|
| 381 |
+
file_paths { list } : List of file paths
|
| 382 |
+
|
| 383 |
+
extract_metadata { bool } : Extract metadata
|
| 384 |
+
|
| 385 |
+
clean_text { str } : Clean text
|
| 386 |
+
|
| 387 |
+
skip_errors { bool } : Skip files that fail to parse
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
--------
|
| 391 |
+
{ list } : List of (file_path, text, metadata) tuples
|
| 392 |
+
"""
|
| 393 |
+
results = list()
|
| 394 |
+
|
| 395 |
+
for file_path in file_paths:
|
| 396 |
+
try:
|
| 397 |
+
text, metadata = self.parse(file_path,
|
| 398 |
+
extract_metadata = extract_metadata,
|
| 399 |
+
clean_text = clean_text,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
results.append((file_path, text, metadata))
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
self.logger.error(f"Failed to parse {file_path}: {repr(e)}")
|
| 406 |
+
|
| 407 |
+
if not skip_errors:
|
| 408 |
+
raise
|
| 409 |
+
# Add placeholder for failed file
|
| 410 |
+
results.append((file_path, "", None))
|
| 411 |
+
|
| 412 |
+
self.logger.info(f"Batch parsed {len(results)}/{len(file_paths)} files successfully")
|
| 413 |
+
|
| 414 |
+
return results
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def parse_directory(self, directory: Path, recursive: bool = False, pattern: str = "*", **kwargs) -> list[tuple[Path, str, Optional[DocumentMetadata]]]:
|
| 418 |
+
"""
|
| 419 |
+
Parse all supported documents in a directory
|
| 420 |
+
|
| 421 |
+
Arguments:
|
| 422 |
+
----------
|
| 423 |
+
directory { Path } : Directory path
|
| 424 |
+
|
| 425 |
+
recursive { bool } : Search recursively
|
| 426 |
+
|
| 427 |
+
pattern { str } : File pattern (glob)
|
| 428 |
+
|
| 429 |
+
**kwargs : Additional parse arguments
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
--------
|
| 433 |
+
{ list } : List of (file_path, text, metadata) tuples
|
| 434 |
+
"""
|
| 435 |
+
directory = Path(directory)
|
| 436 |
+
|
| 437 |
+
# Get all files
|
| 438 |
+
all_files = FileHandler.list_files(directory, pattern=pattern, recursive=recursive)
|
| 439 |
+
|
| 440 |
+
# Filter to supported types
|
| 441 |
+
supported_files = [f for f in all_files if self.is_supported(f)]
|
| 442 |
+
|
| 443 |
+
self.logger.info(f"Found {len(supported_files)} supported files in {directory} ({len(all_files) - len(supported_files)} unsupported)")
|
| 444 |
+
|
| 445 |
+
# Parse all files
|
| 446 |
+
return self.batch_parse(supported_files, **kwargs)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def get_parser_info(self) -> dict:
|
| 450 |
+
"""
|
| 451 |
+
Get information about registered parsers
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
--------
|
| 455 |
+
{ dict } : Dictionary with parser information
|
| 456 |
+
"""
|
| 457 |
+
info = {"supported_types" : [t.value for t in self._parsers.keys()] + ['image', 'archive'],
|
| 458 |
+
"supported_extensions" : self.get_supported_extensions(),
|
| 459 |
+
"parser_classes" : {t.value: type(p).__name__ for t, p in self._parsers.items()},
|
| 460 |
+
"special_handlers" : {"image" : "OCREngine",
|
| 461 |
+
"archive" : "ArchiveHandler"},
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
return info
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# Global factory instance
|
| 468 |
+
_factory = None
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def get_parser_factory() -> ParserFactory:
|
| 472 |
+
"""
|
| 473 |
+
Get global parser factory instance (singleton)
|
| 474 |
+
|
| 475 |
+
Returns:
|
| 476 |
+
--------
|
| 477 |
+
{ ParserFactory } : ParserFactory instance
|
| 478 |
+
"""
|
| 479 |
+
global _factory
|
| 480 |
+
|
| 481 |
+
if _factory is None:
|
| 482 |
+
_factory = ParserFactory()
|
| 483 |
+
|
| 484 |
+
return _factory
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# Convenience functions
|
| 488 |
+
def parse_document(file_path: Union[str, Path], **kwargs) -> tuple[str, Optional[DocumentMetadata]]:
|
| 489 |
+
"""
|
| 490 |
+
Convenience function to parse a document
|
| 491 |
+
|
| 492 |
+
Arguments:
|
| 493 |
+
----------
|
| 494 |
+
file_path { Path } : Path to document
|
| 495 |
+
|
| 496 |
+
**kwargs : Additional arguments
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
--------
|
| 500 |
+
{ tuple } : Tuple of (text, metadata)
|
| 501 |
+
"""
|
| 502 |
+
factory = get_parser_factory()
|
| 503 |
+
|
| 504 |
+
return factory.parse(file_path, **kwargs)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def is_supported_file(file_path: Union[str, Path]) -> bool:
|
| 508 |
+
"""
|
| 509 |
+
Check if file is supported
|
| 510 |
+
|
| 511 |
+
Arguments:
|
| 512 |
+
----------
|
| 513 |
+
file_path { Path } : Path to file
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
--------
|
| 517 |
+
{ bool } : True if supported
|
| 518 |
+
"""
|
| 519 |
+
factory = get_parser_factory()
|
| 520 |
+
|
| 521 |
+
return factory.is_supported(Path(file_path))
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def get_supported_extensions() -> list[str]:
|
| 525 |
+
"""
|
| 526 |
+
Get list of supported extensions.
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
--------
|
| 530 |
+
{ list } : List of extensions
|
| 531 |
+
"""
|
| 532 |
+
factory = get_parser_factory()
|
| 533 |
+
|
| 534 |
+
return factory.get_supported_extensions()
|
document_parser/pdf_parser.py
ADDED
|
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import hashlib
|
| 3 |
+
from typing import Any
|
| 4 |
+
from typing import List
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from config.models import DocumentType
|
| 11 |
+
from utils.text_cleaner import TextCleaner
|
| 12 |
+
from config.models import DocumentMetadata
|
| 13 |
+
from config.logging_config import get_logger
|
| 14 |
+
from utils.error_handler import PDFParseError
|
| 15 |
+
from utils.error_handler import handle_errors
|
| 16 |
+
from document_parser.ocr_engine import OCREngine
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import fitz
|
| 20 |
+
PYMUPDF_AVAILABLE = True
|
| 21 |
+
|
| 22 |
+
except ImportError:
|
| 23 |
+
PYMUPDF_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import PyPDF2
|
| 27 |
+
from PyPDF2 import PdfReader
|
| 28 |
+
PYPdf2_AVAILABLE = True
|
| 29 |
+
|
| 30 |
+
except ImportError:
|
| 31 |
+
PYPdf2_AVAILABLE = False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Setup Logging
|
| 35 |
+
logger = get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PDFParser:
|
| 39 |
+
"""
|
| 40 |
+
Comprehensive PDF parsing with metadata extraction: Uses PyMuPDF (fitz) as primary parser with PyPDF2 fallback
|
| 41 |
+
|
| 42 |
+
Handles various PDF formats including encrypted and scanned documents
|
| 43 |
+
"""
|
| 44 |
+
def __init__(self, prefer_pymupdf: bool = True):
|
| 45 |
+
"""
|
| 46 |
+
Initialize PDF parser.
|
| 47 |
+
|
| 48 |
+
Arguments:
|
| 49 |
+
----------
|
| 50 |
+
prefer_pymupdf { bool } : Use PyMuPDF as primary parser if available
|
| 51 |
+
"""
|
| 52 |
+
self.logger = logger
|
| 53 |
+
self.prefer_pymupdf = prefer_pymupdf and PYMUPDF_AVAILABLE
|
| 54 |
+
self.ocr_engine = None
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
from document_parser.ocr_engine import OCREngine
|
| 58 |
+
self.ocr_available = True
|
| 59 |
+
|
| 60 |
+
except ImportError:
|
| 61 |
+
self.ocr_available = False
|
| 62 |
+
self.logger.warning("OCR engine not available - scanned PDFs may not be processed")
|
| 63 |
+
|
| 64 |
+
if (not PYMUPDF_AVAILABLE and not PYPdf2_AVAILABLE):
|
| 65 |
+
raise ImportError("Neither PyMuPDF nor PyPDF2 are available. Please install at least one.")
|
| 66 |
+
|
| 67 |
+
self.logger.info(f"PDF Parser initialized - Primary: {'PyMuPDF' if self.prefer_pymupdf else 'PyPDF2'}, PyMuPDF available: {PYMUPDF_AVAILABLE}, PyPDF2 available: {PYPdf2_AVAILABLE}")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@handle_errors(error_type=PDFParseError, log_error = True, reraise = True)
|
| 71 |
+
def parse(self, file_path: Path, extract_metadata: bool = True, clean_text: bool = True, password: Optional[str] = None) -> tuple[str, Optional[DocumentMetadata]]:
|
| 72 |
+
"""
|
| 73 |
+
Parse PDF and extract text and metadata : tries PyMuPDF first, falls back to PyPDF2 if needed
|
| 74 |
+
|
| 75 |
+
Arguments:
|
| 76 |
+
----------
|
| 77 |
+
file_path { Path } : Path to PDF file
|
| 78 |
+
|
| 79 |
+
extract_metadata { bool } : Extract document metadata
|
| 80 |
+
|
| 81 |
+
clean_text { bool } : Clean extracted text
|
| 82 |
+
|
| 83 |
+
password { str } : Password for encrypted PDFs
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
--------
|
| 87 |
+
{ tuple } : Tuple of (extracted_text, metadata)
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
-------
|
| 91 |
+
PDFParseError : If parsing fails
|
| 92 |
+
"""
|
| 93 |
+
file_path = Path(file_path)
|
| 94 |
+
|
| 95 |
+
if not file_path.exists():
|
| 96 |
+
raise PDFParseError(str(file_path), original_error = FileNotFoundError(f"PDF file not found: {file_path}"))
|
| 97 |
+
|
| 98 |
+
self.logger.info(f"Parsing PDF: {file_path}")
|
| 99 |
+
|
| 100 |
+
# Try PyMuPDF first if preferred and available
|
| 101 |
+
if (self.prefer_pymupdf and PYMUPDF_AVAILABLE):
|
| 102 |
+
try:
|
| 103 |
+
parsed_text = self._parse_with_pymupdf(file_path = file_path,
|
| 104 |
+
extract_metadata = extract_metadata,
|
| 105 |
+
clean_text = clean_text,
|
| 106 |
+
password = password,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return parsed_text
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
self.logger.warning(f"PyMuPDF parsing failed for {file_path}, falling back to PyPDF2: {repr(e)}")
|
| 113 |
+
|
| 114 |
+
# Fall back to PyPDF2
|
| 115 |
+
if PYPdf2_AVAILABLE:
|
| 116 |
+
try:
|
| 117 |
+
parsed_text = self._parse_with_pypdf2(file_path = file_path,
|
| 118 |
+
extract_metadata = extract_metadata,
|
| 119 |
+
clean_text = clean_text,
|
| 120 |
+
password = password,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return parsed_text
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
self.logger.error(f"PyPDF2 parsing also failed for {file_path}: {repr(e)}")
|
| 127 |
+
raise PDFParseError(str(file_path), original_error = e)
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
raise PDFParseError(str(file_path), original_error = RuntimeError("No PDF parsing libraries available"))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _parse_with_pymupdf(self, file_path: Path, extract_metadata: bool = True, clean_text: bool = True, password: Optional[str] = None) -> tuple[str, Optional[DocumentMetadata]]:
|
| 134 |
+
"""
|
| 135 |
+
Parse PDF using PyMuPDF (fitz) with OCR fallback for scanned documents
|
| 136 |
+
"""
|
| 137 |
+
self.logger.debug(f"Using PyMuPDF for parsing: {file_path}")
|
| 138 |
+
|
| 139 |
+
doc = None
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Open PDF with PyMuPDF
|
| 143 |
+
self.logger.debug(f"Opening document: {file_path}")
|
| 144 |
+
doc = fitz.open(str(file_path))
|
| 145 |
+
|
| 146 |
+
self.logger.debug(f"Document opened successfully, {len(doc)} pages")
|
| 147 |
+
|
| 148 |
+
# Handle encrypted PDFs
|
| 149 |
+
if (doc.needs_pass and password):
|
| 150 |
+
if not doc.authenticate(password):
|
| 151 |
+
raise PDFParseError(str(file_path), original_error = ValueError("Invalid password for encrypted PDF"))
|
| 152 |
+
|
| 153 |
+
elif (doc.needs_pass and not password):
|
| 154 |
+
raise PDFParseError(str(file_path), original_error = ValueError("PDF is encrypted but no password provided"))
|
| 155 |
+
|
| 156 |
+
# Extract text with per-page OCR fallback
|
| 157 |
+
text_content = self._extract_text_with_pymupdf(doc = doc,
|
| 158 |
+
file_path = file_path,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Extract metadata
|
| 162 |
+
metadata = None
|
| 163 |
+
if extract_metadata:
|
| 164 |
+
metadata = self._extract_metadata_with_pymupdf(doc = doc,
|
| 165 |
+
file_path = file_path,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Clean text
|
| 169 |
+
if clean_text:
|
| 170 |
+
text_content = TextCleaner.clean(text_content,
|
| 171 |
+
remove_html = True,
|
| 172 |
+
normalize_whitespace = True,
|
| 173 |
+
preserve_structure = True,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.logger.info(f"Successfully parsed PDF with PyMuPDF: {len(text_content)} characters, {len(doc)} pages")
|
| 177 |
+
return text_content, metadata
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
self.logger.error(f"PyMuPDF parsing failed for {file_path}: {repr(e)}")
|
| 181 |
+
raise
|
| 182 |
+
|
| 183 |
+
finally:
|
| 184 |
+
# Always close the document in finally block
|
| 185 |
+
if doc:
|
| 186 |
+
self.logger.debug("Closing PyMuPDF document")
|
| 187 |
+
doc.close()
|
| 188 |
+
|
| 189 |
+
def _parse_with_pypdf2(self, file_path: Path, extract_metadata: bool = True, clean_text: bool = True, password: Optional[str] = None) -> tuple[str, Optional[DocumentMetadata]]:
|
| 190 |
+
"""
|
| 191 |
+
Parse PDF using PyPDF2
|
| 192 |
+
|
| 193 |
+
Arguments:
|
| 194 |
+
----------
|
| 195 |
+
file_path { Path } : Path to PDF file
|
| 196 |
+
|
| 197 |
+
extract_metadata { bool } : Extract document metadata
|
| 198 |
+
|
| 199 |
+
clean_text { bool } : Clean extracted text
|
| 200 |
+
|
| 201 |
+
password { str } : Password for encrypted PDFs
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
--------
|
| 205 |
+
{ tuple } : Tuple of (extracted_text, metadata)
|
| 206 |
+
"""
|
| 207 |
+
self.logger.debug(f"Using PyPDF2 for parsing: {file_path}")
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
# Open PDF with PyPDF2
|
| 211 |
+
with open(file_path, 'rb') as pdf_file:
|
| 212 |
+
reader = PdfReader(pdf_file)
|
| 213 |
+
|
| 214 |
+
# Handle encrypted PDFs
|
| 215 |
+
if reader.is_encrypted:
|
| 216 |
+
if password:
|
| 217 |
+
reader.decrypt(password)
|
| 218 |
+
self.logger.info("Successfully decrypted PDF with PyPDF2")
|
| 219 |
+
|
| 220 |
+
else:
|
| 221 |
+
raise PDFParseError(str(file_path), original_error = ValueError("PDF is encrypted but no password provided"))
|
| 222 |
+
|
| 223 |
+
# Extract text from all pages
|
| 224 |
+
text_content = self._extract_text_with_pypdf2(reader = reader)
|
| 225 |
+
|
| 226 |
+
# Extract metadata
|
| 227 |
+
metadata = None
|
| 228 |
+
|
| 229 |
+
if extract_metadata:
|
| 230 |
+
metadata = self._extract_metadata_with_pypdf2(reader = reader,
|
| 231 |
+
file_path = file_path,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Clean text
|
| 235 |
+
if clean_text:
|
| 236 |
+
text_content = TextCleaner.clean(text_content,
|
| 237 |
+
remove_html = True,
|
| 238 |
+
normalize_whitespace = True,
|
| 239 |
+
preserve_structure = True,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
self.logger.info(f"Successfully parsed PDF with PyPDF2: {len(text_content)} characters, {len(reader.pages)} pages")
|
| 243 |
+
|
| 244 |
+
return text_content, metadata
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
self.logger.error(f"PyPDF2 parsing failed for {file_path}: {repr(e)}")
|
| 248 |
+
raise
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _extract_text_with_pymupdf(self, doc: "fitz.Document", file_path: Path = None) -> str:
|
| 252 |
+
"""
|
| 253 |
+
Extract text from all pages using PyMuPDF with per-page OCR fallback.
|
| 254 |
+
|
| 255 |
+
Arguments:
|
| 256 |
+
----------
|
| 257 |
+
doc : PyMuPDF document object
|
| 258 |
+
|
| 259 |
+
file_path : Path to PDF file (for OCR fallback)
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
--------
|
| 263 |
+
{ str } : Combined text from all pages
|
| 264 |
+
"""
|
| 265 |
+
text_parts = list()
|
| 266 |
+
|
| 267 |
+
for page_num in range(len(doc)):
|
| 268 |
+
try:
|
| 269 |
+
page = doc[page_num]
|
| 270 |
+
page_text = page.get_text()
|
| 271 |
+
|
| 272 |
+
if page_text and page_text.strip():
|
| 273 |
+
# Add page marker for citation purposes
|
| 274 |
+
text_parts.append(f"\n[PAGE {page_num + 1}]\n{page_text}")
|
| 275 |
+
self.logger.debug(f"Extracted {len(page_text)} chars from page {page_num + 1} with PyMuPDF")
|
| 276 |
+
|
| 277 |
+
else:
|
| 278 |
+
# No text extracted - this page might be scanned
|
| 279 |
+
self.logger.warning(f"No text extracted from page {page_num + 1} with PyMuPDF (might be scanned)")
|
| 280 |
+
|
| 281 |
+
# Try OCR for this specific page if available
|
| 282 |
+
if self.ocr_available and file_path:
|
| 283 |
+
try:
|
| 284 |
+
self.logger.info(f"Attempting OCR for page {page_num + 1}")
|
| 285 |
+
ocr_text = self._extract_page_text_with_ocr(file_path, page_num + 1)
|
| 286 |
+
|
| 287 |
+
if ocr_text and ocr_text.strip():
|
| 288 |
+
text_parts.append(f"\n[PAGE {page_num + 1} - OCR]\n{ocr_text}")
|
| 289 |
+
self.logger.info(f"OCR extracted {len(ocr_text)} chars from page {page_num + 1}")
|
| 290 |
+
|
| 291 |
+
else:
|
| 292 |
+
text_parts.append(f"\n[PAGE {page_num + 1} - NO TEXT]\n")
|
| 293 |
+
self.logger.warning(f"OCR also failed to extract text from page {page_num + 1}")
|
| 294 |
+
|
| 295 |
+
except Exception as ocr_error:
|
| 296 |
+
self.logger.warning(f"OCR failed for page {page_num + 1}: {repr(ocr_error)}")
|
| 297 |
+
text_parts.append(f"\n[PAGE {page_num + 1} - OCR FAILED]\n")
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
# No OCR available or no file_path provided
|
| 301 |
+
text_parts.append(f"\n[PAGE {page_num + 1} - NO TEXT]\n")
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
self.logger.warning(f"Error extracting text from page {page_num + 1} with PyMuPDF: {repr(e)}")
|
| 305 |
+
text_parts.append(f"\n[PAGE {page_num + 1} - ERROR: {str(e)}]\n")
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
return "\n".join(text_parts)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _extract_text_with_pypdf2(self, reader: PdfReader) -> str:
|
| 312 |
+
"""
|
| 313 |
+
Extract text from all pages using PyPDF2
|
| 314 |
+
|
| 315 |
+
Arguments:
|
| 316 |
+
----------
|
| 317 |
+
reader { PdfReader } : PdfReader object
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
--------
|
| 321 |
+
{ str } : Combined text from all pages
|
| 322 |
+
"""
|
| 323 |
+
text_parts = list()
|
| 324 |
+
num_pages = len(reader.pages)
|
| 325 |
+
|
| 326 |
+
for page_num in range(num_pages):
|
| 327 |
+
try:
|
| 328 |
+
page = reader.pages[page_num]
|
| 329 |
+
page_text = page.extract_text()
|
| 330 |
+
|
| 331 |
+
if page_text and page_text.strip():
|
| 332 |
+
# Add page marker for citation purposes
|
| 333 |
+
text_parts.append(f"\n[PAGE {page_num + 1}]\n{page_text}")
|
| 334 |
+
self.logger.debug(f"Extracted {len(page_text)} chars from page {page_num + 1} with PyPDF2")
|
| 335 |
+
|
| 336 |
+
else:
|
| 337 |
+
self.logger.warning(f"No text extracted from page {page_num + 1} with PyPDF2 (might be scanned)")
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
self.logger.warning(f"Error extracting text from page {page_num + 1} with PyPDF2: {repr(e)}")
|
| 341 |
+
continue
|
| 342 |
+
|
| 343 |
+
return "\n".join(text_parts)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _extract_metadata_with_pymupdf(self, doc: "fitz.Document", file_path: Path) -> DocumentMetadata:
|
| 347 |
+
"""
|
| 348 |
+
Extract metadata using PyMuPDF
|
| 349 |
+
|
| 350 |
+
Arguments:
|
| 351 |
+
-----------
|
| 352 |
+
doc { fitz.Document } : PyMuPDF document object
|
| 353 |
+
|
| 354 |
+
file_path { Path } : Path to PDF file
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
--------
|
| 358 |
+
{ DocumentMetadata } : DocumentMetadata object
|
| 359 |
+
"""
|
| 360 |
+
# Get PDF metadata
|
| 361 |
+
pdf_metadata = doc.metadata
|
| 362 |
+
|
| 363 |
+
# Extract common fields
|
| 364 |
+
title = pdf_metadata.get('title', '').strip()
|
| 365 |
+
author = pdf_metadata.get('author', '').strip()
|
| 366 |
+
|
| 367 |
+
# Parse dates
|
| 368 |
+
created_date = self._parse_pdf_date(pdf_metadata.get('creationDate'))
|
| 369 |
+
modified_date = self._parse_pdf_date(pdf_metadata.get('modDate'))
|
| 370 |
+
|
| 371 |
+
# Get file size
|
| 372 |
+
file_size = file_path.stat().st_size
|
| 373 |
+
|
| 374 |
+
# Count pages
|
| 375 |
+
num_pages = len(doc)
|
| 376 |
+
|
| 377 |
+
# Generate document ID
|
| 378 |
+
doc_hash = hashlib.md5(str(file_path).encode()).hexdigest()
|
| 379 |
+
doc_id = f"doc_{int(datetime.now().timestamp())}_{doc_hash}"
|
| 380 |
+
|
| 381 |
+
# Create metadata object
|
| 382 |
+
metadata = DocumentMetadata(document_id = doc_id,
|
| 383 |
+
filename = file_path.name,
|
| 384 |
+
file_path = file_path,
|
| 385 |
+
document_type = DocumentType.PDF,
|
| 386 |
+
title = title or file_path.stem,
|
| 387 |
+
author = author,
|
| 388 |
+
created_date = created_date,
|
| 389 |
+
modified_date = modified_date,
|
| 390 |
+
file_size_bytes = file_size,
|
| 391 |
+
num_pages = num_pages,
|
| 392 |
+
extra = {"pdf_version" : pdf_metadata.get('producer', ''),
|
| 393 |
+
"pdf_metadata" : {k: str(v) for k, v in pdf_metadata.items() if v},
|
| 394 |
+
"parser_used" : "pymupdf"
|
| 395 |
+
}
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
return metadata
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _extract_metadata_with_pypdf2(self, reader: PdfReader, file_path: Path) -> DocumentMetadata:
|
| 402 |
+
"""
|
| 403 |
+
Extract metadata using PyPDF2
|
| 404 |
+
|
| 405 |
+
Arguments:
|
| 406 |
+
----------
|
| 407 |
+
reader { PdfReader } : PdfReader object
|
| 408 |
+
|
| 409 |
+
file_path { Path } : Path to PDF file
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
--------
|
| 413 |
+
{ DocumentMetadata } : DocumentMetadata object
|
| 414 |
+
"""
|
| 415 |
+
# Get PDF metadata
|
| 416 |
+
pdf_info = reader.metadata if reader.metadata else {}
|
| 417 |
+
|
| 418 |
+
# Extract common fields
|
| 419 |
+
title = self._get_metadata_field(pdf_info, ['/Title', 'title'])
|
| 420 |
+
author = self._get_metadata_field(pdf_info, ['/Author', 'author'])
|
| 421 |
+
|
| 422 |
+
# Parse dates
|
| 423 |
+
created_date = self._parse_pdf_date(self._get_metadata_field(pdf_info, ['/CreationDate', 'creation_date']))
|
| 424 |
+
modified_date = self._parse_pdf_date(self._get_metadata_field(pdf_info, ['/ModDate', 'mod_date']))
|
| 425 |
+
|
| 426 |
+
# Get file size
|
| 427 |
+
file_size = file_path.stat().st_size
|
| 428 |
+
|
| 429 |
+
# Count pages
|
| 430 |
+
num_pages = len(reader.pages)
|
| 431 |
+
|
| 432 |
+
# Generate document ID
|
| 433 |
+
doc_hash = hashlib.md5(str(file_path).encode()).hexdigest()
|
| 434 |
+
doc_id = f"doc_{int(datetime.now().timestamp())}_{doc_hash}"
|
| 435 |
+
|
| 436 |
+
# Create metadata object
|
| 437 |
+
metadata = DocumentMetadata(document_id = doc_id,
|
| 438 |
+
filename = file_path.name,
|
| 439 |
+
file_path = file_path,
|
| 440 |
+
document_type = DocumentType.PDF,
|
| 441 |
+
title = title or file_path.stem,
|
| 442 |
+
author = author,
|
| 443 |
+
created_date = created_date,
|
| 444 |
+
modified_date = modified_date,
|
| 445 |
+
file_size_bytes = file_size,
|
| 446 |
+
num_pages = num_pages,
|
| 447 |
+
extra = {"pdf_version" : self._get_metadata_field(pdf_info, ['/Producer', 'producer']),
|
| 448 |
+
"pdf_metadata" : {k: str(v) for k, v in pdf_info.items() if v},
|
| 449 |
+
"parser_used" : "pypdf2",
|
| 450 |
+
}
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
return metadata
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _extract_text_with_ocr(self, file_path: Path) -> str:
|
| 457 |
+
"""
|
| 458 |
+
Extract text from scanned PDF using OCR
|
| 459 |
+
"""
|
| 460 |
+
if not self.ocr_available:
|
| 461 |
+
raise PDFParseError(str(file_path), original_error = RuntimeError("OCR engine not available"))
|
| 462 |
+
|
| 463 |
+
if self.ocr_engine is None:
|
| 464 |
+
self.ocr_engine = OCREngine()
|
| 465 |
+
|
| 466 |
+
return self.ocr_engine.extract_text_from_pdf(file_path)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _extract_page_text_with_ocr(self, file_path: Path, page_number: int) -> str:
|
| 470 |
+
"""
|
| 471 |
+
Extract text from a specific page using OCR
|
| 472 |
+
|
| 473 |
+
Arguments:
|
| 474 |
+
----------
|
| 475 |
+
file_path { Path } : Path to PDF file
|
| 476 |
+
|
| 477 |
+
page_number { int } : Page number (1-indexed)
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
--------
|
| 481 |
+
{ str } : Extracted text from the page
|
| 482 |
+
"""
|
| 483 |
+
if not self.ocr_available:
|
| 484 |
+
raise PDFParseError(str(file_path), original_error = RuntimeError("OCR engine not available"))
|
| 485 |
+
|
| 486 |
+
if self.ocr_engine is None:
|
| 487 |
+
self.ocr_engine = OCREngine()
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
# Use OCR engine to extract text from specific page
|
| 491 |
+
return self.ocr_engine.extract_text_from_pdf(pdf_path = file_path,
|
| 492 |
+
pages = [page_number],
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
except Exception as e:
|
| 496 |
+
self.logger.error(f"OCR failed for page {page_number}: {repr(e)}")
|
| 497 |
+
return ""
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def _get_metadata_field(metadata: Dict, field_names: List[str]) -> Optional[str]:
|
| 502 |
+
"""
|
| 503 |
+
Get metadata field with fallback names
|
| 504 |
+
|
| 505 |
+
Arguments:
|
| 506 |
+
----------
|
| 507 |
+
metadata { dict } : Metadata dictionary
|
| 508 |
+
|
| 509 |
+
field_names { list } : List of possible field names
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
--------
|
| 513 |
+
{ str } : Field value or None
|
| 514 |
+
"""
|
| 515 |
+
for field_name in field_names:
|
| 516 |
+
if field_name in metadata:
|
| 517 |
+
value = metadata[field_name]
|
| 518 |
+
|
| 519 |
+
if value:
|
| 520 |
+
return str(value).strip()
|
| 521 |
+
|
| 522 |
+
return None
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
@staticmethod
|
| 526 |
+
def _parse_pdf_date(date_str: Optional[str]) -> Optional[datetime]:
|
| 527 |
+
"""
|
| 528 |
+
Parse PDF date format : PDF dates are in format: D:YYYYMMDDHHmmSSOHH'mm'
|
| 529 |
+
|
| 530 |
+
Arguments:
|
| 531 |
+
----------
|
| 532 |
+
date_str { str } : PDF date string
|
| 533 |
+
|
| 534 |
+
Returns:
|
| 535 |
+
--------
|
| 536 |
+
{ datetime } : Datetime object or None
|
| 537 |
+
"""
|
| 538 |
+
if not date_str:
|
| 539 |
+
return None
|
| 540 |
+
|
| 541 |
+
try:
|
| 542 |
+
# Remove 'D:' prefix if present
|
| 543 |
+
if date_str.startswith('D:'):
|
| 544 |
+
date_str = date_str[2:]
|
| 545 |
+
|
| 546 |
+
# Parse basic format: YYYYMMDDHHMMSS
|
| 547 |
+
date_str = date_str[:14]
|
| 548 |
+
|
| 549 |
+
return datetime.strptime(date_str, '%Y%m%d%H%M%S')
|
| 550 |
+
|
| 551 |
+
except Exception:
|
| 552 |
+
return None
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def extract_page_text(self, file_path: Path, page_number: int, clean_text: bool = True) -> str:
|
| 556 |
+
"""
|
| 557 |
+
Extract text from a specific page
|
| 558 |
+
|
| 559 |
+
Arguments:
|
| 560 |
+
----------
|
| 561 |
+
file_path { Path } : Path to PDF file
|
| 562 |
+
|
| 563 |
+
page_number { int } : Page number (1-indexed)
|
| 564 |
+
|
| 565 |
+
clean_text { bool } : Clean extracted text
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
--------
|
| 569 |
+
{ str } : Page text
|
| 570 |
+
"""
|
| 571 |
+
# Try PyMuPDF first if preferred and available
|
| 572 |
+
if self.prefer_pymupdf and PYMUPDF_AVAILABLE:
|
| 573 |
+
try:
|
| 574 |
+
page_text = self._extract_page_text_pymupdf(file_path = file_path,
|
| 575 |
+
page_number = page_number,
|
| 576 |
+
clean_text = clean_text,
|
| 577 |
+
)
|
| 578 |
+
return page_text
|
| 579 |
+
|
| 580 |
+
except Exception as e:
|
| 581 |
+
self.logger.warning(f"PyMuPDF page extraction failed, falling back to PyPDF2: {repr(e)}")
|
| 582 |
+
|
| 583 |
+
# Fall back to PyPDF2
|
| 584 |
+
if PYPdf2_AVAILABLE:
|
| 585 |
+
page_text = self._extract_page_text_pypdf2(file_path = file_path,
|
| 586 |
+
pagse_number = page_number,
|
| 587 |
+
clean_text = clean_text,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
return page_text
|
| 591 |
+
|
| 592 |
+
else:
|
| 593 |
+
raise PDFParseError(str(file_path), original_error = RuntimeError("No PDF parsing libraries available"))
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def _extract_page_text_pymupdf(self, file_path: Path, page_number: int, clean_text: bool = True) -> str:
|
| 597 |
+
"""
|
| 598 |
+
Extract page text using PyMuPDF
|
| 599 |
+
"""
|
| 600 |
+
doc = None
|
| 601 |
+
try:
|
| 602 |
+
doc = fitz.open(str(file_path))
|
| 603 |
+
num_pages = len(doc)
|
| 604 |
+
|
| 605 |
+
if ((page_number < 1) or (page_number > num_pages)):
|
| 606 |
+
raise ValueError(f"Page number {page_number} out of range (1-{num_pages})")
|
| 607 |
+
|
| 608 |
+
page = doc[page_number - 1]
|
| 609 |
+
page_text = page.get_text()
|
| 610 |
+
|
| 611 |
+
if clean_text:
|
| 612 |
+
page_text = TextCleaner.clean(page_text)
|
| 613 |
+
|
| 614 |
+
return page_text
|
| 615 |
+
|
| 616 |
+
except Exception as e:
|
| 617 |
+
self.logger.error(f"Failed to extract page {page_number} with PyMuPDF: {repr(e)}")
|
| 618 |
+
raise PDFParseError(str(file_path), original_error = e)
|
| 619 |
+
|
| 620 |
+
finally:
|
| 621 |
+
if doc:
|
| 622 |
+
doc.close()
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def _extract_page_text_pypdf2(self, file_path: Path, page_number: int, clean_text: bool = True) -> str:
|
| 626 |
+
"""
|
| 627 |
+
Extract page text using PyPDF2
|
| 628 |
+
"""
|
| 629 |
+
try:
|
| 630 |
+
with open(file_path, 'rb') as pdf_file:
|
| 631 |
+
reader = PdfReader(pdf_file)
|
| 632 |
+
num_pages = len(reader.pages)
|
| 633 |
+
|
| 634 |
+
if ((page_number < 1) or (page_number > num_pages)):
|
| 635 |
+
raise ValueError(f"Page number {page_number} out of range (1-{num_pages})")
|
| 636 |
+
|
| 637 |
+
page = reader.pages[page_number - 1]
|
| 638 |
+
page_text = page.extract_text()
|
| 639 |
+
|
| 640 |
+
if clean_text:
|
| 641 |
+
page_text = TextCleaner.clean(page_text)
|
| 642 |
+
|
| 643 |
+
return page_text
|
| 644 |
+
|
| 645 |
+
except Exception as e:
|
| 646 |
+
self.logger.error(f"Failed to extract page {page_number} with PyPDF2: {repr(e)}")
|
| 647 |
+
raise PDFParseError(str(file_path), original_error = e)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def get_page_count(self, file_path: Path) -> int:
|
| 651 |
+
"""
|
| 652 |
+
Get number of pages in PDF
|
| 653 |
+
|
| 654 |
+
Arguments:
|
| 655 |
+
----------
|
| 656 |
+
file_path { Path } : Path to PDF file
|
| 657 |
+
|
| 658 |
+
Returns:
|
| 659 |
+
--------
|
| 660 |
+
{ int } : Number of pages
|
| 661 |
+
"""
|
| 662 |
+
# Try PyMuPDF first if available
|
| 663 |
+
if PYMUPDF_AVAILABLE:
|
| 664 |
+
doc = None
|
| 665 |
+
try:
|
| 666 |
+
doc = fitz.open(str(file_path))
|
| 667 |
+
page_count = len(doc)
|
| 668 |
+
|
| 669 |
+
return page_count
|
| 670 |
+
|
| 671 |
+
except Exception as e:
|
| 672 |
+
self.logger.warning(f"PyMuPDF page count failed, trying PyPDF2: {repr(e)}")
|
| 673 |
+
|
| 674 |
+
finally:
|
| 675 |
+
if doc:
|
| 676 |
+
doc.close()
|
| 677 |
+
|
| 678 |
+
# Fall back to PyPDF2
|
| 679 |
+
if PYPdf2_AVAILABLE:
|
| 680 |
+
try:
|
| 681 |
+
with open(file_path, 'rb') as pdf_file:
|
| 682 |
+
reader = PdfReader(pdf_file)
|
| 683 |
+
|
| 684 |
+
return len(reader.pages)
|
| 685 |
+
|
| 686 |
+
except Exception as e:
|
| 687 |
+
self.logger.error(f"Failed to get page count: {repr(e)}")
|
| 688 |
+
raise PDFParseError(str(file_path), original_error = e)
|
| 689 |
+
else:
|
| 690 |
+
raise PDFParseError(str(file_path), original_error = RuntimeError("No PDF parsing libraries available"))
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def extract_page_range(self, file_path: Path, start_page: int, end_page: int, clean_text: bool = True) -> str:
|
| 694 |
+
"""
|
| 695 |
+
Extract text from a range of pages
|
| 696 |
+
|
| 697 |
+
Arguments:
|
| 698 |
+
----------
|
| 699 |
+
file_path { Path } : Path to PDF file
|
| 700 |
+
|
| 701 |
+
start_page { int } : Starting page (1-indexed, inclusive)
|
| 702 |
+
|
| 703 |
+
end_page { int } : Ending page (1-indexed, inclusive)
|
| 704 |
+
|
| 705 |
+
clean_text { bool } : Clean extracted text
|
| 706 |
+
|
| 707 |
+
Returns:
|
| 708 |
+
--------
|
| 709 |
+
{ str } : Combined text from pages
|
| 710 |
+
"""
|
| 711 |
+
# Try PyMuPDF first if preferred and available
|
| 712 |
+
if self.prefer_pymupdf and PYMUPDF_AVAILABLE:
|
| 713 |
+
try:
|
| 714 |
+
page_range = self._extract_page_range_pymupdf(file_path = file_path,
|
| 715 |
+
start_page = start_page,
|
| 716 |
+
end_page = end_page,
|
| 717 |
+
clean_text = clean_text,
|
| 718 |
+
)
|
| 719 |
+
return page_range
|
| 720 |
+
|
| 721 |
+
except Exception as e:
|
| 722 |
+
self.logger.warning(f"PyMuPDF page range extraction failed, falling back to PyPDF2: {repr(e)}")
|
| 723 |
+
|
| 724 |
+
# Fall back to PyPDF2
|
| 725 |
+
if PYPdf2_AVAILABLE:
|
| 726 |
+
page_range = self._extract_page_range_pypdf2(file_path = file_path,
|
| 727 |
+
start_page = start_page,
|
| 728 |
+
end_page = end_page,
|
| 729 |
+
clean_text = clean_text,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
return page_range
|
| 733 |
+
|
| 734 |
+
else:
|
| 735 |
+
raise PDFParseError(str(file_path), original_error = RuntimeError("No PDF parsing libraries available"))
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
def _extract_page_range_pymupdf(self, file_path: Path, start_page: int, end_page: int, clean_text: bool = True) -> str:
|
| 739 |
+
"""
|
| 740 |
+
Extract page range using PyMuPDF
|
| 741 |
+
"""
|
| 742 |
+
doc = None
|
| 743 |
+
try:
|
| 744 |
+
doc = fitz.open(str(file_path))
|
| 745 |
+
num_pages = len(doc)
|
| 746 |
+
|
| 747 |
+
if ((start_page < 1) or (end_page > num_pages) or (start_page > end_page)):
|
| 748 |
+
raise ValueError(f"Invalid page range {start_page}-{end_page} for PDF with {num_pages} pages")
|
| 749 |
+
|
| 750 |
+
text_parts = list()
|
| 751 |
+
|
| 752 |
+
for page_num in range(start_page - 1, end_page):
|
| 753 |
+
page = doc[page_num]
|
| 754 |
+
page_text = page.get_text()
|
| 755 |
+
|
| 756 |
+
if page_text:
|
| 757 |
+
text_parts.append(f"\n[PAGE {page_num + 1}]\n{page_text}")
|
| 758 |
+
|
| 759 |
+
combined_text = "\n".join(text_parts)
|
| 760 |
+
|
| 761 |
+
if clean_text:
|
| 762 |
+
combined_text = TextCleaner.clean(combined_text)
|
| 763 |
+
|
| 764 |
+
return combined_text
|
| 765 |
+
|
| 766 |
+
except Exception as e:
|
| 767 |
+
self.logger.error(f"Failed to extract page range with PyMuPDF: {repr(e)}")
|
| 768 |
+
raise PDFParseError(str(file_path), original_error = e)
|
| 769 |
+
|
| 770 |
+
finally:
|
| 771 |
+
if doc:
|
| 772 |
+
doc.close()
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def _extract_page_range_pypdf2(self, file_path: Path, start_page: int, end_page: int, clean_text: bool = True) -> str:
|
| 776 |
+
"""
|
| 777 |
+
Extract page range using PyPDF2
|
| 778 |
+
"""
|
| 779 |
+
try:
|
| 780 |
+
with open(file_path, 'rb') as pdf_file:
|
| 781 |
+
reader = PdfReader(pdf_file)
|
| 782 |
+
num_pages = len(reader.pages)
|
| 783 |
+
|
| 784 |
+
if ((start_page < 1) or (end_page > num_pages) or (start_page > end_page)):
|
| 785 |
+
raise ValueError(f"Invalid page range {start_page}-{end_page} for PDF with {num_pages} pages")
|
| 786 |
+
|
| 787 |
+
text_parts = list()
|
| 788 |
+
|
| 789 |
+
for page_num in range(start_page - 1, end_page):
|
| 790 |
+
page = reader.pages[page_num]
|
| 791 |
+
page_text = page.extract_text()
|
| 792 |
+
|
| 793 |
+
if page_text:
|
| 794 |
+
text_parts.append(f"\n[PAGE {page_num + 1}]\n{page_text}")
|
| 795 |
+
|
| 796 |
+
combined_text = "\n".join(text_parts)
|
| 797 |
+
|
| 798 |
+
if clean_text:
|
| 799 |
+
combined_text = TextCleaner.clean(combined_text)
|
| 800 |
+
|
| 801 |
+
return combined_text
|
| 802 |
+
|
| 803 |
+
except Exception as e:
|
| 804 |
+
self.logger.error(f"Failed to extract page range with PyPDF2: {repr(e)}")
|
| 805 |
+
raise PDFParseError(str(file_path), original_error = e)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def is_scanned(self, file_path: Path) -> bool:
|
| 809 |
+
"""
|
| 810 |
+
Check if PDF is scanned (image-based): Scanned PDFs have very little or no extractable text
|
| 811 |
+
|
| 812 |
+
Arguments:
|
| 813 |
+
----------
|
| 814 |
+
file_path { Path } : Path to PDF file
|
| 815 |
+
|
| 816 |
+
Returns:
|
| 817 |
+
--------
|
| 818 |
+
{ bool } : True if appears to be scanned
|
| 819 |
+
"""
|
| 820 |
+
# Try PyMuPDF first if available for better detection
|
| 821 |
+
if PYMUPDF_AVAILABLE:
|
| 822 |
+
try:
|
| 823 |
+
return self._is_scanned_pymupdf(file_path = file_path)
|
| 824 |
+
|
| 825 |
+
except Exception as e:
|
| 826 |
+
self.logger.warning(f"PyMuPDF scanned detection failed, trying PyPDF2: {repr(e)}")
|
| 827 |
+
|
| 828 |
+
# Fall back to PyPDF2
|
| 829 |
+
if PYPdf2_AVAILABLE:
|
| 830 |
+
return self._is_scanned_pypdf2(file_path = file_path)
|
| 831 |
+
|
| 832 |
+
else:
|
| 833 |
+
self.logger.warning("No PDF parsing libraries available for scanned detection")
|
| 834 |
+
return False
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
def _is_scanned_pymupdf(self, file_path: Path) -> bool:
|
| 838 |
+
"""
|
| 839 |
+
Check if PDF is scanned using PyMuPDF
|
| 840 |
+
"""
|
| 841 |
+
doc = None
|
| 842 |
+
try:
|
| 843 |
+
doc = fitz.open(str(file_path))
|
| 844 |
+
|
| 845 |
+
# Sample first 3 pages
|
| 846 |
+
pages_to_check = min(3, len(doc))
|
| 847 |
+
total_text_length = 0
|
| 848 |
+
|
| 849 |
+
for i in range(pages_to_check):
|
| 850 |
+
page = doc[i]
|
| 851 |
+
text = page.get_text()
|
| 852 |
+
total_text_length += len(text.strip())
|
| 853 |
+
|
| 854 |
+
# If average text per page is very low, likely scanned
|
| 855 |
+
avg_text_per_page = total_text_length / pages_to_check
|
| 856 |
+
|
| 857 |
+
# characters per page
|
| 858 |
+
threshold = 100
|
| 859 |
+
|
| 860 |
+
is_scanned = (avg_text_per_page < threshold)
|
| 861 |
+
|
| 862 |
+
if is_scanned:
|
| 863 |
+
self.logger.info(f"PDF appears to be scanned (avg {avg_text_per_page:.0f} chars/page)")
|
| 864 |
+
|
| 865 |
+
return is_scanned
|
| 866 |
+
|
| 867 |
+
except Exception as e:
|
| 868 |
+
self.logger.warning(f"Could not determine if PDF is scanned with PyMuPDF: {repr(e)}")
|
| 869 |
+
return False
|
| 870 |
+
|
| 871 |
+
finally:
|
| 872 |
+
if doc:
|
| 873 |
+
doc.close()
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
def _is_scanned_pypdf2(self, file_path: Path) -> bool:
|
| 877 |
+
"""
|
| 878 |
+
Check if PDF is scanned using PyPDF2
|
| 879 |
+
"""
|
| 880 |
+
try:
|
| 881 |
+
with open(file_path, 'rb') as pdf_file:
|
| 882 |
+
reader = PdfReader(pdf_file)
|
| 883 |
+
|
| 884 |
+
# Sample first 3 pages
|
| 885 |
+
pages_to_check = min(3, len(reader.pages))
|
| 886 |
+
total_text_length = 0
|
| 887 |
+
|
| 888 |
+
for i in range(pages_to_check):
|
| 889 |
+
page = reader.pages[i]
|
| 890 |
+
text = page.extract_text()
|
| 891 |
+
total_text_length += len(text.strip())
|
| 892 |
+
|
| 893 |
+
# If average text per page is very low, likely scanned
|
| 894 |
+
avg_text_per_page = total_text_length / pages_to_check
|
| 895 |
+
|
| 896 |
+
# characters per page
|
| 897 |
+
threshold = 100
|
| 898 |
+
|
| 899 |
+
is_scanned = (avg_text_per_page < threshold)
|
| 900 |
+
|
| 901 |
+
if is_scanned:
|
| 902 |
+
self.logger.info(f"PDF appears to be scanned (avg {avg_text_per_page:.0f} chars/page)")
|
| 903 |
+
|
| 904 |
+
return is_scanned
|
| 905 |
+
|
| 906 |
+
except Exception as e:
|
| 907 |
+
self.logger.warning(f"Could not determine if PDF is scanned with PyPDF2: {repr(e)}")
|
| 908 |
+
return False
|
document_parser/txt_parser.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import chardet
|
| 3 |
+
import hashlib
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from config.models import DocumentType
|
| 8 |
+
from utils.text_cleaner import TextCleaner
|
| 9 |
+
from config.models import DocumentMetadata
|
| 10 |
+
from config.logging_config import get_logger
|
| 11 |
+
from utils.error_handler import handle_errors
|
| 12 |
+
from utils.error_handler import TextEncodingError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Setup Logging
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TXTParser:
|
| 20 |
+
"""
|
| 21 |
+
Plain text file parser with automatic encoding detection : handles various text encodings and formats
|
| 22 |
+
"""
|
| 23 |
+
# Common encodings to try
|
| 24 |
+
COMMON_ENCODINGS = ['utf-8',
|
| 25 |
+
'utf-16',
|
| 26 |
+
'ascii',
|
| 27 |
+
'latin-1',
|
| 28 |
+
'cp1252',
|
| 29 |
+
'iso-8859-1',
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.logger = logger
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@handle_errors(error_type = TextEncodingError, log_error = True, reraise = True)
|
| 37 |
+
def parse(self, file_path: Path, extract_metadata: bool = True, clean_text: bool = True, encoding: Optional[str] = None) -> tuple[str, Optional[DocumentMetadata]]:
|
| 38 |
+
"""
|
| 39 |
+
Parse text file and extract content
|
| 40 |
+
|
| 41 |
+
Arguments:
|
| 42 |
+
-----------
|
| 43 |
+
file_path { Path } : Path to text file
|
| 44 |
+
|
| 45 |
+
extract_metadata { bool } : Extract document metadata
|
| 46 |
+
|
| 47 |
+
clean_text { bool } : Clean extracted text
|
| 48 |
+
|
| 49 |
+
encoding { str } : Force specific encoding (None = auto-detect)
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
--------
|
| 53 |
+
{ tuple } : Tuple of (extracted_text, metadata)
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
-------
|
| 57 |
+
TextEncodingError : If file cannot be decoded
|
| 58 |
+
"""
|
| 59 |
+
file_path = Path(file_path)
|
| 60 |
+
|
| 61 |
+
if not file_path.exists():
|
| 62 |
+
raise TextEncodingError(str(file_path), encoding = "unknown", original_error = FileNotFoundError(f"Text file not found: {file_path}"))
|
| 63 |
+
|
| 64 |
+
self.logger.info(f"Parsing TXT: {file_path}")
|
| 65 |
+
|
| 66 |
+
# Detect encoding if not specified
|
| 67 |
+
if encoding is None:
|
| 68 |
+
encoding = self.detect_encoding(file_path)
|
| 69 |
+
self.logger.info(f"Detected encoding: {encoding}")
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
# Read file with detected/specified encoding
|
| 73 |
+
with open(file_path, 'r', encoding = encoding, errors = 'replace') as f:
|
| 74 |
+
text_content = f.read()
|
| 75 |
+
|
| 76 |
+
# Extract metadata
|
| 77 |
+
metadata = None
|
| 78 |
+
|
| 79 |
+
if extract_metadata:
|
| 80 |
+
metadata = self._extract_metadata(file_path = file_path,
|
| 81 |
+
encoding = encoding,
|
| 82 |
+
text_length = len(text_content),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Clean text
|
| 86 |
+
if clean_text:
|
| 87 |
+
text_content = TextCleaner.clean(text_content,
|
| 88 |
+
remove_html = False,
|
| 89 |
+
normalize_whitespace = True,
|
| 90 |
+
preserve_structure = True,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.logger.info(f"Successfully parsed TXT: {len(text_content)} characters")
|
| 94 |
+
|
| 95 |
+
return text_content, metadata
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
self.logger.error(f"Failed to parse TXT {file_path}: {repr(e)}")
|
| 99 |
+
raise TextEncodingError(str(file_path), encoding = encoding, original_error = e)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def detect_encoding(self, file_path: Path) -> str:
|
| 103 |
+
"""
|
| 104 |
+
Detect file encoding using chardet
|
| 105 |
+
|
| 106 |
+
Arguments:
|
| 107 |
+
----------
|
| 108 |
+
file_path { Path } : Path to text file
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
--------
|
| 112 |
+
{ str } : Detected encoding name
|
| 113 |
+
"""
|
| 114 |
+
try:
|
| 115 |
+
# Read raw bytes
|
| 116 |
+
with open(file_path, 'rb') as f:
|
| 117 |
+
# Read first 10KB for detection
|
| 118 |
+
raw_data = f.read(10000)
|
| 119 |
+
|
| 120 |
+
# Detect encoding
|
| 121 |
+
result = chardet.detect(raw_data)
|
| 122 |
+
encoding = result['encoding']
|
| 123 |
+
confidence = result['confidence']
|
| 124 |
+
|
| 125 |
+
self.logger.debug(f"Encoding detection: {encoding} (confidence: {confidence:.2%})")
|
| 126 |
+
|
| 127 |
+
# If confidence is low, try common encodings
|
| 128 |
+
if (confidence < 0.7):
|
| 129 |
+
self.logger.warning(f"Low confidence ({confidence:.2%}) for detected encoding {encoding}")
|
| 130 |
+
encoding = self._try_common_encodings(file_path = file_path)
|
| 131 |
+
|
| 132 |
+
# Fallback to UTF-8
|
| 133 |
+
return encoding or 'utf-8'
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
self.logger.warning(f"Encoding detection failed: {repr(e)}, using UTF-8")
|
| 137 |
+
return 'utf-8'
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _try_common_encodings(self, file_path: Path) -> Optional[str]:
|
| 141 |
+
"""
|
| 142 |
+
Try reading file with common encodings
|
| 143 |
+
|
| 144 |
+
Arguments:
|
| 145 |
+
----------
|
| 146 |
+
file_path { Path } : Path to text file
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
--------
|
| 150 |
+
{ str } : Working encoding or None
|
| 151 |
+
"""
|
| 152 |
+
for encoding in self.COMMON_ENCODINGS:
|
| 153 |
+
try:
|
| 154 |
+
with open(file_path, 'r', encoding = encoding) as f:
|
| 155 |
+
# Try reading first 1000 chars
|
| 156 |
+
f.read(1000)
|
| 157 |
+
|
| 158 |
+
self.logger.info(f"Successfully read with encoding: {encoding}")
|
| 159 |
+
return encoding
|
| 160 |
+
|
| 161 |
+
except (UnicodeDecodeError, LookupError):
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _extract_metadata(self, file_path: Path, encoding: str, text_length: int) -> DocumentMetadata:
|
| 168 |
+
"""
|
| 169 |
+
Extract metadata from text file
|
| 170 |
+
|
| 171 |
+
Arguments:
|
| 172 |
+
----------
|
| 173 |
+
file_path { Path } : Path to text file
|
| 174 |
+
|
| 175 |
+
encoding { str } : File encoding
|
| 176 |
+
|
| 177 |
+
text_length { int } : Length of text content
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
--------
|
| 181 |
+
{ DocumentMetadata } : DocumentMetadata object
|
| 182 |
+
"""
|
| 183 |
+
# Get file stats
|
| 184 |
+
stat = file_path.stat()
|
| 185 |
+
file_size = stat.st_size
|
| 186 |
+
created_time = datetime.fromtimestamp(stat.st_ctime)
|
| 187 |
+
modified_time = datetime.fromtimestamp(stat.st_mtime)
|
| 188 |
+
|
| 189 |
+
# Generate document ID
|
| 190 |
+
doc_hash = hashlib.md5(str(file_path).encode()).hexdigest()
|
| 191 |
+
doc_id = f"doc_{int(datetime.now().timestamp())}_{doc_hash}"
|
| 192 |
+
|
| 193 |
+
# Estimate pages (rough: 3000 characters per page)
|
| 194 |
+
estimated_pages = max(1, text_length // 3000)
|
| 195 |
+
|
| 196 |
+
# Count lines
|
| 197 |
+
with open(file_path, 'r', encoding = encoding, errors = 'replace') as f:
|
| 198 |
+
num_lines = sum(1 for _ in f)
|
| 199 |
+
|
| 200 |
+
# Create metadata object
|
| 201 |
+
metadata = DocumentMetadata(document_id = doc_id,
|
| 202 |
+
filename = file_path.name,
|
| 203 |
+
file_path = file_path,
|
| 204 |
+
document_type = DocumentType.TXT,
|
| 205 |
+
title = file_path.stem,
|
| 206 |
+
created_date = created_time,
|
| 207 |
+
modified_date = modified_time,
|
| 208 |
+
file_size_bytes = file_size,
|
| 209 |
+
num_pages = estimated_pages,
|
| 210 |
+
extra = {"encoding" : encoding,
|
| 211 |
+
"num_lines" : num_lines,
|
| 212 |
+
"text_length" : text_length,
|
| 213 |
+
}
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return metadata
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def read_lines(self, file_path: Path, start_line: int = 0, end_line: Optional[int] = None, encoding: Optional[str] = None) -> list[str]:
|
| 220 |
+
"""
|
| 221 |
+
Read specific lines from file
|
| 222 |
+
|
| 223 |
+
Arguments:
|
| 224 |
+
-----------
|
| 225 |
+
file_path { Path } : Path to text file
|
| 226 |
+
|
| 227 |
+
start_line { int } : Starting line (0-indexed)
|
| 228 |
+
|
| 229 |
+
end_line { int } : Ending line (None = end of file)
|
| 230 |
+
|
| 231 |
+
encoding { str } : File encoding (None = auto-detect)
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
--------
|
| 235 |
+
{ list } : List of lines
|
| 236 |
+
"""
|
| 237 |
+
if encoding is None:
|
| 238 |
+
encoding = self.detect_encoding(file_path)
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
with open(file_path, 'r', encoding = encoding, errors = 'replace') as f:
|
| 242 |
+
lines = f.readlines()
|
| 243 |
+
|
| 244 |
+
if end_line is None:
|
| 245 |
+
return lines[start_line:]
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
return lines[start_line:end_line]
|
| 249 |
+
|
| 250 |
+
except Exception as e:
|
| 251 |
+
self.logger.error(f"Failed to read lines: {repr(e)}")
|
| 252 |
+
raise TextEncodingError(str(file_path), encoding = encoding, original_error = e)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def count_lines(self, file_path: Path, encoding: Optional[str] = None) -> int:
|
| 256 |
+
"""
|
| 257 |
+
Count number of lines in file
|
| 258 |
+
|
| 259 |
+
Arguments:
|
| 260 |
+
----------
|
| 261 |
+
file_path { Path } : Path to text file
|
| 262 |
+
|
| 263 |
+
encoding { str } : File encoding (None = auto-detect)
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
--------
|
| 267 |
+
{ int } : Number of lines
|
| 268 |
+
"""
|
| 269 |
+
if encoding is None:
|
| 270 |
+
encoding = self.detect_encoding(file_path)
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
with open(file_path, 'r', encoding = encoding, errors = 'replace') as f:
|
| 274 |
+
return sum(1 for _ in f)
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
self.logger.error(f"Failed to count lines: {repr(e)}")
|
| 278 |
+
raise TextEncodingError(str(file_path), encoding = encoding, original_error = e)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_file_info(self, file_path: Path) -> dict:
|
| 282 |
+
"""
|
| 283 |
+
Get comprehensive file information
|
| 284 |
+
|
| 285 |
+
Arguments:
|
| 286 |
+
----------
|
| 287 |
+
file_path { Path } : Path to text file
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
--------
|
| 291 |
+
{ dict } : Dictionary with file info
|
| 292 |
+
"""
|
| 293 |
+
encoding = self.detect_encoding(file_path)
|
| 294 |
+
|
| 295 |
+
with open(file_path, 'r', encoding = encoding, errors = 'replace') as f:
|
| 296 |
+
content = f.read()
|
| 297 |
+
|
| 298 |
+
lines = content.split('\n')
|
| 299 |
+
|
| 300 |
+
return {"encoding" : encoding,
|
| 301 |
+
"size_bytes" : file_path.stat().st_size,
|
| 302 |
+
"num_lines" : len(lines),
|
| 303 |
+
"num_characters" : len(content),
|
| 304 |
+
"num_words" : len(content.split()),
|
| 305 |
+
"avg_line_length" : sum(len(line) for line in lines) / len(lines) if lines else 0,
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def is_empty(self, file_path: Path) -> bool:
|
| 310 |
+
"""
|
| 311 |
+
Check if file is empty or contains only whitespace
|
| 312 |
+
|
| 313 |
+
Arguments:
|
| 314 |
+
----------
|
| 315 |
+
file_path { Path } : Path to text file
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
--------
|
| 319 |
+
{ bool } : True if empty
|
| 320 |
+
"""
|
| 321 |
+
try:
|
| 322 |
+
# Check file size first
|
| 323 |
+
if file_path.stat().st_size == 0:
|
| 324 |
+
return True
|
| 325 |
+
|
| 326 |
+
# Read and check content
|
| 327 |
+
encoding = self.detect_encoding(file_path)
|
| 328 |
+
|
| 329 |
+
with open(file_path, 'r', encoding = encoding, errors = 'replace') as f:
|
| 330 |
+
content = f.read().strip()
|
| 331 |
+
|
| 332 |
+
return len(content) == 0
|
| 333 |
+
|
| 334 |
+
except Exception as e:
|
| 335 |
+
self.logger.warning(f"Error checking if file is empty: {repr(e)}")
|
| 336 |
+
return True
|
document_parser/zip_handler.py
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import py7zr
|
| 3 |
+
import zipfile
|
| 4 |
+
import tarfile
|
| 5 |
+
import rarfile
|
| 6 |
+
from typing import List
|
| 7 |
+
from typing import Dict
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from tempfile import TemporaryDirectory
|
| 11 |
+
from config.settings import get_settings
|
| 12 |
+
from utils.file_handler import FileHandler
|
| 13 |
+
from config.logging_config import get_logger
|
| 14 |
+
from utils.error_handler import handle_errors
|
| 15 |
+
from utils.error_handler import ArchiveException
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Setup Settings and Logging
|
| 19 |
+
settings = get_settings()
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ArchiveHandler:
|
| 24 |
+
"""
|
| 25 |
+
Comprehensive archive file handler supporting multiple formats
|
| 26 |
+
ZIP, TAR, RAR, 7Z with recursive extraction and validation
|
| 27 |
+
"""
|
| 28 |
+
# Supported archive formats and their handlers
|
| 29 |
+
SUPPORTED_FORMATS = {'.zip' : 'zip',
|
| 30 |
+
'.tar' : 'tar',
|
| 31 |
+
'.gz' : 'tar',
|
| 32 |
+
'.tgz' : 'tar',
|
| 33 |
+
'.rar' : 'rar',
|
| 34 |
+
'.7z' : '7z',
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def __init__(self, max_size_mb: int = 2048, max_files: int = 10000, allow_recursive: bool = True):
|
| 39 |
+
"""
|
| 40 |
+
Initialize archive handler
|
| 41 |
+
|
| 42 |
+
Arguments:
|
| 43 |
+
----------
|
| 44 |
+
max_size_mb { int } : Maximum uncompressed size in MB
|
| 45 |
+
|
| 46 |
+
max_files { int } : Maximum number of files to extract
|
| 47 |
+
|
| 48 |
+
allow_recursive { bool } : Allow nested archives
|
| 49 |
+
"""
|
| 50 |
+
self.logger = logger
|
| 51 |
+
self.max_size_bytes = max_size_mb * 1024 * 1024
|
| 52 |
+
self.max_files = max_files
|
| 53 |
+
self.allow_recursive = allow_recursive
|
| 54 |
+
self.extracted_files_count = 0
|
| 55 |
+
self.total_extracted_size = 0
|
| 56 |
+
self._temp_dirs = list()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@handle_errors(error_type = ArchiveException, log_error = True, reraise = True)
|
| 60 |
+
def extract_archive(self, archive_path: Path, extract_dir: Optional[Path] = None, flatten_structure: bool = False, filter_extensions: Optional[List[str]] = None) -> List[Path]:
|
| 61 |
+
"""
|
| 62 |
+
Extract archive and return list of extracted file paths
|
| 63 |
+
|
| 64 |
+
Arguments:
|
| 65 |
+
----------
|
| 66 |
+
archive_path { Path } : Path to archive file
|
| 67 |
+
|
| 68 |
+
extract_dir { Path } : Directory to extract to (None = temp directory)
|
| 69 |
+
|
| 70 |
+
flatten_structure { bool } : Ignore directory structure, extract all files to root
|
| 71 |
+
|
| 72 |
+
filter_extensions { list } : Only extract files with these extensions
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
--------
|
| 76 |
+
{ list } : List of paths to extracted files
|
| 77 |
+
"""
|
| 78 |
+
archive_path = Path(archive_path)
|
| 79 |
+
|
| 80 |
+
if not archive_path.exists():
|
| 81 |
+
raise ArchiveException(f"Archive file not found: {archive_path}")
|
| 82 |
+
|
| 83 |
+
# Validate archive size
|
| 84 |
+
self._validate_archive_size(archive_path = archive_path)
|
| 85 |
+
|
| 86 |
+
# Determine extraction directory
|
| 87 |
+
if extract_dir is None:
|
| 88 |
+
temp_dir = TemporaryDirectory()
|
| 89 |
+
extract_dir = Path(temp_dir.name)
|
| 90 |
+
# Keep reference to prevent cleanup
|
| 91 |
+
self._temp_dirs.append(temp_dir)
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
extract_dir = Path(extract_dir)
|
| 95 |
+
extract_dir.mkdir(parents = True, exist_ok = True)
|
| 96 |
+
|
| 97 |
+
self.logger.info(f"Extracting archive: {archive_path} to {extract_dir}")
|
| 98 |
+
|
| 99 |
+
# Reset counters
|
| 100 |
+
self.extracted_files_count = 0
|
| 101 |
+
self.total_extracted_size = 0
|
| 102 |
+
|
| 103 |
+
# Extract based on format
|
| 104 |
+
archive_format = self._detect_archive_format(archive_path = archive_path)
|
| 105 |
+
|
| 106 |
+
extracted_files = self._extract_by_format(archive_path = archive_path,
|
| 107 |
+
extract_dir = extract_dir,
|
| 108 |
+
format = archive_format,
|
| 109 |
+
flatten_structure = flatten_structure,
|
| 110 |
+
filter_extensions = filter_extensions,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.logger.info(f"Extracted {len(extracted_files)} files from {archive_path} ({self.total_extracted_size} bytes)")
|
| 114 |
+
|
| 115 |
+
return extracted_files
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _extract_by_format(self, archive_path: Path, extract_dir: Path, format: str, flatten_structure: bool, filter_extensions: Optional[List[str]]) -> List[Path]:
|
| 119 |
+
"""
|
| 120 |
+
Extract archive based on format
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
if (format == 'zip'):
|
| 124 |
+
return self._extract_zip(archive_path = archive_path,
|
| 125 |
+
extract_dir = extract_dir,
|
| 126 |
+
flatten_structure = flatten_structure,
|
| 127 |
+
filter_extensions = filter_extensions,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
elif (format == 'tar'):
|
| 131 |
+
return self._extract_tar(archive_path = archive_path,
|
| 132 |
+
extract_dir = extract_dir,
|
| 133 |
+
flatten_structure = flatten_structure,
|
| 134 |
+
filter_extensions = filter_extensions,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
elif (format == 'rar'):
|
| 138 |
+
return self._extract_rar(archive_path = archive_path,
|
| 139 |
+
extract_dir = extract_dir,
|
| 140 |
+
flatten_structure = flatten_structure,
|
| 141 |
+
filter_extensions = filter_extensions,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
elif (format == '7z'):
|
| 145 |
+
return self._extract_7z(archive_path = archive_path,
|
| 146 |
+
extract_dir = extract_dir,
|
| 147 |
+
flatten_structure = flatten_structure,
|
| 148 |
+
filter_extensions = filter_extensions,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
raise ArchiveException(f"Unsupported archive format: {format}")
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
raise ArchiveException(f"Failed to extract {format} archive: {repr(e)}")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _extract_zip(self, archive_path: Path, extract_dir: Path, flatten_structure: bool, filter_extensions: Optional[List[str]]) -> List[Path]:
|
| 159 |
+
"""
|
| 160 |
+
Extract ZIP archive
|
| 161 |
+
"""
|
| 162 |
+
extracted_files = list()
|
| 163 |
+
|
| 164 |
+
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
|
| 165 |
+
# Validate files before extraction
|
| 166 |
+
file_list = zip_ref.namelist()
|
| 167 |
+
|
| 168 |
+
self._validate_file_count(file_count = len(file_list))
|
| 169 |
+
|
| 170 |
+
for file_info in zip_ref.infolist():
|
| 171 |
+
# Use enhanced filtering
|
| 172 |
+
if self._should_extract_file(file_info.filename, filter_extensions):
|
| 173 |
+
try:
|
| 174 |
+
extracted_path = self._extract_zip_file(zip_ref = zip_ref,
|
| 175 |
+
file_info = file_info,
|
| 176 |
+
extract_dir = extract_dir,
|
| 177 |
+
flatten_structure = flatten_structure,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if extracted_path:
|
| 181 |
+
extracted_files.append(extracted_path)
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
self.logger.warning(f"Failed to extract {file_info.filename}: {repr(e)}")
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
return extracted_files
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _extract_zip_file(self, zip_ref, file_info, extract_dir: Path, flatten_structure: bool) -> Optional[Path]:
|
| 191 |
+
"""
|
| 192 |
+
Extract single file from ZIP
|
| 193 |
+
"""
|
| 194 |
+
# Skip directories
|
| 195 |
+
if file_info.filename.endswith('/'):
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
# Determine extraction path
|
| 199 |
+
if flatten_structure:
|
| 200 |
+
target_filename = Path(file_info.filename).name
|
| 201 |
+
extract_path = extract_dir / self._safe_filename(filename = target_filename)
|
| 202 |
+
|
| 203 |
+
else:
|
| 204 |
+
extract_path = extract_dir / self._safe_archive_path(archive_path = file_info.filename)
|
| 205 |
+
|
| 206 |
+
# Ensure parent directory exists
|
| 207 |
+
extract_path.parent.mkdir(parents = True, exist_ok = True)
|
| 208 |
+
|
| 209 |
+
# Check limits
|
| 210 |
+
self._check_extraction_limits(file_size = file_info.file_size)
|
| 211 |
+
|
| 212 |
+
# Extract file
|
| 213 |
+
zip_ref.extract(file_info, extract_dir)
|
| 214 |
+
|
| 215 |
+
# Rename if flattening structure
|
| 216 |
+
if flatten_structure:
|
| 217 |
+
original_path = extract_dir / file_info.filename
|
| 218 |
+
if (original_path != extract_path):
|
| 219 |
+
original_path.rename(extract_path)
|
| 220 |
+
|
| 221 |
+
self.extracted_files_count += 1
|
| 222 |
+
self.total_extracted_size += file_info.file_size
|
| 223 |
+
|
| 224 |
+
return extract_path
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _extract_tar(self, archive_path: Path, extract_dir: Path, flatten_structure: bool, filter_extensions: Optional[List[str]]) -> List[Path]:
|
| 228 |
+
"""
|
| 229 |
+
Extract TAR archive
|
| 230 |
+
"""
|
| 231 |
+
extracted_files = list()
|
| 232 |
+
|
| 233 |
+
# Determine compression
|
| 234 |
+
mode = 'r'
|
| 235 |
+
|
| 236 |
+
if (archive_path.suffix.lower() in ['.gz', '.tgz']):
|
| 237 |
+
mode = 'r:gz'
|
| 238 |
+
|
| 239 |
+
elif (archive_path.suffix.lower() == '.bz2'):
|
| 240 |
+
mode = 'r:bz2'
|
| 241 |
+
|
| 242 |
+
elif (archive_path.suffix.lower() == '.xz'):
|
| 243 |
+
mode = 'r:xz'
|
| 244 |
+
|
| 245 |
+
with tarfile.open(archive_path, mode) as tar_ref:
|
| 246 |
+
# Validate files before extraction
|
| 247 |
+
file_list = tar_ref.getnames()
|
| 248 |
+
self._validate_file_count(file_count = len(file_list))
|
| 249 |
+
|
| 250 |
+
for member in tar_ref.getmembers():
|
| 251 |
+
if self._should_extract_file(member.name, filter_extensions) and member.isfile():
|
| 252 |
+
try:
|
| 253 |
+
extracted_path = self._extract_tar_file(tar_ref = tar_ref,
|
| 254 |
+
member = member,
|
| 255 |
+
extract_dir = extract_dir,
|
| 256 |
+
flatten_structure = flatten_structure,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if extracted_path:
|
| 260 |
+
extracted_files.append(extracted_path)
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
self.logger.warning(f"Failed to extract {member.name}: {repr(e)}")
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
return extracted_files
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _extract_tar_file(self, tar_ref, member, extract_dir: Path, flatten_structure: bool) -> Optional[Path]:
|
| 270 |
+
"""
|
| 271 |
+
Extract single file from TAR
|
| 272 |
+
"""
|
| 273 |
+
# Determine extraction path
|
| 274 |
+
if flatten_structure:
|
| 275 |
+
target_filename = Path(member.name).name
|
| 276 |
+
extract_path = extract_dir / self._safe_filename(filename = target_filename)
|
| 277 |
+
|
| 278 |
+
else:
|
| 279 |
+
extract_path = extract_dir / self._safe_archive_path(archive_path = member.name)
|
| 280 |
+
|
| 281 |
+
# Ensure parent directory exists
|
| 282 |
+
extract_path.parent.mkdir(parents = True, exist_ok = True)
|
| 283 |
+
|
| 284 |
+
# Check limits
|
| 285 |
+
self._check_extraction_limits(file_size = member.size)
|
| 286 |
+
|
| 287 |
+
# Extract file
|
| 288 |
+
tar_ref.extract(member, extract_dir)
|
| 289 |
+
|
| 290 |
+
# Rename if flattening structure
|
| 291 |
+
if flatten_structure:
|
| 292 |
+
original_path = extract_dir / member.name
|
| 293 |
+
|
| 294 |
+
if (original_path != extract_path):
|
| 295 |
+
original_path.rename(extract_path)
|
| 296 |
+
|
| 297 |
+
self.extracted_files_count += 1
|
| 298 |
+
self.total_extracted_size += member.size
|
| 299 |
+
|
| 300 |
+
return extract_path
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _extract_rar(self, archive_path: Path, extract_dir: Path, flatten_structure: bool, filter_extensions: Optional[List[str]]) -> List[Path]:
|
| 304 |
+
"""
|
| 305 |
+
Extract RAR archive
|
| 306 |
+
"""
|
| 307 |
+
extracted_files = list()
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
with rarfile.RarFile(archive_path) as rar_ref:
|
| 311 |
+
# Validate files before extraction
|
| 312 |
+
file_list = rar_ref.namelist()
|
| 313 |
+
|
| 314 |
+
self._validate_file_count(file_count = len(file_list))
|
| 315 |
+
|
| 316 |
+
for file_info in rar_ref.infolist():
|
| 317 |
+
if (self._should_extract_file(filename = file_info.filename, filter_extensions = filter_extensions) and not file_info.isdir()):
|
| 318 |
+
try:
|
| 319 |
+
extracted_path = self._extract_rar_file(rar_ref = rar_ref,
|
| 320 |
+
file_info = file_info,
|
| 321 |
+
extract_dir = extract_dir,
|
| 322 |
+
flatten_structure = flatten_structure,
|
| 323 |
+
)
|
| 324 |
+
if extracted_path:
|
| 325 |
+
extracted_files.append(extracted_path)
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
self.logger.warning(f"Failed to extract {file_info.filename}: {repr(e)}")
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
except rarfile.NotRarFile:
|
| 332 |
+
raise ArchiveException(f"Not a valid RAR file: {archive_path}")
|
| 333 |
+
|
| 334 |
+
except rarfile.BadRarFile:
|
| 335 |
+
raise ArchiveException(f"Corrupted RAR file: {archive_path}")
|
| 336 |
+
|
| 337 |
+
return extracted_files
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _extract_rar_file(self, rar_ref, file_info, extract_dir: Path, flatten_structure: bool) -> Optional[Path]:
|
| 341 |
+
"""
|
| 342 |
+
Extract single file from RAR
|
| 343 |
+
"""
|
| 344 |
+
# Determine extraction path
|
| 345 |
+
if flatten_structure:
|
| 346 |
+
target_filename = Path(file_info.filename).name
|
| 347 |
+
extract_path = extract_dir / self._safe_filename(filename = target_filename)
|
| 348 |
+
|
| 349 |
+
else:
|
| 350 |
+
extract_path = extract_dir / self._safe_archive_path(archive_path = file_info.filename)
|
| 351 |
+
|
| 352 |
+
# Ensure parent directory exists
|
| 353 |
+
extract_path.parent.mkdir(parents = True, exist_ok = True)
|
| 354 |
+
|
| 355 |
+
# Check limits
|
| 356 |
+
self._check_extraction_limits(file_size = file_info.file_size)
|
| 357 |
+
|
| 358 |
+
# Extract file
|
| 359 |
+
rar_ref.extract(file_info.filename, extract_dir)
|
| 360 |
+
|
| 361 |
+
# Rename if flattening structure
|
| 362 |
+
if flatten_structure:
|
| 363 |
+
original_path = extract_dir / file_info.filename
|
| 364 |
+
|
| 365 |
+
if (original_path != extract_path):
|
| 366 |
+
original_path.rename(extract_path)
|
| 367 |
+
|
| 368 |
+
self.extracted_files_count += 1
|
| 369 |
+
self.total_extracted_size += file_info.file_size
|
| 370 |
+
|
| 371 |
+
return extract_path
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _extract_7z(self, archive_path: Path, extract_dir: Path, flatten_structure: bool, filter_extensions: Optional[List[str]]) -> List[Path]:
|
| 375 |
+
"""
|
| 376 |
+
Extract 7Z archive
|
| 377 |
+
"""
|
| 378 |
+
extracted_files = list()
|
| 379 |
+
|
| 380 |
+
with py7zr.SevenZipFile(archive_path, 'r') as sevenz_ref:
|
| 381 |
+
# Get file list
|
| 382 |
+
file_list = sevenz_ref.getnames()
|
| 383 |
+
self._validate_file_count(file_count = len(file_list))
|
| 384 |
+
|
| 385 |
+
# Extract all files
|
| 386 |
+
sevenz_ref.extractall(extract_dir)
|
| 387 |
+
|
| 388 |
+
# Process extracted files
|
| 389 |
+
for filename in file_list:
|
| 390 |
+
if self._should_extract_file(filename = filename, filter_extensions = filter_extensions):
|
| 391 |
+
original_path = extract_dir / filename
|
| 392 |
+
|
| 393 |
+
if original_path.is_file():
|
| 394 |
+
if flatten_structure:
|
| 395 |
+
target_path = extract_dir / self._safe_filename(filename = Path(filename).name)
|
| 396 |
+
|
| 397 |
+
if (original_path != target_path):
|
| 398 |
+
original_path.rename(target_path)
|
| 399 |
+
extracted_files.append(target_path)
|
| 400 |
+
|
| 401 |
+
else:
|
| 402 |
+
extracted_files.append(original_path)
|
| 403 |
+
|
| 404 |
+
else:
|
| 405 |
+
extracted_files.append(original_path)
|
| 406 |
+
|
| 407 |
+
# Update counters
|
| 408 |
+
self.extracted_files_count += 1
|
| 409 |
+
self.total_extracted_size += original_path.stat().st_size
|
| 410 |
+
|
| 411 |
+
return extracted_files
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _is_system_file(self, filename: str) -> bool:
|
| 415 |
+
"""
|
| 416 |
+
Check if file is a system/metadata file that should be skipped
|
| 417 |
+
"""
|
| 418 |
+
system_patterns = ['__MACOSX',
|
| 419 |
+
'.DS_Store',
|
| 420 |
+
'Thumbs.db',
|
| 421 |
+
'desktop.ini',
|
| 422 |
+
'~$', # Temporary office files
|
| 423 |
+
'._', # macOS resource fork
|
| 424 |
+
'#recycle', # Recycle bin
|
| 425 |
+
'@eaDir', # Synology index
|
| 426 |
+
]
|
| 427 |
+
|
| 428 |
+
filename_str = str(filename).lower()
|
| 429 |
+
path_parts = Path(filename).parts
|
| 430 |
+
|
| 431 |
+
# Check for system patterns in filename or path
|
| 432 |
+
for pattern in system_patterns:
|
| 433 |
+
if pattern.lower() in filename_str:
|
| 434 |
+
return True
|
| 435 |
+
|
| 436 |
+
# Skip hidden files and directories (except current/parent dir references)
|
| 437 |
+
for part in path_parts:
|
| 438 |
+
if part.startswith(('.', '_')) and part not in ['.', '..']:
|
| 439 |
+
return True
|
| 440 |
+
|
| 441 |
+
# Skip common backup and temporary files
|
| 442 |
+
temp_extensions = ['.tmp', '.temp', '.bak', '.backup']
|
| 443 |
+
|
| 444 |
+
if any(Path(filename).suffix.lower() == ext for ext in temp_extensions):
|
| 445 |
+
return True
|
| 446 |
+
|
| 447 |
+
return False
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def _should_extract_file(self, filename: str, filter_extensions: Optional[List[str]]) -> bool:
|
| 451 |
+
"""
|
| 452 |
+
Check if file should be extracted based on filters and system files
|
| 453 |
+
"""
|
| 454 |
+
# Skip system files and metadata
|
| 455 |
+
if self._is_system_file(filename):
|
| 456 |
+
self.logger.debug(f"Skipping system file: {filename}")
|
| 457 |
+
return False
|
| 458 |
+
|
| 459 |
+
# Apply extension filters if provided
|
| 460 |
+
if filter_extensions is not None:
|
| 461 |
+
file_ext = Path(filename).suffix.lower()
|
| 462 |
+
return file_ext in [ext.lower() for ext in filter_extensions]
|
| 463 |
+
|
| 464 |
+
return True
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _safe_archive_path(self, archive_path: str) -> Path:
|
| 468 |
+
"""
|
| 469 |
+
Convert archive path to safe filesystem path
|
| 470 |
+
"""
|
| 471 |
+
safe_parts = list()
|
| 472 |
+
|
| 473 |
+
for part in Path(archive_path).parts:
|
| 474 |
+
safe_parts.append(self._safe_filename(filename = part))
|
| 475 |
+
|
| 476 |
+
return Path(*safe_parts)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def _safe_filename(self, filename: str) -> str:
|
| 480 |
+
"""
|
| 481 |
+
Ensure filename is safe for filesystem
|
| 482 |
+
"""
|
| 483 |
+
return FileHandler.safe_filename(filename)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def _detect_archive_format(self, archive_path: Path) -> str:
|
| 487 |
+
"""
|
| 488 |
+
Detect archive format from file extension
|
| 489 |
+
"""
|
| 490 |
+
suffix = archive_path.suffix.lower()
|
| 491 |
+
|
| 492 |
+
for ext, format_type in self.SUPPORTED_FORMATS.items():
|
| 493 |
+
if ((suffix == ext) or (ext == '.tar' and suffix in ['.gz', '.tgz', '.bz2', '.xz'])):
|
| 494 |
+
return format_type
|
| 495 |
+
|
| 496 |
+
raise ArchiveException(f"Unsupported archive format: {suffix}")
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _validate_archive_size(self, archive_path: Path):
|
| 500 |
+
"""
|
| 501 |
+
Validate archive size against limits
|
| 502 |
+
"""
|
| 503 |
+
file_size = archive_path.stat().st_size
|
| 504 |
+
|
| 505 |
+
if (file_size > self.max_size_bytes):
|
| 506 |
+
raise ArchiveException(f"Archive size {file_size} exceeds maximum {self.max_size_bytes}")
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _validate_file_count(self, file_count: int):
|
| 510 |
+
"""
|
| 511 |
+
Validate number of files in archive
|
| 512 |
+
"""
|
| 513 |
+
if (file_count > self.max_files):
|
| 514 |
+
raise ArchiveException(f"Archive contains {file_count} files, exceeds maximum {self.max_files}")
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def _check_extraction_limits(self, file_size: int):
|
| 518 |
+
"""
|
| 519 |
+
Check if extraction limits are exceeded
|
| 520 |
+
"""
|
| 521 |
+
if (self.extracted_files_count >= self.max_files):
|
| 522 |
+
raise ArchiveException(f"Maximum file count ({self.max_files}) exceeded")
|
| 523 |
+
|
| 524 |
+
if (self.total_extracted_size + file_size > self.max_size_bytes):
|
| 525 |
+
raise ArchiveException(f"Maximum extraction size ({self.max_size_bytes}) exceeded")
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def list_contents(self, archive_path: Path) -> List[Dict]:
|
| 529 |
+
"""
|
| 530 |
+
List contents of archive without extraction
|
| 531 |
+
|
| 532 |
+
Arguments:
|
| 533 |
+
----------
|
| 534 |
+
archive_path { Path } : Path to archive file
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
--------
|
| 538 |
+
{ list } : List of file information dictionaries
|
| 539 |
+
"""
|
| 540 |
+
archive_path = Path(archive_path)
|
| 541 |
+
format_type = self._detect_archive_format(archive_path)
|
| 542 |
+
|
| 543 |
+
try:
|
| 544 |
+
if (format_type == 'zip'):
|
| 545 |
+
return self._list_zip_contents(archive_path)
|
| 546 |
+
|
| 547 |
+
elif (format_type == 'tar'):
|
| 548 |
+
return self._list_tar_contents(archive_path)
|
| 549 |
+
|
| 550 |
+
elif (format_type == 'rar'):
|
| 551 |
+
return self._list_rar_contents(archive_path)
|
| 552 |
+
|
| 553 |
+
elif (format_type == '7z'):
|
| 554 |
+
return self._list_7z_contents(archive_path)
|
| 555 |
+
|
| 556 |
+
else:
|
| 557 |
+
return []
|
| 558 |
+
|
| 559 |
+
except Exception as e:
|
| 560 |
+
self.logger.error(f"Failed to list archive contents: {repr(e)}")
|
| 561 |
+
return []
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _list_zip_contents(self, archive_path: Path) -> List[Dict]:
|
| 565 |
+
"""
|
| 566 |
+
List ZIP archive contents
|
| 567 |
+
"""
|
| 568 |
+
contents = list()
|
| 569 |
+
|
| 570 |
+
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
|
| 571 |
+
for file_info in zip_ref.infolist():
|
| 572 |
+
contents.append({'filename' : file_info.filename,
|
| 573 |
+
'file_size' : file_info.file_size,
|
| 574 |
+
'compress_size' : file_info.compress_size,
|
| 575 |
+
'is_dir' : file_info.filename.endswith('/'),
|
| 576 |
+
})
|
| 577 |
+
return contents
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def _list_tar_contents(self, archive_path: Path) -> List[Dict]:
|
| 581 |
+
"""
|
| 582 |
+
List TAR archive contents
|
| 583 |
+
"""
|
| 584 |
+
contents = list()
|
| 585 |
+
mode = 'r'
|
| 586 |
+
|
| 587 |
+
if archive_path.suffix.lower() in ['.gz', '.tgz']:
|
| 588 |
+
mode = 'r:gz'
|
| 589 |
+
|
| 590 |
+
with tarfile.open(archive_path, mode) as tar_ref:
|
| 591 |
+
for member in tar_ref.getmembers():
|
| 592 |
+
contents.append({'filename' : member.name,
|
| 593 |
+
'file_size' : member.size,
|
| 594 |
+
'is_dir' : member.isdir(),
|
| 595 |
+
'is_file' : member.isfile(),
|
| 596 |
+
})
|
| 597 |
+
return contents
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def _list_rar_contents(self, archive_path: Path) -> List[Dict]:
|
| 601 |
+
"""
|
| 602 |
+
List RAR archive contents
|
| 603 |
+
"""
|
| 604 |
+
contents = list()
|
| 605 |
+
|
| 606 |
+
with rarfile.RarFile(archive_path) as rar_ref:
|
| 607 |
+
for file_info in rar_ref.infolist():
|
| 608 |
+
contents.append({'filename' : file_info.filename,
|
| 609 |
+
'file_size' : file_info.file_size,
|
| 610 |
+
'compress_size' : file_info.compress_size,
|
| 611 |
+
'is_dir' : file_info.isdir(),
|
| 612 |
+
})
|
| 613 |
+
return contents
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def _list_7z_contents(self, archive_path: Path) -> List[Dict]:
|
| 617 |
+
"""
|
| 618 |
+
List 7Z archive contents
|
| 619 |
+
"""
|
| 620 |
+
contents = list()
|
| 621 |
+
|
| 622 |
+
with py7zr.SevenZipFile(archive_path, 'r') as sevenz_ref:
|
| 623 |
+
for filename in sevenz_ref.getnames():
|
| 624 |
+
# 7z doesn't provide detailed file info in listing
|
| 625 |
+
contents.append({'filename' : filename,
|
| 626 |
+
'file_size' : 0,
|
| 627 |
+
'is_dir' : filename.endswith('/'),
|
| 628 |
+
})
|
| 629 |
+
return contents
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def is_supported_archive(self, file_path: Path) -> bool:
|
| 633 |
+
"""
|
| 634 |
+
Check if file is a supported archive format
|
| 635 |
+
|
| 636 |
+
Arguments:
|
| 637 |
+
----------
|
| 638 |
+
file_path { Path } : Path to file
|
| 639 |
+
|
| 640 |
+
Returns:
|
| 641 |
+
--------
|
| 642 |
+
{ bool } : True if supported archive format
|
| 643 |
+
"""
|
| 644 |
+
try:
|
| 645 |
+
self._detect_archive_format(file_path)
|
| 646 |
+
return True
|
| 647 |
+
|
| 648 |
+
except ArchiveException:
|
| 649 |
+
return False
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def get_supported_formats(self) -> List[str]:
|
| 653 |
+
"""
|
| 654 |
+
Get list of supported archive formats
|
| 655 |
+
|
| 656 |
+
Returns:
|
| 657 |
+
--------
|
| 658 |
+
{ list } : List of supported file extensions
|
| 659 |
+
"""
|
| 660 |
+
return list(self.SUPPORTED_FORMATS.keys())
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
# Global archive handler instance
|
| 664 |
+
_global_archive_handler = None
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def get_archive_handler() -> ArchiveHandler:
|
| 668 |
+
"""
|
| 669 |
+
Get global archive handler instance (singleton)
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
--------
|
| 673 |
+
{ ArchiveHandler } : ArchiveHandler instance
|
| 674 |
+
"""
|
| 675 |
+
global _global_archive_handler
|
| 676 |
+
|
| 677 |
+
if _global_archive_handler is None:
|
| 678 |
+
_global_archive_handler = ArchiveHandler()
|
| 679 |
+
|
| 680 |
+
return _global_archive_handler
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def extract_archive(archive_path: Path, **kwargs) -> List[Path]:
|
| 684 |
+
"""
|
| 685 |
+
Convenience function for archive extraction
|
| 686 |
+
|
| 687 |
+
Arguments:
|
| 688 |
+
----------
|
| 689 |
+
archive_path { Path } : Path to archive file
|
| 690 |
+
|
| 691 |
+
**kwargs : Additional arguments for ArchiveHandler
|
| 692 |
+
|
| 693 |
+
Returns:
|
| 694 |
+
--------
|
| 695 |
+
{ list } : List of extracted file paths
|
| 696 |
+
"""
|
| 697 |
+
handler = get_archive_handler()
|
| 698 |
+
|
| 699 |
+
return handler.extract_archive(archive_path, **kwargs)
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def is_archive_file(file_path: Path) -> bool:
|
| 703 |
+
"""
|
| 704 |
+
Check if file is a supported archive
|
| 705 |
+
|
| 706 |
+
Arguments:
|
| 707 |
+
----------
|
| 708 |
+
file_path { Path } : Path to file
|
| 709 |
+
|
| 710 |
+
Returns:
|
| 711 |
+
--------
|
| 712 |
+
{ bool } : True if supported archive
|
| 713 |
+
"""
|
| 714 |
+
handler = get_archive_handler()
|
| 715 |
+
|
| 716 |
+
return handler.is_supported_archive(file_path)
|
embeddings/__init__.py
ADDED
|
File without changes
|
embeddings/batch_processor.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from numpy.typing import NDArray
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.logging_config import get_logger
|
| 8 |
+
from utils.error_handler import handle_errors
|
| 9 |
+
from utils.error_handler import EmbeddingError
|
| 10 |
+
from chunking.token_counter import get_token_counter
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
from utils.helpers import BatchProcessor as BaseBatchProcessor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Setup Settings and Logging
|
| 16 |
+
settings = get_settings()
|
| 17 |
+
logger = get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BatchProcessor:
|
| 21 |
+
"""
|
| 22 |
+
Efficient batch processing for embeddings: Handles large batches with memory optimization and progress tracking
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.logger = logger
|
| 26 |
+
self.base_processor = BaseBatchProcessor()
|
| 27 |
+
|
| 28 |
+
# Batch processing statistics
|
| 29 |
+
self.total_batches = 0
|
| 30 |
+
self.total_texts = 0
|
| 31 |
+
self.failed_batches = 0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@handle_errors(error_type = EmbeddingError, log_error = True, reraise = True)
|
| 35 |
+
def process_embeddings_batch(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True, **kwargs) -> List[NDArray]:
|
| 36 |
+
"""
|
| 37 |
+
Process embeddings in optimized batches
|
| 38 |
+
|
| 39 |
+
Arguments:
|
| 40 |
+
----------
|
| 41 |
+
model { SentenceTransformer } : Embedding model
|
| 42 |
+
|
| 43 |
+
texts { list } : List of texts to embed
|
| 44 |
+
|
| 45 |
+
batch_size { int } : Batch size (default from settings)
|
| 46 |
+
|
| 47 |
+
normalize { bool } : Normalize embeddings
|
| 48 |
+
|
| 49 |
+
**kwargs : Additional model.encode parameters
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
--------
|
| 53 |
+
{ list } : List of embedding vectors
|
| 54 |
+
"""
|
| 55 |
+
if not texts:
|
| 56 |
+
return []
|
| 57 |
+
|
| 58 |
+
batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE
|
| 59 |
+
|
| 60 |
+
self.logger.debug(f"Processing {len(texts)} texts in batches of {batch_size}")
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
# Use model's built-in batching with optimization
|
| 64 |
+
embeddings = model.encode(texts,
|
| 65 |
+
batch_size = batch_size,
|
| 66 |
+
normalize_embeddings = normalize,
|
| 67 |
+
show_progress_bar = False,
|
| 68 |
+
convert_to_numpy = True,
|
| 69 |
+
**kwargs
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Update statistics
|
| 73 |
+
self.total_batches += ((len(texts) + batch_size - 1) // batch_size)
|
| 74 |
+
self.total_texts += len(texts)
|
| 75 |
+
|
| 76 |
+
self.logger.debug(f"Successfully generated {len(embeddings)} embeddings")
|
| 77 |
+
|
| 78 |
+
# Convert to list of arrays
|
| 79 |
+
return list(embeddings)
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
self.failed_batches += 1
|
| 83 |
+
self.logger.error(f"Batch embedding failed: {repr(e)}")
|
| 84 |
+
raise EmbeddingError(f"Batch processing failed: {repr(e)}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def process_embeddings_with_fallback(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True) -> List[NDArray]:
|
| 88 |
+
"""
|
| 89 |
+
Process embeddings with automatic batch size reduction on failure
|
| 90 |
+
|
| 91 |
+
Arguments:
|
| 92 |
+
----------
|
| 93 |
+
model { SentenceTransformer } : Embedding model
|
| 94 |
+
|
| 95 |
+
texts { list } : List of texts
|
| 96 |
+
|
| 97 |
+
batch_size { int } : Initial batch size
|
| 98 |
+
|
| 99 |
+
normalize { bool } : Normalize embeddings
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
--------
|
| 103 |
+
{ list } : List of embeddings
|
| 104 |
+
"""
|
| 105 |
+
batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
return self.process_embeddings_batch(model = model,
|
| 109 |
+
texts = texts,
|
| 110 |
+
batch_size = batch_size,
|
| 111 |
+
normalize = normalize,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
except (MemoryError, RuntimeError) as e:
|
| 115 |
+
self.logger.warning(f"Batch size {batch_size} failed, reducing to {batch_size // 2}")
|
| 116 |
+
|
| 117 |
+
# Reduce batch size and retry
|
| 118 |
+
return self.process_embeddings_batch(model = model,
|
| 119 |
+
texts = texts,
|
| 120 |
+
batch_size = batch_size // 2,
|
| 121 |
+
normalize = normalize,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def split_into_optimal_batches(self, texts: List[str], target_batch_size: int, max_batch_size: int = 1000) -> List[List[str]]:
|
| 126 |
+
"""
|
| 127 |
+
Split texts into optimal batches considering token counts
|
| 128 |
+
|
| 129 |
+
Arguments:
|
| 130 |
+
----------
|
| 131 |
+
texts { list } : List of texts
|
| 132 |
+
|
| 133 |
+
target_batch_size { int } : Target batch size in texts
|
| 134 |
+
|
| 135 |
+
max_batch_size { int } : Maximum batch size to allow
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
--------
|
| 139 |
+
{ list } : List of text batches
|
| 140 |
+
"""
|
| 141 |
+
if not texts:
|
| 142 |
+
return []
|
| 143 |
+
|
| 144 |
+
token_counter = get_token_counter()
|
| 145 |
+
batches = list()
|
| 146 |
+
current_batch = list()
|
| 147 |
+
current_tokens = 0
|
| 148 |
+
|
| 149 |
+
# Estimate tokens per text (average of first 10 or all if less)
|
| 150 |
+
sample_size = min(10, len(texts))
|
| 151 |
+
sample_tokens = [token_counter.count_tokens(text) for text in texts[:sample_size]]
|
| 152 |
+
avg_tokens = sum(sample_tokens) / len(sample_tokens) if sample_tokens else 100
|
| 153 |
+
|
| 154 |
+
# Target tokens per batch (approximate)
|
| 155 |
+
target_tokens = target_batch_size * avg_tokens
|
| 156 |
+
|
| 157 |
+
for text in texts:
|
| 158 |
+
text_tokens = token_counter.count_tokens(text)
|
| 159 |
+
|
| 160 |
+
# If single text is too large, put it in its own batch
|
| 161 |
+
if (text_tokens > (target_tokens * 0.8)):
|
| 162 |
+
if current_batch:
|
| 163 |
+
batches.append(current_batch)
|
| 164 |
+
current_batch = list()
|
| 165 |
+
current_tokens = 0
|
| 166 |
+
|
| 167 |
+
batches.append([text])
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
# Check if adding this text would exceed limits
|
| 171 |
+
if (((current_tokens + text_tokens) > target_tokens) and current_batch) or (len(current_batch) >= max_batch_size):
|
| 172 |
+
batches.append(current_batch)
|
| 173 |
+
current_batch = list()
|
| 174 |
+
current_tokens = 0
|
| 175 |
+
|
| 176 |
+
current_batch.append(text)
|
| 177 |
+
current_tokens += text_tokens
|
| 178 |
+
|
| 179 |
+
# Add final batch
|
| 180 |
+
if current_batch:
|
| 181 |
+
batches.append(current_batch)
|
| 182 |
+
|
| 183 |
+
self.logger.debug(f"Split {len(texts)} texts into {len(batches)} optimal batches")
|
| 184 |
+
|
| 185 |
+
return batches
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def process_batches_with_progress(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, progress_callback: Optional[callable] = None, **kwargs) -> List[NDArray]:
|
| 189 |
+
"""
|
| 190 |
+
Process batches with progress reporting
|
| 191 |
+
|
| 192 |
+
Arguments:
|
| 193 |
+
----------
|
| 194 |
+
model { SentenceTransformer } : Embedding model
|
| 195 |
+
|
| 196 |
+
texts { list } : List of texts
|
| 197 |
+
|
| 198 |
+
batch_size { int } : Batch size
|
| 199 |
+
|
| 200 |
+
progress_callback { callable } : Callback for progress updates
|
| 201 |
+
|
| 202 |
+
**kwargs : Additional parameters
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
--------
|
| 206 |
+
{ list } : List of embeddings
|
| 207 |
+
"""
|
| 208 |
+
if not texts:
|
| 209 |
+
return []
|
| 210 |
+
|
| 211 |
+
batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE
|
| 212 |
+
|
| 213 |
+
# Split into batches
|
| 214 |
+
batches = self.split_into_optimal_batches(texts = texts,
|
| 215 |
+
target_batch_size = batch_size,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
all_embeddings = list()
|
| 219 |
+
|
| 220 |
+
for i, batch_texts in enumerate(batches):
|
| 221 |
+
if progress_callback:
|
| 222 |
+
progress = (i / len(batches)) * 100
|
| 223 |
+
progress_callback(progress, f"Processing batch {i + 1}/{len(batches)}")
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
batch_embeddings = self.process_embeddings_batch(model = model,
|
| 227 |
+
texts = batch_texts,
|
| 228 |
+
batch_size = len(batch_texts),
|
| 229 |
+
**kwargs
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
all_embeddings.extend(batch_embeddings)
|
| 233 |
+
|
| 234 |
+
self.logger.debug(f"Processed batch {i + 1}/{len(batches)}: {len(batch_texts)} texts")
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
self.logger.error(f"Failed to process batch {i + 1}: {repr(e)}")
|
| 238 |
+
|
| 239 |
+
# Add None placeholders for failed batch
|
| 240 |
+
all_embeddings.extend([None] * len(batch_texts))
|
| 241 |
+
|
| 242 |
+
if progress_callback:
|
| 243 |
+
progress_callback(100, "Embedding complete")
|
| 244 |
+
|
| 245 |
+
return all_embeddings
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def validate_embeddings_batch(self, embeddings: List[NDArray], expected_count: int) -> bool:
|
| 249 |
+
"""
|
| 250 |
+
Validate a batch of embeddings
|
| 251 |
+
|
| 252 |
+
Arguments:
|
| 253 |
+
----------
|
| 254 |
+
embeddings { list } : List of embedding vectors
|
| 255 |
+
|
| 256 |
+
expected_count { int } : Expected number of embeddings
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
--------
|
| 260 |
+
{ bool } : True if valid
|
| 261 |
+
"""
|
| 262 |
+
if (len(embeddings) != expected_count):
|
| 263 |
+
self.logger.error(f"Embedding count mismatch: expected {expected_count}, got {len(embeddings)}")
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
valid_count = 0
|
| 267 |
+
|
| 268 |
+
for i, emb in enumerate(embeddings):
|
| 269 |
+
if emb is None:
|
| 270 |
+
self.logger.warning(f"None embedding at index {i}")
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
if not isinstance(emb, np.ndarray):
|
| 274 |
+
self.logger.warning(f"Invalid embedding type at index {i}: {type(emb)}")
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
if (emb.ndim != 1):
|
| 278 |
+
self.logger.warning(f"Invalid embedding dimension at index {i}: {emb.ndim}")
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
if np.any(np.isnan(emb)):
|
| 282 |
+
self.logger.warning(f"NaN values in embedding at index {i}")
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
valid_count += 1
|
| 286 |
+
|
| 287 |
+
validity_ratio = valid_count / expected_count
|
| 288 |
+
|
| 289 |
+
if (validity_ratio < 0.9):
|
| 290 |
+
self.logger.warning(f"Low embedding validity: {valid_count}/{expected_count} ({validity_ratio:.1%})")
|
| 291 |
+
return False
|
| 292 |
+
|
| 293 |
+
return True
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_processing_stats(self) -> dict:
|
| 297 |
+
"""
|
| 298 |
+
Get batch processing statistics
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
--------
|
| 302 |
+
{ dict } : Statistics dictionary
|
| 303 |
+
"""
|
| 304 |
+
success_rate = ((self.total_batches - self.failed_batches) / self.total_batches * 100) if (self.total_batches > 0) else 100
|
| 305 |
+
|
| 306 |
+
stats = {"total_batches" : self.total_batches,
|
| 307 |
+
"total_texts" : self.total_texts,
|
| 308 |
+
"failed_batches" : self.failed_batches,
|
| 309 |
+
"success_rate" : success_rate,
|
| 310 |
+
"avg_batch_size" : self.total_texts / self.total_batches if (self.total_batches > 0) else 0,
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
return stats
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def reset_stats(self):
|
| 317 |
+
"""
|
| 318 |
+
Reset processing statistics
|
| 319 |
+
"""
|
| 320 |
+
self.total_batches = 0
|
| 321 |
+
self.total_texts = 0
|
| 322 |
+
self.failed_batches = 0
|
| 323 |
+
|
| 324 |
+
self.logger.debug("Reset batch processing statistics")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# Global batch processor instance
|
| 328 |
+
_batch_processor = None
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def get_batch_processor() -> BatchProcessor:
|
| 332 |
+
"""
|
| 333 |
+
Get global batch processor instance
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
--------
|
| 337 |
+
{ BatchProcessor } : BatchProcessor instance
|
| 338 |
+
"""
|
| 339 |
+
global _batch_processor
|
| 340 |
+
|
| 341 |
+
if _batch_processor is None:
|
| 342 |
+
_batch_processor = BatchProcessor()
|
| 343 |
+
|
| 344 |
+
return _batch_processor
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def process_embeddings_batch(model: SentenceTransformer, texts: List[str], **kwargs) -> List[NDArray]:
|
| 348 |
+
"""
|
| 349 |
+
Convenience function for batch embedding
|
| 350 |
+
|
| 351 |
+
Arguments:
|
| 352 |
+
----------
|
| 353 |
+
model { SentenceTransformer } : Embedding model
|
| 354 |
+
|
| 355 |
+
texts { list } : List of texts
|
| 356 |
+
|
| 357 |
+
**kwargs : Additional arguments
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
--------
|
| 361 |
+
{ list } : List of embeddings
|
| 362 |
+
"""
|
| 363 |
+
processor = get_batch_processor()
|
| 364 |
+
|
| 365 |
+
return processor.process_embeddings_batch(model, texts, **kwargs)
|
embeddings/bge_embedder.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from numpy.typing import NDArray
|
| 7 |
+
from config.models import DocumentChunk
|
| 8 |
+
from config.settings import get_settings
|
| 9 |
+
from config.models import EmbeddingRequest
|
| 10 |
+
from config.models import EmbeddingResponse
|
| 11 |
+
from config.logging_config import get_logger
|
| 12 |
+
from utils.error_handler import handle_errors
|
| 13 |
+
from utils.error_handler import EmbeddingError
|
| 14 |
+
from embeddings.model_loader import get_model_loader
|
| 15 |
+
from embeddings.batch_processor import BatchProcessor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Setup Settings and Logging
|
| 19 |
+
settings = get_settings()
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BGEEmbedder:
|
| 24 |
+
"""
|
| 25 |
+
BGE (BAAI General Embedding) model wrapper: Optimized for BAAI/bge models with proper normalization and batching
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None):
|
| 28 |
+
"""
|
| 29 |
+
Initialize BGE embedder
|
| 30 |
+
|
| 31 |
+
Arguments:
|
| 32 |
+
----------
|
| 33 |
+
model_name { str } : BGE model name (default from settings)
|
| 34 |
+
|
| 35 |
+
device { str } : Device to run on
|
| 36 |
+
"""
|
| 37 |
+
self.logger = logger
|
| 38 |
+
self.model_name = model_name or settings.EMBEDDING_MODEL
|
| 39 |
+
self.device = device
|
| 40 |
+
|
| 41 |
+
# Initialize components
|
| 42 |
+
self.model_loader = get_model_loader()
|
| 43 |
+
self.batch_processor = BatchProcessor()
|
| 44 |
+
|
| 45 |
+
# Load model
|
| 46 |
+
self.model = self.model_loader.load_model(model_name = self.model_name,
|
| 47 |
+
device = self.device,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Get model info
|
| 51 |
+
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
| 52 |
+
self.supports_batch = True
|
| 53 |
+
|
| 54 |
+
self.logger.info(f"Initialized BGEEmbedder: model={self.model_name}, dim={self.embedding_dim}, device={self.model.device}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@handle_errors(error_type = EmbeddingError, log_error = True, reraise = True)
|
| 58 |
+
def embed_text(self, text: str, normalize: bool = True) -> NDArray:
|
| 59 |
+
"""
|
| 60 |
+
Embed single text string
|
| 61 |
+
|
| 62 |
+
Arguments:
|
| 63 |
+
----------
|
| 64 |
+
text { str } : Input text
|
| 65 |
+
|
| 66 |
+
normalize { bool } : Normalize embeddings to unit length
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
--------
|
| 70 |
+
{ NDArray } : Embedding vector
|
| 71 |
+
"""
|
| 72 |
+
if not text or not text.strip():
|
| 73 |
+
raise EmbeddingError("Cannot embed empty text")
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# Encode single text
|
| 77 |
+
embedding = self.model.encode([text],
|
| 78 |
+
normalize_embeddings = normalize,
|
| 79 |
+
show_progress_bar = False,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Return single vector
|
| 83 |
+
return embedding[0]
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
self.logger.error(f"Failed to embed text: {repr(e)}")
|
| 87 |
+
raise EmbeddingError(f"Text embedding failed: {repr(e)}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@handle_errors(error_type = EmbeddingError, log_error = True, reraise = True)
|
| 91 |
+
def embed_texts(self, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True) -> List[NDArray]:
|
| 92 |
+
"""
|
| 93 |
+
Embed multiple texts with batching
|
| 94 |
+
|
| 95 |
+
Arguments:
|
| 96 |
+
----------
|
| 97 |
+
texts { list } : List of text strings
|
| 98 |
+
|
| 99 |
+
batch_size { int } : Batch size (default from settings)
|
| 100 |
+
|
| 101 |
+
normalize { bool } : Normalize embeddings
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
--------
|
| 105 |
+
{ list } : List of embedding vectors
|
| 106 |
+
"""
|
| 107 |
+
if not texts:
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
# Filter empty texts
|
| 111 |
+
valid_texts = [t for t in texts if t and t.strip()]
|
| 112 |
+
|
| 113 |
+
if (len(valid_texts) != len(texts)):
|
| 114 |
+
self.logger.warning(f"Filtered {len(texts) - len(valid_texts)} empty texts")
|
| 115 |
+
|
| 116 |
+
if not valid_texts:
|
| 117 |
+
return []
|
| 118 |
+
|
| 119 |
+
batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
# Use batch processing for efficiency
|
| 123 |
+
embeddings = self.batch_processor.process_embeddings_batch(model = self.model,
|
| 124 |
+
texts = valid_texts,
|
| 125 |
+
batch_size = batch_size,
|
| 126 |
+
normalize = normalize,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.logger.debug(f"Generated {len(embeddings)} embeddings for {len(texts)} texts")
|
| 130 |
+
|
| 131 |
+
return embeddings
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
self.logger.error(f"Batch embedding failed: {repr(e)}")
|
| 135 |
+
raise EmbeddingError(f"Batch embedding failed: {repr(e)}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@handle_errors(error_type = EmbeddingError, log_error = True, reraise = True)
|
| 139 |
+
def embed_chunks(self, chunks: List[DocumentChunk], batch_size: Optional[int] = None, normalize: bool = True) -> List[DocumentChunk]:
|
| 140 |
+
"""
|
| 141 |
+
Embed document chunks and update them with embeddings
|
| 142 |
+
|
| 143 |
+
Arguments:
|
| 144 |
+
----------
|
| 145 |
+
chunks { list } : List of DocumentChunk objects
|
| 146 |
+
|
| 147 |
+
batch_size { int } : Batch size
|
| 148 |
+
|
| 149 |
+
normalize { bool } : Normalize embeddings
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
--------
|
| 153 |
+
{ list } : Chunks with embeddings added
|
| 154 |
+
"""
|
| 155 |
+
if not chunks:
|
| 156 |
+
return []
|
| 157 |
+
|
| 158 |
+
# Extract texts from chunks
|
| 159 |
+
texts = [chunk.text for chunk in chunks]
|
| 160 |
+
|
| 161 |
+
# Generate embeddings
|
| 162 |
+
embeddings = self.embed_texts(texts = texts,
|
| 163 |
+
batch_size = batch_size,
|
| 164 |
+
normalize = normalize,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Update chunks with embeddings
|
| 168 |
+
for chunk, embedding in zip(chunks, embeddings):
|
| 169 |
+
# Convert numpy to list for serialization
|
| 170 |
+
chunk.embedding = embedding.tolist()
|
| 171 |
+
|
| 172 |
+
self.logger.info(f"Embedded {len(chunks)} document chunks")
|
| 173 |
+
|
| 174 |
+
return chunks
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def process_embedding_request(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
| 178 |
+
"""
|
| 179 |
+
Process embedding request from API
|
| 180 |
+
|
| 181 |
+
Arguments:
|
| 182 |
+
----------
|
| 183 |
+
request { EmbeddingRequest } : Embedding request
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
--------
|
| 187 |
+
{ EmbeddingResponse } : Embedding response
|
| 188 |
+
"""
|
| 189 |
+
start_time = time.time()
|
| 190 |
+
|
| 191 |
+
# Generate embeddings
|
| 192 |
+
embeddings = self.embed_texts(texts = request.texts,
|
| 193 |
+
batch_size = request.batch_size,
|
| 194 |
+
normalize = request.normalize,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Convert to milliseconds
|
| 198 |
+
processing_time = (time.time() - start_time) * 1000
|
| 199 |
+
|
| 200 |
+
# Convert to list for serialization
|
| 201 |
+
embedding_list = [emb.tolist() for emb in embeddings]
|
| 202 |
+
|
| 203 |
+
response = EmbeddingResponse(embeddings = embedding_list,
|
| 204 |
+
dimension = self.embedding_dim,
|
| 205 |
+
num_embeddings = len(embeddings),
|
| 206 |
+
processing_time_ms = processing_time,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return response
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def get_embedding_dimension(self) -> int:
|
| 213 |
+
"""
|
| 214 |
+
Get embedding dimension
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
--------
|
| 218 |
+
{ int } : Embedding vector dimension
|
| 219 |
+
"""
|
| 220 |
+
return self.embedding_dim
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def cosine_similarity(self, emb1: NDArray, emb2: NDArray) -> float:
|
| 224 |
+
"""
|
| 225 |
+
Calculate cosine similarity between two embeddings
|
| 226 |
+
|
| 227 |
+
Arguments:
|
| 228 |
+
----------
|
| 229 |
+
emb1 { NDArray } : First embedding
|
| 230 |
+
|
| 231 |
+
emb2 { NDArray } : Second embedding
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
--------
|
| 235 |
+
{ float } : Cosine similarity (-1 to 1)
|
| 236 |
+
"""
|
| 237 |
+
# Ensure embeddings are normalized
|
| 238 |
+
emb1_norm = emb1 / np.linalg.norm(emb1)
|
| 239 |
+
emb2_norm = emb2 / np.linalg.norm(emb2)
|
| 240 |
+
|
| 241 |
+
return float(np.dot(emb1_norm, emb2_norm))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def validate_embedding(self, embedding: NDArray) -> bool:
|
| 245 |
+
"""
|
| 246 |
+
Validate embedding vector
|
| 247 |
+
|
| 248 |
+
Arguments:
|
| 249 |
+
----------
|
| 250 |
+
embedding { NDArray } : Embedding vector
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
--------
|
| 254 |
+
{ bool } : True if valid
|
| 255 |
+
"""
|
| 256 |
+
if (embedding is None):
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
if (not isinstance(embedding, np.ndarray)):
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
if (embedding.shape != (self.embedding_dim,)):
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
if (np.all(embedding == 0)):
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
if (np.any(np.isnan(embedding))):
|
| 269 |
+
return False
|
| 270 |
+
|
| 271 |
+
return True
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_model_info(self) -> dict:
|
| 275 |
+
"""
|
| 276 |
+
Get embedder information
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
--------
|
| 280 |
+
{ dict } : Embedder information
|
| 281 |
+
"""
|
| 282 |
+
return {"model_name" : self.model_name,
|
| 283 |
+
"embedding_dim" : self.embedding_dim,
|
| 284 |
+
"device" : str(self.model.device),
|
| 285 |
+
"supports_batch" : self.supports_batch,
|
| 286 |
+
"normalize_default" : True,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Global embedder instance
|
| 291 |
+
_embedder = None
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def get_embedder(model_name: Optional[str] = None, device: Optional[str] = None) -> BGEEmbedder:
|
| 295 |
+
"""
|
| 296 |
+
Get global embedder instance
|
| 297 |
+
|
| 298 |
+
Arguments:
|
| 299 |
+
----------
|
| 300 |
+
model_name { str } : Model name
|
| 301 |
+
|
| 302 |
+
device { str } : Device
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
--------
|
| 306 |
+
{ BGEEmbedder } : BGEEmbedder instance
|
| 307 |
+
"""
|
| 308 |
+
global _embedder
|
| 309 |
+
|
| 310 |
+
if _embedder is None:
|
| 311 |
+
_embedder = BGEEmbedder(model_name, device)
|
| 312 |
+
|
| 313 |
+
return _embedder
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def embed_texts(texts: List[str], **kwargs) -> List[NDArray]:
|
| 317 |
+
"""
|
| 318 |
+
Convenience function to embed texts
|
| 319 |
+
|
| 320 |
+
Arguments:
|
| 321 |
+
----------
|
| 322 |
+
texts { list } : List of texts
|
| 323 |
+
|
| 324 |
+
**kwargs : Additional arguments for BGEEmbedder
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
--------
|
| 328 |
+
{ list } : List of embeddings
|
| 329 |
+
"""
|
| 330 |
+
embedder = get_embedder()
|
| 331 |
+
|
| 332 |
+
return embedder.embed_texts(texts, **kwargs)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def embed_chunks(chunks: List[DocumentChunk], **kwargs) -> List[DocumentChunk]:
|
| 336 |
+
"""
|
| 337 |
+
Convenience function to embed document chunks
|
| 338 |
+
|
| 339 |
+
Arguments:
|
| 340 |
+
----------
|
| 341 |
+
chunks { list } : List of DocumentChunk objects
|
| 342 |
+
|
| 343 |
+
**kwargs : Additional arguments
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
--------
|
| 347 |
+
{ list } : Chunks with embeddings
|
| 348 |
+
"""
|
| 349 |
+
embedder = get_embedder()
|
| 350 |
+
|
| 351 |
+
return embedder.embed_chunks(chunks, **kwargs)
|
embeddings/embedding_cache.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from numpy.typing import NDArray
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.logging_config import get_logger
|
| 8 |
+
from utils.error_handler import handle_errors
|
| 9 |
+
from utils.cache_manager import EmbeddingCache as BaseEmbeddingCache
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Setup Settings and Logging
|
| 13 |
+
settings = get_settings()
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EmbeddingCache:
|
| 18 |
+
"""
|
| 19 |
+
Embedding cache with numpy array support and statistics: Wraps the base cache with embedding-specific features
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, max_size: int = None, ttl: int = None):
|
| 22 |
+
"""
|
| 23 |
+
Initialize embedding cache
|
| 24 |
+
|
| 25 |
+
Arguments:
|
| 26 |
+
----------
|
| 27 |
+
max_size { int } : Maximum cache size
|
| 28 |
+
|
| 29 |
+
ttl { int } : Time to live in seconds
|
| 30 |
+
"""
|
| 31 |
+
self.logger = logger
|
| 32 |
+
self.max_size = max_size or settings.CACHE_MAX_SIZE
|
| 33 |
+
self.ttl = ttl or settings.CACHE_TTL
|
| 34 |
+
|
| 35 |
+
# Initialize base cache
|
| 36 |
+
self.base_cache = BaseEmbeddingCache(max_size = self.max_size,
|
| 37 |
+
ttl = self.ttl,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Enhanced statistics
|
| 41 |
+
self.hits = 0
|
| 42 |
+
self.misses = 0
|
| 43 |
+
self.embeddings_generated = 0
|
| 44 |
+
|
| 45 |
+
self.logger.info(f"Initialized EmbeddingCache: max_size={self.max_size}, ttl={self.ttl}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_embedding(self, text: str) -> Optional[NDArray]:
|
| 49 |
+
"""
|
| 50 |
+
Get embedding from cache
|
| 51 |
+
|
| 52 |
+
Arguments:
|
| 53 |
+
----------
|
| 54 |
+
text { str } : Input text
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
--------
|
| 58 |
+
{ NDArray } : Cached embedding or None
|
| 59 |
+
"""
|
| 60 |
+
cached = self.base_cache.get_embedding(text)
|
| 61 |
+
|
| 62 |
+
if cached is not None:
|
| 63 |
+
self.hits += 1
|
| 64 |
+
|
| 65 |
+
# Convert list back to numpy array
|
| 66 |
+
return np.array(cached)
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
self.misses += 1
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def set_embedding(self, text: str, embedding: NDArray):
|
| 74 |
+
"""
|
| 75 |
+
Store embedding in cache
|
| 76 |
+
|
| 77 |
+
Arguments:
|
| 78 |
+
----------
|
| 79 |
+
text { str } : Input text
|
| 80 |
+
|
| 81 |
+
embedding { NDArray } : Embedding vector
|
| 82 |
+
"""
|
| 83 |
+
# Convert numpy array to list for serialization
|
| 84 |
+
embedding_list = embedding.tolist()
|
| 85 |
+
|
| 86 |
+
self.base_cache.set_embedding(text, embedding_list)
|
| 87 |
+
|
| 88 |
+
self.embeddings_generated += 1
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def batch_get_embeddings(self, texts: List[str]) -> tuple[List[Optional[NDArray]], List[str]]:
|
| 92 |
+
"""
|
| 93 |
+
Get multiple embeddings from cache
|
| 94 |
+
|
| 95 |
+
Arguments:
|
| 96 |
+
----------
|
| 97 |
+
texts { list } : List of texts
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
--------
|
| 101 |
+
{ tuple } : Tuple of (cached_embeddings, missing_texts)
|
| 102 |
+
"""
|
| 103 |
+
cached_embeddings = list()
|
| 104 |
+
missing_texts = list()
|
| 105 |
+
|
| 106 |
+
for text in texts:
|
| 107 |
+
embedding = self.get_embedding(text)
|
| 108 |
+
|
| 109 |
+
if embedding is not None:
|
| 110 |
+
cached_embeddings.append(embedding)
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
missing_texts.append(text)
|
| 114 |
+
cached_embeddings.append(None)
|
| 115 |
+
|
| 116 |
+
return cached_embeddings, missing_texts
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def batch_set_embeddings(self, texts: List[str], embeddings: List[NDArray]):
|
| 120 |
+
"""
|
| 121 |
+
Store multiple embeddings in cache
|
| 122 |
+
|
| 123 |
+
Arguments:
|
| 124 |
+
----------
|
| 125 |
+
texts { list } : List of texts
|
| 126 |
+
|
| 127 |
+
embeddings { list } : List of embedding vectors
|
| 128 |
+
"""
|
| 129 |
+
if (len(texts) != len(embeddings)):
|
| 130 |
+
raise ValueError("Texts and embeddings must have same length")
|
| 131 |
+
|
| 132 |
+
for text, embedding in zip(texts, embeddings):
|
| 133 |
+
self.set_embedding(text, embedding)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_cached_embeddings(self, texts: List[str], embed_function: callable, batch_size: Optional[int] = None) -> List[NDArray]:
|
| 137 |
+
"""
|
| 138 |
+
Smart embedding getter: uses cache for existing, generates for missing
|
| 139 |
+
|
| 140 |
+
Arguments:
|
| 141 |
+
----------
|
| 142 |
+
texts { list } : List of texts
|
| 143 |
+
|
| 144 |
+
embed_function { callable } : Function to generate embeddings for missing texts
|
| 145 |
+
|
| 146 |
+
batch_size { int } : Batch size for generation
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
--------
|
| 150 |
+
{ list } : List of embeddings
|
| 151 |
+
"""
|
| 152 |
+
# Get cached embeddings
|
| 153 |
+
cached_embeddings, missing_texts = self.batch_get_embeddings(texts = texts)
|
| 154 |
+
|
| 155 |
+
if not missing_texts:
|
| 156 |
+
self.logger.debug(f"All {len(texts)} embeddings found in cache")
|
| 157 |
+
return cached_embeddings
|
| 158 |
+
|
| 159 |
+
# Generate missing embeddings
|
| 160 |
+
self.logger.info(f"Generating {len(missing_texts)} embeddings ({(len(missing_texts)/len(texts))*100:.1f}% cache miss)")
|
| 161 |
+
|
| 162 |
+
missing_embeddings = embed_function(missing_texts, batch_size = batch_size)
|
| 163 |
+
|
| 164 |
+
# Store new embeddings in cache
|
| 165 |
+
self.batch_set_embeddings(missing_texts, missing_embeddings)
|
| 166 |
+
|
| 167 |
+
# Combine results
|
| 168 |
+
result_embeddings = list()
|
| 169 |
+
missing_idx = 0
|
| 170 |
+
|
| 171 |
+
for emb in cached_embeddings:
|
| 172 |
+
if emb is not None:
|
| 173 |
+
result_embeddings.append(emb)
|
| 174 |
+
|
| 175 |
+
else:
|
| 176 |
+
result_embeddings.append(missing_embeddings[missing_idx])
|
| 177 |
+
missing_idx += 1
|
| 178 |
+
|
| 179 |
+
return result_embeddings
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def clear(self):
|
| 183 |
+
"""
|
| 184 |
+
Clear entire cache
|
| 185 |
+
"""
|
| 186 |
+
self.base_cache.clear()
|
| 187 |
+
self.hits = 0
|
| 188 |
+
self.misses = 0
|
| 189 |
+
self.embeddings_generated = 0
|
| 190 |
+
|
| 191 |
+
self.logger.info("Cleared embedding cache")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_stats(self) -> dict:
|
| 195 |
+
"""
|
| 196 |
+
Get cache statistics
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
--------
|
| 200 |
+
{ dict } : Statistics dictionary
|
| 201 |
+
"""
|
| 202 |
+
base_stats = self.base_cache.get_stats()
|
| 203 |
+
|
| 204 |
+
total_requests = self.hits + self.misses
|
| 205 |
+
hit_rate = (self.hits / total_requests * 100) if (total_requests > 0) else 0
|
| 206 |
+
|
| 207 |
+
stats = {**base_stats,
|
| 208 |
+
"hits" : self.hits,
|
| 209 |
+
"misses" : self.misses,
|
| 210 |
+
"hit_rate_percentage" : hit_rate,
|
| 211 |
+
"embeddings_generated" : self.embeddings_generated,
|
| 212 |
+
"cache_size" : self.base_cache.cache.size(),
|
| 213 |
+
"max_size" : self.max_size,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return stats
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def save_to_file(self, file_path: str) -> bool:
|
| 220 |
+
"""
|
| 221 |
+
Save cache to file
|
| 222 |
+
|
| 223 |
+
Arguments:
|
| 224 |
+
----------
|
| 225 |
+
file_path { str } : Path to save file
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
--------
|
| 229 |
+
{ bool } : True if successful
|
| 230 |
+
"""
|
| 231 |
+
return self.base_cache.save_to_file(file_path)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def load_from_file(self, file_path: str) -> bool:
|
| 235 |
+
"""
|
| 236 |
+
Load cache from file
|
| 237 |
+
|
| 238 |
+
Arguments:
|
| 239 |
+
----------
|
| 240 |
+
file_path { str } : Path to load file
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
--------
|
| 244 |
+
{ bool } : True if successful
|
| 245 |
+
"""
|
| 246 |
+
return self.base_cache.load_from_file(file_path)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def warm_cache(self, texts: List[str], embed_function: callable, batch_size: Optional[int] = None):
|
| 250 |
+
"""
|
| 251 |
+
Pre-populate cache with embeddings
|
| 252 |
+
|
| 253 |
+
Arguments:
|
| 254 |
+
----------
|
| 255 |
+
texts { list } : List of texts to warm cache with
|
| 256 |
+
|
| 257 |
+
embed_function { callable } : Embedding generation function
|
| 258 |
+
|
| 259 |
+
batch_size { int } : Batch size
|
| 260 |
+
"""
|
| 261 |
+
# Check which texts are not in cache
|
| 262 |
+
_, missing_texts = self.batch_get_embeddings(texts = texts)
|
| 263 |
+
|
| 264 |
+
if not missing_texts:
|
| 265 |
+
self.logger.info("Cache already warm for all texts")
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
self.logger.info(f"Warming cache with {len(missing_texts)} embeddings")
|
| 269 |
+
|
| 270 |
+
# Generate and cache embeddings
|
| 271 |
+
embeddings = embed_function(missing_texts, batch_size = batch_size)
|
| 272 |
+
|
| 273 |
+
self.batch_set_embeddings(missing_texts, embeddings)
|
| 274 |
+
|
| 275 |
+
self.logger.info(f"Cache warming complete: added {len(missing_texts)} embeddings")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# Global embedding cache instance
|
| 279 |
+
_embedding_cache = None
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_embedding_cache() -> EmbeddingCache:
|
| 283 |
+
"""
|
| 284 |
+
Get global embedding cache instance
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
--------
|
| 288 |
+
{ EmbeddingCache } : EmbeddingCache instance
|
| 289 |
+
"""
|
| 290 |
+
global _embedding_cache
|
| 291 |
+
|
| 292 |
+
if _embedding_cache is None:
|
| 293 |
+
_embedding_cache = EmbeddingCache()
|
| 294 |
+
|
| 295 |
+
return _embedding_cache
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def cache_embeddings(texts: List[str], embeddings: List[NDArray]):
|
| 299 |
+
"""
|
| 300 |
+
Convenience function to cache embeddings
|
| 301 |
+
|
| 302 |
+
Arguments:
|
| 303 |
+
----------
|
| 304 |
+
texts { list } : List of texts
|
| 305 |
+
|
| 306 |
+
embeddings { list } : List of embeddings
|
| 307 |
+
"""
|
| 308 |
+
cache = get_embedding_cache()
|
| 309 |
+
|
| 310 |
+
cache.batch_set_embeddings(texts, embeddings)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_cached_embeddings(texts: List[str], embed_function: callable, **kwargs) -> List[NDArray]:
|
| 314 |
+
"""
|
| 315 |
+
Convenience function to get cached embeddings
|
| 316 |
+
|
| 317 |
+
Arguments:
|
| 318 |
+
----------
|
| 319 |
+
texts { list } : List of texts
|
| 320 |
+
|
| 321 |
+
embed_function { callable } : Embedding function
|
| 322 |
+
|
| 323 |
+
**kwargs : Additional arguments
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
--------
|
| 327 |
+
{ list } : List of embeddings
|
| 328 |
+
"""
|
| 329 |
+
cache = get_embedding_cache()
|
| 330 |
+
|
| 331 |
+
return cache.get_cached_embeddings(texts, embed_function, **kwargs)
|
embeddings/model_loader.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import os
|
| 3 |
+
import gc
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.logging_config import get_logger
|
| 8 |
+
from utils.error_handler import handle_errors
|
| 9 |
+
from utils.error_handler import EmbeddingError
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Setup Settings and Logging
|
| 14 |
+
settings = get_settings()
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EmbeddingModelLoader:
|
| 19 |
+
"""
|
| 20 |
+
Manages loading and caching of embedding models: Supports multiple models with efficient resource management
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.logger = logger
|
| 24 |
+
self._loaded_model = None
|
| 25 |
+
self._model_name = None
|
| 26 |
+
self._device = None
|
| 27 |
+
|
| 28 |
+
# Model cache for multiple models
|
| 29 |
+
self._model_cache = dict()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@handle_errors(error_type = EmbeddingError, log_error = True, reraise = True)
|
| 33 |
+
def load_model(self, model_name: Optional[str] = None, device: Optional[str] = None, force_reload: bool = False) -> SentenceTransformer:
|
| 34 |
+
"""
|
| 35 |
+
Load embedding model with caching and device optimization
|
| 36 |
+
|
| 37 |
+
Arguments:
|
| 38 |
+
----------
|
| 39 |
+
model_name { str } : Name of model to load (default from settings)
|
| 40 |
+
|
| 41 |
+
device { str } : Device to load on ('cpu', 'cuda', 'mps', 'auto')
|
| 42 |
+
|
| 43 |
+
force_reload { bool } : Force reload even if model is cached
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
--------
|
| 47 |
+
{ SentenceTransformer } : Loaded model instance
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
-------
|
| 51 |
+
EmbeddingError : If model loading fails
|
| 52 |
+
"""
|
| 53 |
+
model_name = model_name or settings.EMBEDDING_MODEL
|
| 54 |
+
device = self._resolve_device(device)
|
| 55 |
+
|
| 56 |
+
# Check cache first
|
| 57 |
+
cache_key = f"{model_name}_{device}"
|
| 58 |
+
|
| 59 |
+
if ((not force_reload) and (cache_key in self._model_cache)):
|
| 60 |
+
self.logger.debug(f"Using cached model: {cache_key}")
|
| 61 |
+
|
| 62 |
+
self._loaded_model = self._model_cache[cache_key]
|
| 63 |
+
self._model_name = model_name
|
| 64 |
+
self._device = device
|
| 65 |
+
|
| 66 |
+
return self._loaded_model
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
self.logger.info(f"Loading embedding model: {model_name} on device: {device}")
|
| 70 |
+
|
| 71 |
+
# Load model with optimized settings
|
| 72 |
+
model = SentenceTransformer(model_name,
|
| 73 |
+
device = device,
|
| 74 |
+
cache_folder = os.path.expanduser("~/.cache/sentence_transformers"),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Model-specific optimizations
|
| 78 |
+
model = self._optimize_model(model = model,
|
| 79 |
+
device = device,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Cache the model
|
| 83 |
+
self._model_cache[cache_key] = model
|
| 84 |
+
self._loaded_model = model
|
| 85 |
+
self._model_name = model_name
|
| 86 |
+
self._device = device
|
| 87 |
+
|
| 88 |
+
# Log model info
|
| 89 |
+
self._log_model_info(model = model,
|
| 90 |
+
device = device,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.logger.info(f"Successfully loaded model: {model_name}")
|
| 94 |
+
|
| 95 |
+
return model
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
self.logger.error(f"Failed to load model {model_name}: {repr(e)}")
|
| 99 |
+
raise EmbeddingError(f"Model loading failed: {repr(e)}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _resolve_device(self, device: Optional[str] = None) -> str:
|
| 103 |
+
"""
|
| 104 |
+
Resolve the best available device
|
| 105 |
+
|
| 106 |
+
Arguments:
|
| 107 |
+
----------
|
| 108 |
+
device { str } : Requested device
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
--------
|
| 112 |
+
{ str } : Actual device to use
|
| 113 |
+
"""
|
| 114 |
+
if (device and (device != "auto")):
|
| 115 |
+
return device
|
| 116 |
+
|
| 117 |
+
# Auto device selection
|
| 118 |
+
if (settings.EMBEDDING_DEVICE != "auto"):
|
| 119 |
+
return settings.EMBEDDING_DEVICE
|
| 120 |
+
|
| 121 |
+
# Automatic detection
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
return "cuda"
|
| 124 |
+
|
| 125 |
+
elif torch.backends.mps.is_available():
|
| 126 |
+
return "mps"
|
| 127 |
+
|
| 128 |
+
else:
|
| 129 |
+
return "cpu"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _optimize_model(self, model: SentenceTransformer, device: str) -> SentenceTransformer:
|
| 133 |
+
"""
|
| 134 |
+
Apply optimizations to the model
|
| 135 |
+
|
| 136 |
+
Arguments:
|
| 137 |
+
----------
|
| 138 |
+
model { SentenceTransformer } : Model to optimize
|
| 139 |
+
|
| 140 |
+
device { str } : Device model is on
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
--------
|
| 144 |
+
{ SentenceTransformer } : Optimized model
|
| 145 |
+
"""
|
| 146 |
+
# Enable eval mode for inference
|
| 147 |
+
model.eval()
|
| 148 |
+
|
| 149 |
+
# GPU optimizations
|
| 150 |
+
if (device == "cuda"):
|
| 151 |
+
# Use half precision for GPU if supported
|
| 152 |
+
try:
|
| 153 |
+
model = model.half()
|
| 154 |
+
self.logger.debug("Enabled half precision for GPU")
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
self.logger.warning(f"Could not enable half precision: {repr(e)}")
|
| 158 |
+
|
| 159 |
+
# Disable gradient computation
|
| 160 |
+
for param in model.parameters():
|
| 161 |
+
param.requires_grad = False
|
| 162 |
+
|
| 163 |
+
return model
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _log_model_info(self, model: SentenceTransformer, device: str):
|
| 167 |
+
"""
|
| 168 |
+
Log detailed model information
|
| 169 |
+
|
| 170 |
+
Arguments:
|
| 171 |
+
----------
|
| 172 |
+
model { SentenceTransformer } : Model to log info for
|
| 173 |
+
|
| 174 |
+
device { str } : Device model is on
|
| 175 |
+
"""
|
| 176 |
+
try:
|
| 177 |
+
# Get model architecture info
|
| 178 |
+
if hasattr(model, '_modules'):
|
| 179 |
+
modules = list(model._modules.keys())
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
modules = ["unknown"]
|
| 183 |
+
|
| 184 |
+
# Get embedding dimension
|
| 185 |
+
if hasattr(model, 'get_sentence_embedding_dimension'):
|
| 186 |
+
dimension = model.get_sentence_embedding_dimension()
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
dimension = "unknown"
|
| 190 |
+
|
| 191 |
+
# Count parameters
|
| 192 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 193 |
+
|
| 194 |
+
self.logger.info(f"Model Info: {len(modules)} modules, dimension={dimension}, parameters={total_params:,}, device={device}")
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
self.logger.debug(f"Could not get detailed model info: {repr(e)}")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_loaded_model(self) -> Optional[SentenceTransformer]:
|
| 201 |
+
"""
|
| 202 |
+
Get currently loaded model
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
--------
|
| 206 |
+
{ SentenceTransformer } : Currently loaded model or None
|
| 207 |
+
"""
|
| 208 |
+
return self._loaded_model
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_model_info(self) -> dict:
|
| 212 |
+
"""
|
| 213 |
+
Get information about loaded model
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
--------
|
| 217 |
+
{ dict } : Model information dictionary
|
| 218 |
+
"""
|
| 219 |
+
if self._loaded_model is None:
|
| 220 |
+
return {"loaded": False}
|
| 221 |
+
|
| 222 |
+
info = {"loaded" : True,
|
| 223 |
+
"model_name" : self._model_name,
|
| 224 |
+
"device" : self._device,
|
| 225 |
+
"cache_size" : len(self._model_cache),
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
if hasattr(self._loaded_model, 'get_sentence_embedding_dimension'):
|
| 230 |
+
info["embedding_dimension"] = self._loaded_model.get_sentence_embedding_dimension()
|
| 231 |
+
|
| 232 |
+
info["model_class"] = type(self._loaded_model).__name__
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
self.logger.warning(f"Could not get detailed model info: {e}")
|
| 236 |
+
|
| 237 |
+
return info
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def clear_cache(self, model_name: Optional[str] = None):
|
| 241 |
+
"""
|
| 242 |
+
Clear model cache
|
| 243 |
+
|
| 244 |
+
Arguments:
|
| 245 |
+
----------
|
| 246 |
+
model_name { str } : Specific model to clear (None = all)
|
| 247 |
+
"""
|
| 248 |
+
if model_name:
|
| 249 |
+
# Clear specific model from all devices
|
| 250 |
+
keys_to_remove = [k for k in self._model_cache.keys() if k.startswith(model_name)]
|
| 251 |
+
|
| 252 |
+
for key in keys_to_remove:
|
| 253 |
+
del self._model_cache[key]
|
| 254 |
+
|
| 255 |
+
self.logger.info(f"Cleared cache for model: {model_name}")
|
| 256 |
+
|
| 257 |
+
else:
|
| 258 |
+
# Clear all cache
|
| 259 |
+
cache_size = len(self._model_cache)
|
| 260 |
+
self._model_cache.clear()
|
| 261 |
+
|
| 262 |
+
self.logger.info(f"Cleared all model cache ({cache_size} models)")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def unload_model(self):
|
| 266 |
+
"""
|
| 267 |
+
Unload current model and free memory
|
| 268 |
+
"""
|
| 269 |
+
if self._loaded_model:
|
| 270 |
+
model_name = self._model_name
|
| 271 |
+
|
| 272 |
+
# Clear from cache
|
| 273 |
+
if self._model_name and self._device:
|
| 274 |
+
cache_key = f"{self._model_name}_{self._device}"
|
| 275 |
+
self._model_cache.pop(cache_key, None)
|
| 276 |
+
|
| 277 |
+
# Clear references
|
| 278 |
+
self._loaded_model = None
|
| 279 |
+
self._model_name = None
|
| 280 |
+
self._device = None
|
| 281 |
+
|
| 282 |
+
# Force garbage collection
|
| 283 |
+
gc.collect()
|
| 284 |
+
|
| 285 |
+
if torch.cuda.is_available():
|
| 286 |
+
torch.cuda.empty_cache()
|
| 287 |
+
|
| 288 |
+
self.logger.info(f"Unloaded model: {model_name}")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Global model loader instance
|
| 292 |
+
_model_loader = None
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def get_model_loader() -> EmbeddingModelLoader:
|
| 296 |
+
"""
|
| 297 |
+
Get global model loader instance (singleton)
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
--------
|
| 301 |
+
{ EmbeddingModelLoader } : Model loader instance
|
| 302 |
+
"""
|
| 303 |
+
global _model_loader
|
| 304 |
+
|
| 305 |
+
if _model_loader is None:
|
| 306 |
+
_model_loader = EmbeddingModelLoader()
|
| 307 |
+
|
| 308 |
+
return _model_loader
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def load_embedding_model(model_name: Optional[str] = None, device: Optional[str] = None) -> SentenceTransformer:
|
| 312 |
+
"""
|
| 313 |
+
Convenience function to load embedding model
|
| 314 |
+
|
| 315 |
+
Arguments:
|
| 316 |
+
----------
|
| 317 |
+
model_name { str } : Model name
|
| 318 |
+
|
| 319 |
+
device { str } : Device
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
--------
|
| 323 |
+
{ SentenceTransformer } : Loaded model
|
| 324 |
+
"""
|
| 325 |
+
loader = get_model_loader()
|
| 326 |
+
|
| 327 |
+
return loader.load_model(model_name, device)
|
evaluation/ragas_evaluator.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import logging
|
| 5 |
+
import statistics
|
| 6 |
+
from typing import Any
|
| 7 |
+
from typing import List
|
| 8 |
+
from typing import Dict
|
| 9 |
+
from ragas import evaluate
|
| 10 |
+
from typing import Optional
|
| 11 |
+
from datasets import Dataset
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from ragas.metrics import faithfulness
|
| 14 |
+
from config.settings import get_settings
|
| 15 |
+
from ragas.metrics import context_recall
|
| 16 |
+
from config.models import RAGASStatistics
|
| 17 |
+
from config.models import RAGASExportData
|
| 18 |
+
from ragas.metrics import answer_relevancy
|
| 19 |
+
from ragas.metrics import context_precision
|
| 20 |
+
from ragas.metrics import context_relevancy
|
| 21 |
+
from ragas.metrics import answer_similarity
|
| 22 |
+
from ragas.metrics import answer_correctness
|
| 23 |
+
from config.logging_config import get_logger
|
| 24 |
+
from ragas.metrics import context_utilization
|
| 25 |
+
from config.models import RAGASEvaluationResult
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Setup Logging
|
| 29 |
+
settings = get_settings()
|
| 30 |
+
logger = get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Set OpenAI API key from settings
|
| 34 |
+
if (hasattr(settings, 'OPENAI_API_KEY') and settings.OPENAI_API_KEY):
|
| 35 |
+
os.environ["OPENAI_API_KEY"] = settings.OPENAI_API_KEY
|
| 36 |
+
logger.info("OpenAI API key loaded from settings")
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
logger.warning("OPENAI_API_KEY not found in settings. Please add it to your .env file.")
|
| 40 |
+
|
| 41 |
+
# Supressing Warning
|
| 42 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def sanitize_ragas_score(value: Any, metric_name: str = "unknown") -> float:
|
| 46 |
+
"""
|
| 47 |
+
Sanitize a single RAGAS score to handle NaN, None, and invalid values
|
| 48 |
+
|
| 49 |
+
Arguments:
|
| 50 |
+
----------
|
| 51 |
+
value { Any } : Raw score value
|
| 52 |
+
|
| 53 |
+
metric_name { str } : Name of the metric (for logging)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
--------
|
| 57 |
+
{ float } : Valid float between 0.0 and 1.0
|
| 58 |
+
"""
|
| 59 |
+
# Handle None
|
| 60 |
+
if value is None:
|
| 61 |
+
return 0.0
|
| 62 |
+
|
| 63 |
+
# Handle NaN and infinity
|
| 64 |
+
try:
|
| 65 |
+
float_val = float(value)
|
| 66 |
+
|
| 67 |
+
if math.isnan(float_val) or math.isinf(float_val):
|
| 68 |
+
logger.warning(f"Invalid RAGAS score for {metric_name}: {value}, defaulting to 0.0")
|
| 69 |
+
return 0.0
|
| 70 |
+
|
| 71 |
+
# Clamp between 0 and 1
|
| 72 |
+
return max(0.0, min(1.0, float_val))
|
| 73 |
+
|
| 74 |
+
except (ValueError, TypeError):
|
| 75 |
+
logger.warning(f"Could not convert RAGAS score for {metric_name}: {value}, defaulting to 0.0")
|
| 76 |
+
return 0.0
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class RAGASEvaluator:
|
| 80 |
+
"""
|
| 81 |
+
RAGAS evaluation module for RAG system quality assessment
|
| 82 |
+
"""
|
| 83 |
+
def __init__(self, enable_ground_truth_metrics: bool = False):
|
| 84 |
+
"""
|
| 85 |
+
Initialize RAGAS evaluator
|
| 86 |
+
|
| 87 |
+
Arguments:
|
| 88 |
+
----------
|
| 89 |
+
enable_ground_truth_metrics { bool } : Whether to compute metrics requiring ground truth
|
| 90 |
+
"""
|
| 91 |
+
self.enable_ground_truth = enable_ground_truth_metrics
|
| 92 |
+
|
| 93 |
+
# Metrics that don't require ground truth (UPDATED)
|
| 94 |
+
self.base_metrics = [answer_relevancy,
|
| 95 |
+
faithfulness,
|
| 96 |
+
context_utilization,
|
| 97 |
+
context_relevancy,
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
# Metrics requiring ground truth
|
| 101 |
+
self.ground_truth_metrics = [context_precision,
|
| 102 |
+
context_recall,
|
| 103 |
+
answer_similarity,
|
| 104 |
+
answer_correctness,
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
# Store evaluation history
|
| 108 |
+
self.evaluation_history : List[RAGASEvaluationResult] = list()
|
| 109 |
+
self.session_start = datetime.now()
|
| 110 |
+
|
| 111 |
+
logger.info(f"RAGAS Evaluator initialized (ground_truth_metrics: {enable_ground_truth_metrics})")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def evaluate_single(self, query: str, answer: str, contexts: List[str], ground_truth: Optional[str] = None, retrieval_time_ms: int = 0,
|
| 115 |
+
generation_time_ms: int = 0, total_time_ms: int = 0, chunks_retrieved: int = 0, query_type: str = "rag") -> RAGASEvaluationResult:
|
| 116 |
+
"""
|
| 117 |
+
Evaluate a single query-answer pair using RAGAS metrics
|
| 118 |
+
|
| 119 |
+
Arguments:
|
| 120 |
+
----------
|
| 121 |
+
query { str } : User query
|
| 122 |
+
|
| 123 |
+
answer { str } : Generated answer
|
| 124 |
+
|
| 125 |
+
contexts { list } : Retrieved context chunks
|
| 126 |
+
|
| 127 |
+
ground_truth { str } : Reference answer (optional)
|
| 128 |
+
|
| 129 |
+
retrieval_time_ms { int } : Retrieval time in milliseconds
|
| 130 |
+
|
| 131 |
+
generation_time_ms { int } : Generation time in milliseconds
|
| 132 |
+
|
| 133 |
+
total_time_ms { int } : Total time in milliseconds
|
| 134 |
+
|
| 135 |
+
chunks_retrieved { int } : Number of chunks retrieved
|
| 136 |
+
|
| 137 |
+
query_type { str } : Type of the query : RAG or non-RAG
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
--------
|
| 141 |
+
{ RAGASEvaluationResult } : RAGASEvaluationResult object
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
logger.info(f"Evaluating {query_type.upper()}, query: {query[:100]}...")
|
| 145 |
+
|
| 146 |
+
if ((query_type == "general") or (query_type == "non-rag")):
|
| 147 |
+
logger.info(f"Skipping detailed RAGAS evaluation for {query_type} query")
|
| 148 |
+
|
| 149 |
+
return RAGASEvaluationResult(query = query,
|
| 150 |
+
answer = answer,
|
| 151 |
+
contexts = contexts,
|
| 152 |
+
ground_truth = ground_truth,
|
| 153 |
+
timestamp = datetime.now().isoformat(),
|
| 154 |
+
answer_relevancy = 0.0, # N/A for non-RAG
|
| 155 |
+
faithfulness = 0.0, # N/A for non-RAG
|
| 156 |
+
context_utilization = None,
|
| 157 |
+
context_precision = None,
|
| 158 |
+
context_relevancy = 0.0, # N/A for non-RAG
|
| 159 |
+
context_recall = None,
|
| 160 |
+
answer_similarity = None,
|
| 161 |
+
answer_correctness = None,
|
| 162 |
+
retrieval_time_ms = retrieval_time_ms,
|
| 163 |
+
generation_time_ms = generation_time_ms,
|
| 164 |
+
total_time_ms = total_time_ms,
|
| 165 |
+
chunks_retrieved = chunks_retrieved,
|
| 166 |
+
query_type = query_type,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Only for RAG queries : Validate inputs
|
| 170 |
+
if not contexts or not any(c.strip() for c in contexts):
|
| 171 |
+
logger.warning("No valid contexts provided for RAGAS evaluation")
|
| 172 |
+
raise ValueError("No valid contexts for evaluation")
|
| 173 |
+
|
| 174 |
+
# Prepare dataset for RAGAS
|
| 175 |
+
eval_data = {"question" : [query],
|
| 176 |
+
"answer" : [answer],
|
| 177 |
+
"contexts" : [contexts],
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
# Add ground truth if available
|
| 181 |
+
if ground_truth and self.enable_ground_truth:
|
| 182 |
+
eval_data["ground_truth"] = [ground_truth]
|
| 183 |
+
|
| 184 |
+
# Create dataset
|
| 185 |
+
dataset = Dataset.from_dict(eval_data)
|
| 186 |
+
|
| 187 |
+
# Select metrics based on ground truth availability
|
| 188 |
+
if (ground_truth and self.enable_ground_truth):
|
| 189 |
+
metrics_to_use = self.base_metrics + self.ground_truth_metrics
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
metrics_to_use = self.base_metrics
|
| 193 |
+
|
| 194 |
+
# Run evaluation
|
| 195 |
+
logger.info(f"Running RAGAS evaluation with {len(metrics_to_use)} metrics...")
|
| 196 |
+
|
| 197 |
+
results = evaluate(dataset, metrics = metrics_to_use)
|
| 198 |
+
|
| 199 |
+
# Extract scores
|
| 200 |
+
scores = results.to_pandas().iloc[0].to_dict()
|
| 201 |
+
|
| 202 |
+
# Sanitize all scores to handle NaN values
|
| 203 |
+
answer_relevancy = sanitize_ragas_score(scores.get('answer_relevancy'), 'answer_relevancy')
|
| 204 |
+
|
| 205 |
+
faithfulness = sanitize_ragas_score(scores.get('faithfulness'), 'faithfulness')
|
| 206 |
+
|
| 207 |
+
context_utilization_val = sanitize_ragas_score(scores.get('context_utilization'), 'context_utilization') if not ground_truth else None
|
| 208 |
+
|
| 209 |
+
context_relevancy_val = sanitize_ragas_score(scores.get('context_relevancy'), 'context_relevancy')
|
| 210 |
+
|
| 211 |
+
# Ground truth metrics (sanitized)
|
| 212 |
+
context_precision_val = None
|
| 213 |
+
context_recall_val = None
|
| 214 |
+
answer_similarity_val = None
|
| 215 |
+
answer_correctness_val = None
|
| 216 |
+
|
| 217 |
+
if (ground_truth and ('context_precision' in scores)):
|
| 218 |
+
context_precision_val = sanitize_ragas_score(scores.get('context_precision'), 'context_precision')
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if (ground_truth and ('context_recall' in scores)):
|
| 223 |
+
context_recall_val = sanitize_ragas_score(scores.get('context_recall'), 'context_recall')
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if ground_truth and 'answer_similarity' in scores:
|
| 227 |
+
answer_similarity_val = sanitize_ragas_score(scores.get('answer_similarity'), 'answer_similarity')
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if ground_truth and 'answer_correctness' in scores:
|
| 231 |
+
answer_correctness_val = sanitize_ragas_score(scores.get('answer_correctness'), 'answer_correctness')
|
| 232 |
+
|
| 233 |
+
# Create result object with sanitized values
|
| 234 |
+
result = RAGASEvaluationResult(query = query,
|
| 235 |
+
answer = answer,
|
| 236 |
+
contexts = contexts,
|
| 237 |
+
ground_truth = ground_truth,
|
| 238 |
+
timestamp = datetime.now().isoformat(),
|
| 239 |
+
answer_relevancy = answer_relevancy,
|
| 240 |
+
faithfulness = faithfulness,
|
| 241 |
+
context_utilization = context_utilization_val,
|
| 242 |
+
context_precision = context_precision_val,
|
| 243 |
+
context_relevancy = context_relevancy_val,
|
| 244 |
+
context_recall = context_recall_val,
|
| 245 |
+
answer_similarity = answer_similarity_val,
|
| 246 |
+
answer_correctness = answer_correctness_val,
|
| 247 |
+
retrieval_time_ms = retrieval_time_ms,
|
| 248 |
+
generation_time_ms = generation_time_ms,
|
| 249 |
+
total_time_ms = total_time_ms,
|
| 250 |
+
chunks_retrieved = chunks_retrieved,
|
| 251 |
+
query_type = query_type,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Store in history
|
| 255 |
+
self.evaluation_history.append(result)
|
| 256 |
+
|
| 257 |
+
# Log results
|
| 258 |
+
if ground_truth:
|
| 259 |
+
logger.info(f"Evaluation complete: relevancy={result.answer_relevancy:.3f}, faithfulness={result.faithfulness:.3f}, precision={result.context_precision:.3f}, overall={result.overall_score:.3f}")
|
| 260 |
+
|
| 261 |
+
else:
|
| 262 |
+
logger.info(f"Evaluation complete: relevancy={result.answer_relevancy:.3f}, faithfulness={result.faithfulness:.3f}, utilization={result.context_utilization:.3f}, overall={result.overall_score:.3f}")
|
| 263 |
+
|
| 264 |
+
return result
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"RAGAS evaluation failed for {query_type} query: {e}", exc_info = True)
|
| 268 |
+
|
| 269 |
+
# Return zero metrics on failure (all sanitized)
|
| 270 |
+
return RAGASEvaluationResult(query = query,
|
| 271 |
+
answer = answer,
|
| 272 |
+
contexts = contexts,
|
| 273 |
+
ground_truth = ground_truth,
|
| 274 |
+
timestamp = datetime.now().isoformat(),
|
| 275 |
+
answer_relevancy = 0.0,
|
| 276 |
+
faithfulness = 0.0,
|
| 277 |
+
context_utilization = 0.0 if not ground_truth else None,
|
| 278 |
+
context_precision = None if not ground_truth else 0.0,
|
| 279 |
+
context_relevancy = 0.0,
|
| 280 |
+
context_recall = None,
|
| 281 |
+
answer_similarity = None,
|
| 282 |
+
answer_correctness = None,
|
| 283 |
+
retrieval_time_ms = retrieval_time_ms,
|
| 284 |
+
generation_time_ms = generation_time_ms,
|
| 285 |
+
total_time_ms = total_time_ms,
|
| 286 |
+
chunks_retrieved = chunks_retrieved,
|
| 287 |
+
query_type = query_type
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def evaluate_query_response(self, query_response: Any) -> Dict:
|
| 292 |
+
"""
|
| 293 |
+
Evaluate based on actual response characteristics, not predictions
|
| 294 |
+
|
| 295 |
+
Arguments:
|
| 296 |
+
----------
|
| 297 |
+
query_response { Any } : QueryResponse object with metadata
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
--------
|
| 301 |
+
{ dict } : RAGAS evaluation results
|
| 302 |
+
"""
|
| 303 |
+
try:
|
| 304 |
+
# Extract necessary data from response object: Check if it has the attributes we need
|
| 305 |
+
if (hasattr(query_response, 'sources')):
|
| 306 |
+
sources = query_response.sources
|
| 307 |
+
|
| 308 |
+
elif hasattr(query_response, 'contexts'):
|
| 309 |
+
sources = query_response.contexts
|
| 310 |
+
|
| 311 |
+
else:
|
| 312 |
+
sources = []
|
| 313 |
+
|
| 314 |
+
# Extract context from sources
|
| 315 |
+
contexts = list()
|
| 316 |
+
|
| 317 |
+
if (sources and len(sources) > 0):
|
| 318 |
+
if (hasattr(sources[0], 'content')):
|
| 319 |
+
contexts = [s.content for s in sources]
|
| 320 |
+
|
| 321 |
+
elif ((isinstance(sources[0], dict)) and ('content' in sources[0])):
|
| 322 |
+
contexts = [s['content'] for s in sources]
|
| 323 |
+
|
| 324 |
+
elif (isinstance(sources[0], str)):
|
| 325 |
+
contexts = sources
|
| 326 |
+
|
| 327 |
+
# Check if this is actually a RAG response
|
| 328 |
+
is_actual_rag = ((sources and len(sources) > 0) or (contexts and len(contexts) > 0) or (hasattr(query_response, 'metrics') and query_response.metrics and query_response.metrics.get("execution_path") == "rag_pipeline"))
|
| 329 |
+
|
| 330 |
+
if not is_actual_rag:
|
| 331 |
+
logger.info(f"Non-RAG response, skipping RAGAS evaluation")
|
| 332 |
+
return {"evaluated" : False,
|
| 333 |
+
"reason" : "Not a RAG response",
|
| 334 |
+
"is_rag" : False,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
# Get query and answer
|
| 338 |
+
query = getattr(query_response, 'query', '')
|
| 339 |
+
answer = getattr(query_response, 'answer', '')
|
| 340 |
+
|
| 341 |
+
if not query or not answer:
|
| 342 |
+
logger.warning("Missing query or answer for evaluation")
|
| 343 |
+
return {"evaluated" : False,
|
| 344 |
+
"reason" : "Missing query or answer",
|
| 345 |
+
"is_rag" : True,
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
# Check if context exists in metrics
|
| 349 |
+
if (hasattr(query_response, 'metrics') and query_response.metrics):
|
| 350 |
+
if (query_response.metrics.get("context_for_evaluation")):
|
| 351 |
+
contexts = [query_response.metrics["context_for_evaluation"]]
|
| 352 |
+
|
| 353 |
+
if ((not contexts) or (not any(c.strip() for c in contexts))):
|
| 354 |
+
logger.warning("No context available for RAGAS evaluation")
|
| 355 |
+
return {"evaluated" : False,
|
| 356 |
+
"reason" : "No context available",
|
| 357 |
+
"is_rag" : True,
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
# Try to get query_type from query_response
|
| 361 |
+
if (hasattr(query_response, 'query_type')):
|
| 362 |
+
detected_query_type = query_response.query_type
|
| 363 |
+
|
| 364 |
+
elif (hasattr(query_response, 'metrics') and query_response.metrics):
|
| 365 |
+
detected_query_type = query_response.metrics.get("query_type", "rag")
|
| 366 |
+
|
| 367 |
+
else:
|
| 368 |
+
# Determine based on contexts
|
| 369 |
+
detected_query_type = "rag" if (contexts and (len(contexts) > 0)) else "general"
|
| 370 |
+
|
| 371 |
+
# Now use the existing evaluate_single method
|
| 372 |
+
result = self.evaluate_single(query = query,
|
| 373 |
+
answer = answer,
|
| 374 |
+
contexts = contexts,
|
| 375 |
+
ground_truth = None,
|
| 376 |
+
retrieval_time_ms = getattr(query_response, 'retrieval_time_ms', 0),
|
| 377 |
+
generation_time_ms = getattr(query_response, 'generation_time_ms', 0),
|
| 378 |
+
total_time_ms = getattr(query_response, 'total_time_ms', 0),
|
| 379 |
+
chunks_retrieved = len(sources) if sources else len(contexts),
|
| 380 |
+
query_type = detected_query_type,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Convert to dict and add metadata
|
| 384 |
+
result_dict = result.to_dict() if hasattr(result, 'to_dict') else vars(result)
|
| 385 |
+
|
| 386 |
+
# Add evaluation metadata
|
| 387 |
+
result_dict["evaluated"] = True
|
| 388 |
+
result_dict["is_rag"] = True
|
| 389 |
+
result_dict["context_count"] = len(contexts)
|
| 390 |
+
|
| 391 |
+
# Add prediction vs reality info if available
|
| 392 |
+
if ((hasattr(query_response, 'metrics')) and query_response.metrics):
|
| 393 |
+
result_dict["predicted_type"] = query_response.metrics.get("predicted_type", "unknown")
|
| 394 |
+
result_dict["actual_type"] = query_response.metrics.get("actual_type", "unknown")
|
| 395 |
+
result_dict["confidence_mismatch"] = (query_response.metrics.get("predicted_type") != query_response.metrics.get("actual_type"))
|
| 396 |
+
|
| 397 |
+
logger.info(f"RAGAS evaluation completed for RAG response")
|
| 398 |
+
return result_dict
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.error(f"Query response evaluation failed: {repr(e)}", exc_info = True)
|
| 402 |
+
|
| 403 |
+
return {"evaluated" : False,
|
| 404 |
+
"error" : str(e),
|
| 405 |
+
"is_rag" : True,
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def evaluate_batch(self, queries: List[str], answers: List[str], contexts_list: List[List[str]], ground_truths: Optional[List[str]] = None,
|
| 410 |
+
query_types: Optional[List[str]] = None) -> List[RAGASEvaluationResult]:
|
| 411 |
+
"""
|
| 412 |
+
Evaluate multiple query-answer pairs in batch
|
| 413 |
+
|
| 414 |
+
Arguments:
|
| 415 |
+
----------
|
| 416 |
+
queries { list } : List of user queries
|
| 417 |
+
|
| 418 |
+
answers { list } : List of generated answers
|
| 419 |
+
|
| 420 |
+
contexts_list { list } : List of context lists
|
| 421 |
+
|
| 422 |
+
ground_truths { list } : List of reference answers (optional)
|
| 423 |
+
|
| 424 |
+
query_types { list } : List of query types RAG / non-RAG
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
--------
|
| 428 |
+
{ list } : List of RAGASEvaluationResult objects
|
| 429 |
+
"""
|
| 430 |
+
try:
|
| 431 |
+
logger.info(f"Batch evaluating {len(queries)} queries...")
|
| 432 |
+
|
| 433 |
+
# Prepare dataset
|
| 434 |
+
eval_data = {"question" : queries,
|
| 435 |
+
"answer" : answers,
|
| 436 |
+
"contexts" : contexts_list,
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
if ground_truths and self.enable_ground_truth:
|
| 440 |
+
eval_data["ground_truth"] = ground_truths
|
| 441 |
+
|
| 442 |
+
# Create dataset
|
| 443 |
+
dataset = Dataset.from_dict(eval_data)
|
| 444 |
+
|
| 445 |
+
# Select metrics
|
| 446 |
+
if (ground_truths and self.enable_ground_truth):
|
| 447 |
+
metrics_to_use = self.base_metrics + self.ground_truth_metrics
|
| 448 |
+
|
| 449 |
+
else:
|
| 450 |
+
metrics_to_use = self.base_metrics
|
| 451 |
+
|
| 452 |
+
# Run evaluation
|
| 453 |
+
results = evaluate(dataset, metrics = metrics_to_use)
|
| 454 |
+
results_df = results.to_pandas()
|
| 455 |
+
|
| 456 |
+
# Create result objects
|
| 457 |
+
evaluation_results = list()
|
| 458 |
+
|
| 459 |
+
for idx, row in results_df.iterrows():
|
| 460 |
+
# Determine query_type for this item
|
| 461 |
+
if query_types and idx < len(query_types):
|
| 462 |
+
current_query_type = query_types[idx]
|
| 463 |
+
|
| 464 |
+
else:
|
| 465 |
+
# Default based on whether contexts are available
|
| 466 |
+
current_query_type = "rag" if contexts_list[idx] and len(contexts_list[idx]) > 0 else "general"
|
| 467 |
+
|
| 468 |
+
# Sanitize all scores
|
| 469 |
+
answer_relevancy_val = sanitize_ragas_score(row.get('answer_relevancy', 0.0), f'answer_relevancy_{idx}')
|
| 470 |
+
|
| 471 |
+
faithfulness_val = sanitize_ragas_score(row.get('faithfulness', 0.0), f'faithfulness_{idx}')
|
| 472 |
+
|
| 473 |
+
context_relevancy_val = sanitize_ragas_score(row.get('context_relevancy', 0.0), f'context_relevancy_{idx}')
|
| 474 |
+
|
| 475 |
+
# Handle context_utilization vs context_precision
|
| 476 |
+
context_utilization_val = sanitize_ragas_score(row.get('context_utilization'), f'context_utilization_{idx}') if not ground_truths else None
|
| 477 |
+
|
| 478 |
+
context_precision_val = sanitize_ragas_score(row.get('context_precision'), f'context_precision_{idx}') if (ground_truths and 'context_precision' in row) else None
|
| 479 |
+
|
| 480 |
+
# Ground truth metrics
|
| 481 |
+
context_recall_val = sanitize_ragas_score(row.get('context_recall'), f'context_recall_{idx}') if (ground_truths and 'context_recall' in row) else None
|
| 482 |
+
|
| 483 |
+
answer_similarity_val = sanitize_ragas_score(row.get('answer_similarity'), f'answer_similarity_{idx}') if (ground_truths and 'answer_similarity' in row) else None
|
| 484 |
+
|
| 485 |
+
answer_correctness_val = sanitize_ragas_score(row.get('answer_correctness'), f'answer_correctness_{idx}') if (ground_truths and 'answer_correctness' in row) else None
|
| 486 |
+
|
| 487 |
+
# For non-RAG queries, set appropriate scores
|
| 488 |
+
if ((current_query_type == "general") or (current_query_type == "non-rag")):
|
| 489 |
+
# Non-RAG queries shouldn't have RAGAS metrics
|
| 490 |
+
answer_relevancy_val = 0.0
|
| 491 |
+
faithfulness_val = 0.0
|
| 492 |
+
context_relevancy_val = 0.0
|
| 493 |
+
context_utilization_val = None
|
| 494 |
+
context_precision_val = None
|
| 495 |
+
|
| 496 |
+
result = RAGASEvaluationResult(query = queries[idx],
|
| 497 |
+
answer = answers[idx],
|
| 498 |
+
contexts = contexts_list[idx],
|
| 499 |
+
ground_truth = ground_truths[idx] if ground_truths else None,
|
| 500 |
+
timestamp = datetime.now().isoformat(),
|
| 501 |
+
answer_relevancy = answer_relevancy_val,
|
| 502 |
+
faithfulness = faithfulness_val,
|
| 503 |
+
context_precision = context_precision_val,
|
| 504 |
+
context_utilization = context_utilization_val,
|
| 505 |
+
context_relevancy = context_relevancy_val,
|
| 506 |
+
context_recall = context_recall_val,
|
| 507 |
+
answer_similarity = answer_similarity_val,
|
| 508 |
+
answer_correctness = answer_correctness_val,
|
| 509 |
+
retrieval_time_ms = 0,
|
| 510 |
+
generation_time_ms = 0,
|
| 511 |
+
total_time_ms = 0,
|
| 512 |
+
chunks_retrieved = len(contexts_list[idx]),
|
| 513 |
+
query_type = current_query_type,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
evaluation_results.append(result)
|
| 517 |
+
|
| 518 |
+
self.evaluation_history.append(result)
|
| 519 |
+
|
| 520 |
+
logger.info(f"Batch evaluation complete for {len(evaluation_results)} queries")
|
| 521 |
+
return evaluation_results
|
| 522 |
+
|
| 523 |
+
except Exception as e:
|
| 524 |
+
logger.error(f"Batch evaluation failed: {e}", exc_info = True)
|
| 525 |
+
return []
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def get_session_statistics(self) -> RAGASStatistics:
|
| 529 |
+
"""
|
| 530 |
+
Get aggregate statistics for the current evaluation session
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
---------
|
| 534 |
+
{ RAGASStatistics } : RAGASStatistics object with aggregate metrics
|
| 535 |
+
"""
|
| 536 |
+
if not self.evaluation_history:
|
| 537 |
+
# Return empty statistics
|
| 538 |
+
return RAGASStatistics(total_evaluations = 0,
|
| 539 |
+
avg_answer_relevancy = 0.0,
|
| 540 |
+
avg_faithfulness = 0.0,
|
| 541 |
+
avg_context_precision = 0.0,
|
| 542 |
+
avg_context_utilization = 0.0,
|
| 543 |
+
avg_context_relevancy = 0.0,
|
| 544 |
+
avg_overall_score = 0.0,
|
| 545 |
+
avg_retrieval_time_ms = 0.0,
|
| 546 |
+
avg_generation_time_ms = 0.0,
|
| 547 |
+
avg_total_time_ms = 0.0,
|
| 548 |
+
min_score = 0.0,
|
| 549 |
+
max_score = 0.0,
|
| 550 |
+
std_dev = 0.0,
|
| 551 |
+
session_start = self.session_start,
|
| 552 |
+
last_updated = datetime.now(),
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
n = len(self.evaluation_history)
|
| 556 |
+
|
| 557 |
+
# Calculate averages
|
| 558 |
+
avg_relevancy = sum(r.answer_relevancy for r in self.evaluation_history) / n
|
| 559 |
+
avg_faithfulness = sum(r.faithfulness for r in self.evaluation_history) / n
|
| 560 |
+
|
| 561 |
+
# Calculate context_precision and context_utilization separately
|
| 562 |
+
precision_values = [r.context_precision for r in self.evaluation_history if r.context_precision is not None]
|
| 563 |
+
utilization_values = [r.context_utilization for r in self.evaluation_history if r.context_utilization is not None]
|
| 564 |
+
|
| 565 |
+
avg_precision = sum(precision_values) / len(precision_values) if precision_values else 0.0
|
| 566 |
+
avg_utilization = sum(utilization_values) / len(utilization_values) if utilization_values else 0.0
|
| 567 |
+
|
| 568 |
+
avg_relevancy_ctx = sum(r.context_relevancy for r in self.evaluation_history) / n
|
| 569 |
+
|
| 570 |
+
# Overall scores
|
| 571 |
+
overall_scores = [r.overall_score for r in self.evaluation_history]
|
| 572 |
+
avg_overall = sum(overall_scores) / n
|
| 573 |
+
min_score = min(overall_scores)
|
| 574 |
+
max_score = max(overall_scores)
|
| 575 |
+
std_dev = statistics.stdev(overall_scores) if n > 1 else 0.0
|
| 576 |
+
|
| 577 |
+
# Performance averages
|
| 578 |
+
avg_retrieval = sum(r.retrieval_time_ms for r in self.evaluation_history) / n
|
| 579 |
+
avg_generation = sum(r.generation_time_ms for r in self.evaluation_history) / n
|
| 580 |
+
avg_total = sum(r.total_time_ms for r in self.evaluation_history) / n
|
| 581 |
+
|
| 582 |
+
# Ground truth metrics (if available)
|
| 583 |
+
recall_values = [r.context_recall for r in self.evaluation_history if r.context_recall is not None]
|
| 584 |
+
similarity_values = [r.answer_similarity for r in self.evaluation_history if r.answer_similarity is not None]
|
| 585 |
+
correctness_values = [r.answer_correctness for r in self.evaluation_history if r.answer_correctness is not None]
|
| 586 |
+
|
| 587 |
+
return RAGASStatistics(total_evaluations = n,
|
| 588 |
+
avg_answer_relevancy = round(avg_relevancy, 3),
|
| 589 |
+
avg_faithfulness = round(avg_faithfulness, 3),
|
| 590 |
+
avg_context_precision = round(avg_precision, 3) if precision_values else None,
|
| 591 |
+
avg_context_utilization = round(avg_utilization, 3) if utilization_values else None,
|
| 592 |
+
avg_context_relevancy = round(avg_relevancy_ctx, 3),
|
| 593 |
+
avg_overall_score = round(avg_overall, 3),
|
| 594 |
+
avg_context_recall = round(sum(recall_values) / len(recall_values), 3) if recall_values else None,
|
| 595 |
+
avg_answer_similarity = round(sum(similarity_values) / len(similarity_values), 3) if similarity_values else None,
|
| 596 |
+
avg_answer_correctness = round(sum(correctness_values) / len(correctness_values), 3) if correctness_values else None,
|
| 597 |
+
avg_retrieval_time_ms = round(avg_retrieval, 2),
|
| 598 |
+
avg_generation_time_ms = round(avg_generation, 2),
|
| 599 |
+
avg_total_time_ms = round(avg_total, 2),
|
| 600 |
+
min_score = round(min_score, 3),
|
| 601 |
+
max_score = round(max_score, 3),
|
| 602 |
+
std_dev = round(std_dev, 3),
|
| 603 |
+
session_start = self.session_start,
|
| 604 |
+
last_updated = datetime.now(),
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def get_evaluation_history(self) -> List[Dict]:
|
| 609 |
+
"""
|
| 610 |
+
Get full evaluation history as list of dictionaries
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
--------
|
| 614 |
+
{ list } : List of evaluation results as dictionaries
|
| 615 |
+
"""
|
| 616 |
+
return [result.to_dict() for result in self.evaluation_history]
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def clear_history(self):
|
| 620 |
+
"""
|
| 621 |
+
Clear evaluation history and reset session
|
| 622 |
+
"""
|
| 623 |
+
self.evaluation_history.clear()
|
| 624 |
+
self.session_start = datetime.now()
|
| 625 |
+
|
| 626 |
+
logger.info("Evaluation history cleared, new session started")
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def export_to_dict(self) -> RAGASExportData:
|
| 630 |
+
"""
|
| 631 |
+
Export all evaluations to structured format
|
| 632 |
+
|
| 633 |
+
Returns:
|
| 634 |
+
--------
|
| 635 |
+
{ RAGASExportData } : RAGASExportData object with complete evaluation data
|
| 636 |
+
"""
|
| 637 |
+
return RAGASExportData(export_timestamp = datetime.now().isoformat(),
|
| 638 |
+
total_evaluations = len(self.evaluation_history),
|
| 639 |
+
statistics = self.get_session_statistics(),
|
| 640 |
+
evaluations = self.evaluation_history,
|
| 641 |
+
ground_truth_enabled = self.enable_ground_truth,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
# Global evaluator instance
|
| 647 |
+
_ragas_evaluator : Optional[RAGASEvaluator] = None
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def get_ragas_evaluator(enable_ground_truth_metrics: bool = False) -> RAGASEvaluator:
|
| 651 |
+
"""
|
| 652 |
+
Get or create global RAGAS evaluator instance
|
| 653 |
+
|
| 654 |
+
Arguments:
|
| 655 |
+
----------
|
| 656 |
+
enable_ground_truth_metrics { bool } : Whether to enable ground truth metrics
|
| 657 |
+
|
| 658 |
+
Returns:
|
| 659 |
+
--------
|
| 660 |
+
{ RAGASEvaluator } : RAGASEvaluator instance
|
| 661 |
+
"""
|
| 662 |
+
global _ragas_evaluator
|
| 663 |
+
|
| 664 |
+
if _ragas_evaluator is None:
|
| 665 |
+
_ragas_evaluator = RAGASEvaluator(enable_ground_truth_metrics)
|
| 666 |
+
|
| 667 |
+
return _ragas_evaluator
|
frontend/index.html
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generation/__init__.py
ADDED
|
File without changes
|
generation/citation_formatter.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import List
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from config.models import CitationStyle
|
| 9 |
+
from config.models import ChunkWithScore
|
| 10 |
+
from config.logging_config import get_logger
|
| 11 |
+
from utils.error_handler import handle_errors
|
| 12 |
+
from utils.error_handler import CitationFormattingError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Setup Logging
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CitationFormatter:
|
| 20 |
+
"""
|
| 21 |
+
Citation formatting and management: Formats citations in generated text according to different styles and ensures citation consistency and validity
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, style: CitationStyle = CitationStyle.NUMERIC):
|
| 24 |
+
"""
|
| 25 |
+
Initialize citation formatter
|
| 26 |
+
|
| 27 |
+
Arguments:
|
| 28 |
+
----------
|
| 29 |
+
style { CitationStyle } : Citation style to use
|
| 30 |
+
"""
|
| 31 |
+
self.logger = logger
|
| 32 |
+
self.style = style
|
| 33 |
+
self.citation_pattern = re.compile(r'\[(\d+)\]')
|
| 34 |
+
|
| 35 |
+
# Style configurations
|
| 36 |
+
self.style_configs = {CitationStyle.NUMERIC : {"inline_format" : "[{number}]", "reference_format" : "[{number}] {source_info}", "separator" : " ",},
|
| 37 |
+
CitationStyle.VERBOSE : {"inline_format" : "[{number}]", "reference_format" : "Citation {number}: {source_info}", "separator" : "\n",},
|
| 38 |
+
CitationStyle.MINIMAL : {"inline_format" : "[{number}]", "reference_format" : "[{number}]", "separator" : " ",},
|
| 39 |
+
CitationStyle.ACADEMIC : {"inline_format" : "({number})", "reference_format" : "{number}. {source_info}", "separator" : "\n",},
|
| 40 |
+
CitationStyle.LEGAL : {"inline_format" : "[{number}]", "reference_format" : "[{number}] {source_info}", "separator" : "\n",}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def format_citations_in_text(self, text: str, sources: List[ChunkWithScore]) -> str:
|
| 45 |
+
"""
|
| 46 |
+
Format citations in generated text
|
| 47 |
+
|
| 48 |
+
Arguments:
|
| 49 |
+
----------
|
| 50 |
+
text { str } : Text containing citation markers
|
| 51 |
+
|
| 52 |
+
sources { list } : List of sources for citation mapping
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
--------
|
| 56 |
+
{ str } : Text with formatted citations
|
| 57 |
+
"""
|
| 58 |
+
if not text or not sources:
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
# Extract citation numbers from text
|
| 63 |
+
citation_numbers = self._extract_citation_numbers(text = text)
|
| 64 |
+
|
| 65 |
+
if not citation_numbers:
|
| 66 |
+
return text
|
| 67 |
+
|
| 68 |
+
# Create citation mapping
|
| 69 |
+
citation_map = self._create_citation_map(sources = sources)
|
| 70 |
+
|
| 71 |
+
# Replace citation markers with formatted citations
|
| 72 |
+
formatted_text = self._replace_citation_markers(text = text,
|
| 73 |
+
citation_map = citation_map,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.logger.debug(f"Formatted {len(citation_numbers)} citations in text")
|
| 77 |
+
|
| 78 |
+
return formatted_text
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
self.logger.error(f"Citation formatting failed: {repr(e)}")
|
| 82 |
+
raise CitationFormattingError(f"Citation formatting failed: {repr(e)}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def generate_reference_section(self, sources: List[ChunkWithScore], cited_numbers: List[int]) -> str:
|
| 86 |
+
"""
|
| 87 |
+
Generate reference section for cited sources
|
| 88 |
+
|
| 89 |
+
Arguments:
|
| 90 |
+
----------
|
| 91 |
+
sources { list } : All available sources
|
| 92 |
+
|
| 93 |
+
cited_numbers { list } : Numbers of actually cited sources
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
--------
|
| 97 |
+
{ str } : Formatted reference section
|
| 98 |
+
"""
|
| 99 |
+
if not sources or not cited_numbers:
|
| 100 |
+
return ""
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
style_config = self.style_configs[self.style]
|
| 104 |
+
references = list()
|
| 105 |
+
|
| 106 |
+
# Get only cited sources
|
| 107 |
+
cited_sources = [sources[num-1] for num in cited_numbers if (1 <= num <= len(sources))]
|
| 108 |
+
|
| 109 |
+
for i, source in enumerate(cited_sources, 1):
|
| 110 |
+
source_info = self._format_source_info(source, i)
|
| 111 |
+
reference = style_config["reference_format"].format(number = i, source_info = source_info)
|
| 112 |
+
|
| 113 |
+
references.append(reference)
|
| 114 |
+
|
| 115 |
+
separator = style_config["separator"]
|
| 116 |
+
reference_section = separator.join(references)
|
| 117 |
+
|
| 118 |
+
# Add section header if appropriate
|
| 119 |
+
if (self.style in [CitationStyle.VERBOSE, CitationStyle.ACADEMIC]):
|
| 120 |
+
reference_section = "References:\n" + reference_section
|
| 121 |
+
|
| 122 |
+
self.logger.debug(f"Generated reference section with {len(references)} entries")
|
| 123 |
+
|
| 124 |
+
return reference_section
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
self.logger.error(f"Reference section generation failed: {repr(e)}")
|
| 128 |
+
return ""
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _extract_citation_numbers(self, text: str) -> List[int]:
|
| 132 |
+
"""
|
| 133 |
+
Extract citation numbers from text
|
| 134 |
+
"""
|
| 135 |
+
matches = self.citation_pattern.findall(text)
|
| 136 |
+
citation_numbers = [int(match) for match in matches]
|
| 137 |
+
|
| 138 |
+
# Unique and sorted
|
| 139 |
+
return sorted(set(citation_numbers))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _create_citation_map(self, sources: List[ChunkWithScore]) -> Dict[int, str]:
|
| 143 |
+
"""
|
| 144 |
+
Create mapping from citation numbers to formatted citations
|
| 145 |
+
"""
|
| 146 |
+
citation_map = dict()
|
| 147 |
+
style_config = self.style_configs[self.style]
|
| 148 |
+
|
| 149 |
+
for i, source in enumerate(sources, 1):
|
| 150 |
+
formatted_citation = style_config["inline_format"].format(number=i)
|
| 151 |
+
citation_map[i] = formatted_citation
|
| 152 |
+
|
| 153 |
+
return citation_map
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _replace_citation_markers(self, text: str, citation_map: Dict[int, str]) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Replace citation markers in text
|
| 159 |
+
"""
|
| 160 |
+
def replacement(match):
|
| 161 |
+
citation_num = int(match.group(1))
|
| 162 |
+
|
| 163 |
+
return citation_map.get(citation_num, match.group(0))
|
| 164 |
+
|
| 165 |
+
return self.citation_pattern.sub(replacement, text)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _format_source_info(self, source: ChunkWithScore, citation_number: int) -> str:
|
| 169 |
+
"""
|
| 170 |
+
Format source information based on style
|
| 171 |
+
"""
|
| 172 |
+
chunk = source.chunk
|
| 173 |
+
|
| 174 |
+
if (self.style == CitationStyle.MINIMAL):
|
| 175 |
+
return f"Source {citation_number}"
|
| 176 |
+
|
| 177 |
+
# Build source components
|
| 178 |
+
components = list()
|
| 179 |
+
|
| 180 |
+
# Document information
|
| 181 |
+
if hasattr(chunk, 'metadata'):
|
| 182 |
+
meta = chunk.metadata
|
| 183 |
+
|
| 184 |
+
if ('filename' in meta):
|
| 185 |
+
components.append(f"Document: {meta['filename']}")
|
| 186 |
+
|
| 187 |
+
if (('title' in meta) and meta['title']):
|
| 188 |
+
components.append(f"\"{meta['title']}\"")
|
| 189 |
+
|
| 190 |
+
if (('author' in meta) and meta['author']):
|
| 191 |
+
components.append(f"by {meta['author']}")
|
| 192 |
+
|
| 193 |
+
# Location information
|
| 194 |
+
location_parts = list()
|
| 195 |
+
|
| 196 |
+
if chunk.page_number:
|
| 197 |
+
location_parts.append(f"p. {chunk.page_number}")
|
| 198 |
+
|
| 199 |
+
if chunk.section_title:
|
| 200 |
+
location_parts.append(f"Section: {chunk.section_title}")
|
| 201 |
+
|
| 202 |
+
if location_parts:
|
| 203 |
+
components.append("(" + ", ".join(location_parts) + ")")
|
| 204 |
+
|
| 205 |
+
# Relevance score (for verbose styles)
|
| 206 |
+
if ((self.style in [CitationStyle.VERBOSE, CitationStyle.ACADEMIC]) and (source.score > 0)):
|
| 207 |
+
components.append(f"[relevance: {source.score:.3f}]")
|
| 208 |
+
|
| 209 |
+
return " ".join(components)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def validate_citations(self, text: str, sources: List[ChunkWithScore]) -> tuple[bool, List[int]]:
|
| 213 |
+
"""
|
| 214 |
+
Validate citations in text
|
| 215 |
+
|
| 216 |
+
Arguments:
|
| 217 |
+
----------
|
| 218 |
+
text { str } : Text to validate
|
| 219 |
+
|
| 220 |
+
sources { list } : Available sources
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
--------
|
| 224 |
+
{ tuple } : (is_valid, invalid_citations)
|
| 225 |
+
"""
|
| 226 |
+
citation_numbers = self._extract_citation_numbers(text = text)
|
| 227 |
+
invalid_citations = list()
|
| 228 |
+
|
| 229 |
+
for number in citation_numbers:
|
| 230 |
+
if ((number < 1) or (number > len(sources))):
|
| 231 |
+
invalid_citations.append(number)
|
| 232 |
+
|
| 233 |
+
is_valid = (len(invalid_citations) == 0)
|
| 234 |
+
|
| 235 |
+
if not is_valid:
|
| 236 |
+
self.logger.warning(f"Invalid citations found: {invalid_citations}")
|
| 237 |
+
|
| 238 |
+
return is_valid, invalid_citations
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def normalize_citations(self, text: str, sources: List[ChunkWithScore]) -> str:
|
| 242 |
+
"""
|
| 243 |
+
Normalize citations to ensure sequential numbering
|
| 244 |
+
|
| 245 |
+
Arguments:
|
| 246 |
+
----------
|
| 247 |
+
text { str } : Text with citations
|
| 248 |
+
|
| 249 |
+
sources { list } : Available sources
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
--------
|
| 253 |
+
{ str } : Text with normalized citations
|
| 254 |
+
"""
|
| 255 |
+
citation_numbers = self._extract_citation_numbers(text = text)
|
| 256 |
+
|
| 257 |
+
if not citation_numbers:
|
| 258 |
+
return text
|
| 259 |
+
|
| 260 |
+
# Create mapping from old to new numbers
|
| 261 |
+
citation_mapping = dict()
|
| 262 |
+
|
| 263 |
+
for i, old_num in enumerate(sorted(set(citation_numbers)), 1):
|
| 264 |
+
if (1 <= old_num <= len(sources)):
|
| 265 |
+
citation_mapping[old_num] = i
|
| 266 |
+
|
| 267 |
+
# Replace citations
|
| 268 |
+
def normalize_replacement(match):
|
| 269 |
+
old_num = int(match.group(1))
|
| 270 |
+
new_num = citation_mapping.get(old_num, old_num)
|
| 271 |
+
style_config = self.style_configs[self.style]
|
| 272 |
+
|
| 273 |
+
return style_config["inline_format"].format(number = new_num)
|
| 274 |
+
|
| 275 |
+
normalized_text = self.citation_pattern.sub(normalize_replacement, text)
|
| 276 |
+
|
| 277 |
+
if citation_mapping:
|
| 278 |
+
self.logger.info(f"Normalized citations: {citation_numbers} -> {list(citation_mapping.values())}")
|
| 279 |
+
|
| 280 |
+
return normalized_text
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_citation_statistics(self, text: str, sources: List[ChunkWithScore]) -> Dict:
|
| 284 |
+
"""
|
| 285 |
+
Get citation statistics
|
| 286 |
+
|
| 287 |
+
Arguments:
|
| 288 |
+
----------
|
| 289 |
+
text { str } : Text with citations
|
| 290 |
+
|
| 291 |
+
sources { list } : Available sources
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
--------
|
| 295 |
+
{ dict } : Citation statistics
|
| 296 |
+
"""
|
| 297 |
+
citation_numbers = self._extract_citation_numbers(text = text)
|
| 298 |
+
|
| 299 |
+
if not citation_numbers:
|
| 300 |
+
return {"total_citations": 0}
|
| 301 |
+
|
| 302 |
+
# Calculate distribution
|
| 303 |
+
source_usage = defaultdict(int)
|
| 304 |
+
|
| 305 |
+
for number in citation_numbers:
|
| 306 |
+
if (1 <= number <= len(sources)):
|
| 307 |
+
source = sources[number-1]
|
| 308 |
+
doc_id = source.chunk.document_id
|
| 309 |
+
source_usage[doc_id] += 1
|
| 310 |
+
|
| 311 |
+
return {"total_citations" : len(citation_numbers),
|
| 312 |
+
"unique_citations" : len(set(citation_numbers)),
|
| 313 |
+
"sources_used" : len(source_usage),
|
| 314 |
+
"citations_per_source" : dict(source_usage),
|
| 315 |
+
"citation_density" : len(citation_numbers) / max(1, len(text.split())),
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def set_style(self, style: CitationStyle):
|
| 320 |
+
"""
|
| 321 |
+
Set citation style
|
| 322 |
+
|
| 323 |
+
Arguments:
|
| 324 |
+
----------
|
| 325 |
+
style { CitationStyle } : New citation style
|
| 326 |
+
"""
|
| 327 |
+
if (style not in self.style_configs):
|
| 328 |
+
raise CitationFormattingError(f"Unsupported citation style: {style}")
|
| 329 |
+
|
| 330 |
+
old_style = self.style
|
| 331 |
+
self.style = style
|
| 332 |
+
|
| 333 |
+
self.logger.info(f"Citation style changed: {old_style} -> {style}")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# Global citation formatter instance
|
| 337 |
+
_citation_formatter = None
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def get_citation_formatter() -> CitationFormatter:
|
| 341 |
+
"""
|
| 342 |
+
Get global citation formatter instance (singleton)
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
--------
|
| 346 |
+
{ CitationFormatter } : CitationFormatter instance
|
| 347 |
+
"""
|
| 348 |
+
global _citation_formatter
|
| 349 |
+
|
| 350 |
+
if _citation_formatter is None:
|
| 351 |
+
_citation_formatter = CitationFormatter()
|
| 352 |
+
|
| 353 |
+
return _citation_formatter
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@handle_errors(error_type = CitationFormattingError, log_error = True, reraise = False)
|
| 357 |
+
def format_citations(text: str, sources: List[ChunkWithScore], style: CitationStyle = None) -> str:
|
| 358 |
+
"""
|
| 359 |
+
Convenience function for citation formatting
|
| 360 |
+
|
| 361 |
+
Arguments:
|
| 362 |
+
----------
|
| 363 |
+
text { str } : Text containing citations
|
| 364 |
+
|
| 365 |
+
sources { list } : List of sources
|
| 366 |
+
|
| 367 |
+
style { CitationStyle } : Citation style to use
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
--------
|
| 371 |
+
{ str } : Formatted text
|
| 372 |
+
"""
|
| 373 |
+
formatter = get_citation_formatter()
|
| 374 |
+
|
| 375 |
+
if style is not None:
|
| 376 |
+
formatter.set_style(style)
|
| 377 |
+
|
| 378 |
+
return formatter.format_citations_in_text(text, sources)
|
generation/general_responder.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import random
|
| 3 |
+
import datetime
|
| 4 |
+
from typing import List
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from config.models import LLMProvider
|
| 7 |
+
from config.settings import get_settings
|
| 8 |
+
from config.logging_config import get_logger
|
| 9 |
+
from utils.error_handler import handle_errors
|
| 10 |
+
from generation.llm_client import get_llm_client
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Setup Settings and Logging
|
| 14 |
+
settings = get_settings()
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GeneralResponder:
|
| 19 |
+
"""
|
| 20 |
+
Handles general/conversational queries using LLM without RAG
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, provider: LLMProvider = None, model_name: str = None):
|
| 23 |
+
self.logger = logger
|
| 24 |
+
self.provider = provider or LLMProvider.OLLAMA
|
| 25 |
+
self.model_name = model_name or settings.OLLAMA_MODEL
|
| 26 |
+
|
| 27 |
+
# Initialize LLM client for general responses
|
| 28 |
+
self.llm_client = get_llm_client(provider = self.provider,
|
| 29 |
+
model_name = self.model_name,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# System prompt for general conversation
|
| 33 |
+
self.system_prompt = """
|
| 34 |
+
You are a helpful, friendly AI assistant. You're part of a larger system called the "AI Universal Knowledge Ingestion System" which specializes in document analysis and retrieval.
|
| 35 |
+
|
| 36 |
+
When users ask general questions, answer helpfully and conversationally. When they ask about your capabilities, explain that you can:
|
| 37 |
+
1. Answer general knowledge questions
|
| 38 |
+
2. Help with document analysis (when they upload documents)
|
| 39 |
+
3. Provide explanations on various topics
|
| 40 |
+
4. Engage in friendly conversation
|
| 41 |
+
|
| 42 |
+
If a question is better answered by searching through documents, politely suggest uploading documents first.
|
| 43 |
+
|
| 44 |
+
Be concise but thorough. Use a friendly, professional tone.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# Fallback responses (if LLM fails)
|
| 48 |
+
self.fallback_responses = {"greeting" : ["Hello! 👋 I'm your AI assistant. How can I help you today?", "Hi there! I'm here to help with questions or document analysis. What's on your mind?", "Greetings! I can answer questions or help analyze documents. What would you like to know?"],
|
| 49 |
+
"farewell" : ["Goodbye! Feel free to come back if you have more questions.", "See you later! Don't hesitate to ask if you need help with documents.", "Take care! Remember I'm here for document analysis too."],
|
| 50 |
+
"thanks" : ["You're welcome! Happy to help.", "My pleasure! Let me know if you need anything else.", "Glad I could help! Don't hesitate to ask more questions."],
|
| 51 |
+
"default" : ["I'm here to help! You can ask me general questions or upload documents for analysis.", "That's an interesting question! I'd be happy to discuss it with you.", "I can help with that. What specific aspect would you like to know about?"],
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
async def respond(self, query: str, conversation_history: List[Dict] = None) -> Dict:
|
| 56 |
+
"""
|
| 57 |
+
Generate a response to a general query
|
| 58 |
+
|
| 59 |
+
Arguments:
|
| 60 |
+
----------
|
| 61 |
+
query { str } : User query
|
| 62 |
+
|
| 63 |
+
conversation_history { list } : Previous messages in conversation
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
--------
|
| 67 |
+
{ dict } : Response dictionary
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
# Prepare messages for LLM
|
| 71 |
+
messages = list()
|
| 72 |
+
|
| 73 |
+
# Add system prompt
|
| 74 |
+
messages.append({"role" : "system",
|
| 75 |
+
"content" : self.system_prompt,
|
| 76 |
+
})
|
| 77 |
+
|
| 78 |
+
# Add conversation history if available
|
| 79 |
+
if conversation_history:
|
| 80 |
+
# Last 5 messages for context
|
| 81 |
+
messages.extend(conversation_history[-5:])
|
| 82 |
+
|
| 83 |
+
# Add current query
|
| 84 |
+
messages.append({"role" : "user",
|
| 85 |
+
"content" : query,
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
# Generate response: Slightly higher temp for conversational
|
| 89 |
+
llm_response = await self.llm_client.generate(messages = messages,
|
| 90 |
+
temperature = 0.7,
|
| 91 |
+
max_tokens = 500,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
response_text = llm_response.get("content", "").strip()
|
| 95 |
+
|
| 96 |
+
if not response_text:
|
| 97 |
+
response_text = self._get_fallback_response(query)
|
| 98 |
+
|
| 99 |
+
return {"answer" : response_text,
|
| 100 |
+
"is_general" : True,
|
| 101 |
+
"requires_rag" : False,
|
| 102 |
+
"tokens_used" : llm_response.get("usage", {}),
|
| 103 |
+
"model" : self.model_name,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"General response generation failed: {e}")
|
| 108 |
+
return {"answer" : self._get_fallback_response(query),
|
| 109 |
+
"is_general" : True,
|
| 110 |
+
"requires_rag" : False,
|
| 111 |
+
"error" : str(e),
|
| 112 |
+
"model" : self.model_name,
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_fallback_response(self, query: str) -> str:
|
| 117 |
+
"""
|
| 118 |
+
Get a fallback response if LLM fails
|
| 119 |
+
"""
|
| 120 |
+
query_lower = query.lower()
|
| 121 |
+
|
| 122 |
+
if (any(word in query_lower for word in ["hello", "hi", "hey", "greetings"])):
|
| 123 |
+
return random.choice(self.fallback_responses["greeting"])
|
| 124 |
+
|
| 125 |
+
elif (any(word in query_lower for word in ["thank", "thanks", "appreciate"])):
|
| 126 |
+
return random.choice(self.fallback_responses["thanks"])
|
| 127 |
+
|
| 128 |
+
elif (any(word in query_lower for word in ["bye", "goodbye", "see you"])):
|
| 129 |
+
return random.choice(self.fallback_responses["farewell"])
|
| 130 |
+
|
| 131 |
+
else:
|
| 132 |
+
return random.choice(self.fallback_responses["default"])
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# Global responder instance
|
| 137 |
+
_general_responder = None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_general_responder(provider: LLMProvider = None, model_name: str = None) -> GeneralResponder:
|
| 141 |
+
"""
|
| 142 |
+
Get global general responder instance
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
--------
|
| 146 |
+
{ GeneralResponder } : GeneralResponder instance
|
| 147 |
+
"""
|
| 148 |
+
global _general_responder
|
| 149 |
+
|
| 150 |
+
if _general_responder is None:
|
| 151 |
+
_general_responder = GeneralResponder(provider = provider,
|
| 152 |
+
model_name = model_name,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return _general_responder
|
generation/llm_client.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import aiohttp
|
| 6 |
+
import requests
|
| 7 |
+
from typing import List
|
| 8 |
+
from typing import Dict
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from typing import AsyncGenerator
|
| 11 |
+
from config.models import LLMProvider
|
| 12 |
+
from config.settings import get_settings
|
| 13 |
+
from config.logging_config import get_logger
|
| 14 |
+
from utils.error_handler import handle_errors
|
| 15 |
+
from utils.error_handler import LLMClientError
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Setup Settings and Logging
|
| 19 |
+
settings = get_settings()
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LLMClient:
|
| 24 |
+
"""
|
| 25 |
+
Unified LLM client supporting multiple providers (Ollama, OpenAI): Provides consistent interface for text generation across different LLM services
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, provider: LLMProvider = None, model_name: str = None, api_key: str = None, base_url: str = None):
|
| 28 |
+
"""
|
| 29 |
+
Initialize LLM client
|
| 30 |
+
|
| 31 |
+
Arguments:
|
| 32 |
+
----------
|
| 33 |
+
provider { LLMProvider } : LLM provider to use
|
| 34 |
+
|
| 35 |
+
model_name { str } : Model name to use
|
| 36 |
+
|
| 37 |
+
api_key { str } : API key (for OpenAI)
|
| 38 |
+
|
| 39 |
+
base_url { str } : Base URL for API (for Ollama)
|
| 40 |
+
"""
|
| 41 |
+
self.logger = logger
|
| 42 |
+
self.settings = get_settings()
|
| 43 |
+
self.provider = provider or LLMProvider.OLLAMA
|
| 44 |
+
self.model_name = model_name or self.settings.OLLAMA_MODEL
|
| 45 |
+
self.api_key = api_key
|
| 46 |
+
self.base_url = base_url or self.settings.OLLAMA_BASE_URL
|
| 47 |
+
self.timeout = self.settings.OLLAMA_TIMEOUT
|
| 48 |
+
|
| 49 |
+
# Initialize provider-specific configurations
|
| 50 |
+
self._initialize_provider()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _initialize_provider(self):
|
| 54 |
+
"""
|
| 55 |
+
Initialize provider-specific configurations
|
| 56 |
+
"""
|
| 57 |
+
# Auto-detect provider if not explicitly set
|
| 58 |
+
if (self.settings.IS_HF_SPACE and not self.settings.OLLAMA_ENABLED):
|
| 59 |
+
if (self.settings.USE_OPENAI and self.settings.OPENAI_API_KEY):
|
| 60 |
+
self.provider = LLMProvider.OPENAI
|
| 61 |
+
logger.info("HF Space detected: Using OpenAI API")
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
raise LLMClientError("Running in HF Space without Ollama. Set OPENAI_API_KEY in Space secrets.")
|
| 65 |
+
|
| 66 |
+
# Provider initialization
|
| 67 |
+
if (self.provider == LLMProvider.OLLAMA):
|
| 68 |
+
if not self.base_url:
|
| 69 |
+
raise LLMClientError("Ollama base URL is required")
|
| 70 |
+
|
| 71 |
+
self.logger.info(f"Initialized Ollama client: {self.base_url}, model: {self.model_name}")
|
| 72 |
+
|
| 73 |
+
elif (self.provider == LLMProvider.OPENAI):
|
| 74 |
+
if not self.api_key:
|
| 75 |
+
# Try to get from environment
|
| 76 |
+
self.api_key = os.getenv('OPENAI_API_KEY')
|
| 77 |
+
if not self.api_key:
|
| 78 |
+
raise LLMClientError("OpenAI API key is required")
|
| 79 |
+
|
| 80 |
+
self.base_url = "https://api.openai.com/v1"
|
| 81 |
+
|
| 82 |
+
self.logger.info(f"Initialized OpenAI client, model: {self.model_name}")
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
raise LLMClientError(f"Unsupported provider: {self.provider}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
async def generate(self, messages: List[Dict], **generation_params) -> Dict:
|
| 89 |
+
"""
|
| 90 |
+
Generate text completion (async)
|
| 91 |
+
|
| 92 |
+
Arguments:
|
| 93 |
+
----------
|
| 94 |
+
messages { list } : List of message dictionaries
|
| 95 |
+
|
| 96 |
+
**generation_params : Generation parameters (temperature, max_tokens, etc.)
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
--------
|
| 100 |
+
{ dict } : Generation response
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
if (self.provider == LLMProvider.OLLAMA):
|
| 104 |
+
return await self._generate_ollama(messages, **generation_params)
|
| 105 |
+
|
| 106 |
+
elif (self.provider == LLMProvider.OPENAI):
|
| 107 |
+
return await self._generate_openai(messages, **generation_params)
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
raise LLMClientError(f"Unsupported provider: {self.provider}")
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
self.logger.error(f"Generation failed: {repr(e)}")
|
| 114 |
+
raise LLMClientError(f"Generation failed: {repr(e)}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
async def generate_stream(self, messages: List[Dict], **generation_params) -> AsyncGenerator[str, None]:
|
| 118 |
+
"""
|
| 119 |
+
Generate text completion with streaming (async)
|
| 120 |
+
|
| 121 |
+
Arguments:
|
| 122 |
+
----------
|
| 123 |
+
messages { list } : List of message dictionaries
|
| 124 |
+
|
| 125 |
+
**generation_params : Generation parameters
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
--------
|
| 129 |
+
{ AsyncGenerator } : Async generator yielding response chunks
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
if (self.provider == LLMProvider.OLLAMA):
|
| 133 |
+
async for chunk in self._generate_ollama_stream(messages, **generation_params):
|
| 134 |
+
yield chunk
|
| 135 |
+
|
| 136 |
+
elif (self.provider == LLMProvider.OPENAI):
|
| 137 |
+
async for chunk in self._generate_openai_stream(messages, **generation_params):
|
| 138 |
+
yield chunk
|
| 139 |
+
|
| 140 |
+
else:
|
| 141 |
+
raise LLMClientError(f"Unsupported provider: {self.provider}")
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
self.logger.error(f"Stream generation failed: {repr(e)}")
|
| 145 |
+
raise LLMClientError(f"Stream generation failed: {repr(e)}")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
async def _generate_ollama(self, messages: List[Dict], **generation_params) -> Dict:
|
| 149 |
+
"""
|
| 150 |
+
Generate using Ollama API
|
| 151 |
+
"""
|
| 152 |
+
url = f"{self.base_url}/api/chat"
|
| 153 |
+
|
| 154 |
+
# Prepare request payload
|
| 155 |
+
payload = {"model" : self.model_name,
|
| 156 |
+
"messages" : messages,
|
| 157 |
+
"stream" : False,
|
| 158 |
+
"options" : {"temperature" : generation_params.get("temperature", 0.1),
|
| 159 |
+
"top_p" : generation_params.get("top_p", 0.9),
|
| 160 |
+
"num_predict" : generation_params.get("max_tokens", 1000),
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
async with aiohttp.ClientSession() as session:
|
| 165 |
+
async with session.post(url, json = payload, timeout = self.timeout) as response:
|
| 166 |
+
if (response.status != 200):
|
| 167 |
+
error_text = await response.text()
|
| 168 |
+
raise LLMClientError(f"Ollama API error: {response.status} - {error_text}")
|
| 169 |
+
|
| 170 |
+
result = await response.json()
|
| 171 |
+
|
| 172 |
+
return {"content" : result["message"]["content"],
|
| 173 |
+
"usage" : {"prompt_tokens" : result.get("prompt_eval_count", 0),
|
| 174 |
+
"completion_tokens" : result.get("eval_count", 0),
|
| 175 |
+
"total_tokens" : result.get("prompt_eval_count", 0) + result.get("eval_count", 0),
|
| 176 |
+
},
|
| 177 |
+
"finish_reason" : result.get("done_reason", "stop"),
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
async def _generate_ollama_stream(self, messages: List[Dict], **generation_params) -> AsyncGenerator[str, None]:
|
| 182 |
+
"""
|
| 183 |
+
Generate stream using Ollama API - FIXED VERSION
|
| 184 |
+
"""
|
| 185 |
+
url = f"{self.base_url}/api/chat"
|
| 186 |
+
|
| 187 |
+
payload = {"model" : self.model_name,
|
| 188 |
+
"messages" : messages,
|
| 189 |
+
"stream" : True,
|
| 190 |
+
"options" : {"temperature" : generation_params.get("temperature", 0.1),
|
| 191 |
+
"top_p" : generation_params.get("top_p", 0.9),
|
| 192 |
+
"num_predict" : generation_params.get("max_tokens", 1000),
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
async with aiohttp.ClientSession() as session:
|
| 197 |
+
async with session.post(url, json = payload, timeout = self.timeout) as response:
|
| 198 |
+
if (response.status != 200):
|
| 199 |
+
error_text = await response.text()
|
| 200 |
+
raise LLMClientError(f"Ollama API error: {response.status} - {error_text}")
|
| 201 |
+
|
| 202 |
+
async for line in response.content:
|
| 203 |
+
line_str = line.decode('utf-8').strip()
|
| 204 |
+
|
| 205 |
+
# Skip empty lines
|
| 206 |
+
if not line_str:
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
chunk_data = json.loads(line_str)
|
| 211 |
+
|
| 212 |
+
# Check if this is the final chunk
|
| 213 |
+
if (chunk_data.get("done", False)):
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
# Extract content regardless of whether it's empty: Ollama sends incremental content in each chunk
|
| 217 |
+
if ("message" in chunk_data):
|
| 218 |
+
content = chunk_data["message"].get("content", "")
|
| 219 |
+
|
| 220 |
+
# Only yield non-empty content
|
| 221 |
+
if content:
|
| 222 |
+
yield content
|
| 223 |
+
|
| 224 |
+
except json.JSONDecodeError as e:
|
| 225 |
+
self.logger.warning(f"Failed to parse streaming chunk: {line_str[:100]}")
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
async def _generate_openai(self, messages: List[Dict], **generation_params) -> Dict:
|
| 230 |
+
"""
|
| 231 |
+
Generate using OpenAI API
|
| 232 |
+
"""
|
| 233 |
+
url = f"{self.base_url}/chat/completions"
|
| 234 |
+
|
| 235 |
+
headers = {"Authorization" : f"Bearer {self.api_key}",
|
| 236 |
+
"Content-Type" : "application/json",
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
payload = {"model" : self.model_name,
|
| 240 |
+
"messages" : messages,
|
| 241 |
+
"temperature" : generation_params.get("temperature", 0.1),
|
| 242 |
+
"top_p" : generation_params.get("top_p", 0.9),
|
| 243 |
+
"max_tokens" : generation_params.get("max_tokens", 1000),
|
| 244 |
+
"stream" : False,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
async with aiohttp.ClientSession() as session:
|
| 248 |
+
async with session.post(url, headers = headers, json = payload, timeout = self.timeout) as response:
|
| 249 |
+
if (response.status != 200):
|
| 250 |
+
error_text = await response.text()
|
| 251 |
+
raise LLMClientError(f"OpenAI API error: {response.status} - {error_text}")
|
| 252 |
+
|
| 253 |
+
result = await response.json()
|
| 254 |
+
|
| 255 |
+
return {"content" : result["choices"][0]["message"]["content"],
|
| 256 |
+
"usage" : result["usage"],
|
| 257 |
+
"finish_reason" : result["choices"][0]["finish_reason"],
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
async def _generate_openai_stream(self, messages: List[Dict], **generation_params) -> AsyncGenerator[str, None]:
|
| 262 |
+
"""
|
| 263 |
+
Generate stream using OpenAI API
|
| 264 |
+
"""
|
| 265 |
+
url = f"{self.base_url}/chat/completions"
|
| 266 |
+
|
| 267 |
+
headers = {"Authorization" : f"Bearer {self.api_key}",
|
| 268 |
+
"Content-Type" : "application/json",
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
payload = {"model" : self.model_name,
|
| 272 |
+
"messages" : messages,
|
| 273 |
+
"temperature" : generation_params.get("temperature", 0.1),
|
| 274 |
+
"top_p" : generation_params.get("top_p", 0.9),
|
| 275 |
+
"max_tokens" : generation_params.get("max_tokens", 1000),
|
| 276 |
+
"stream" : True,
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
async with aiohttp.ClientSession() as session:
|
| 280 |
+
async with session.post(url, headers = headers, json = payload, timeout = self.timeout) as response:
|
| 281 |
+
if (response.status != 200):
|
| 282 |
+
error_text = await response.text()
|
| 283 |
+
|
| 284 |
+
raise LLMClientError(f"OpenAI API error: {response.status} - {error_text}")
|
| 285 |
+
|
| 286 |
+
async for line in response.content:
|
| 287 |
+
line = line.decode('utf-8').strip()
|
| 288 |
+
|
| 289 |
+
if (line.startswith('data: ')):
|
| 290 |
+
# Remove 'data: ' prefix
|
| 291 |
+
data = line[6:]
|
| 292 |
+
|
| 293 |
+
if (data == '[DONE]'):
|
| 294 |
+
break
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
chunk_data = json.loads(data)
|
| 298 |
+
if ("choices" in chunk_data) and (chunk_data["choices"]):
|
| 299 |
+
delta = chunk_data["choices"][0].get("delta", {})
|
| 300 |
+
|
| 301 |
+
if ("content" in delta):
|
| 302 |
+
yield delta["content"]
|
| 303 |
+
|
| 304 |
+
except json.JSONDecodeError:
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def check_health(self) -> bool:
|
| 309 |
+
"""
|
| 310 |
+
Check if LLM provider is healthy and accessible
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
--------
|
| 314 |
+
{ bool } : True if healthy
|
| 315 |
+
"""
|
| 316 |
+
try:
|
| 317 |
+
if (self.provider == LLMProvider.OLLAMA):
|
| 318 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout = 30)
|
| 319 |
+
return (response.status_code == 200)
|
| 320 |
+
|
| 321 |
+
elif (self.provider == LLMProvider.OPENAI):
|
| 322 |
+
# Simple models list check
|
| 323 |
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
| 324 |
+
response = requests.get(f"{self.base_url}/models", headers=headers, timeout=10)
|
| 325 |
+
return (response.status_code == 200)
|
| 326 |
+
|
| 327 |
+
return False
|
| 328 |
+
|
| 329 |
+
except Exception as e:
|
| 330 |
+
self.logger.warning(f"Health check failed: {repr(e)}")
|
| 331 |
+
return False
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def get_provider_info(self) -> Dict:
|
| 335 |
+
"""
|
| 336 |
+
Get provider information
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
--------
|
| 340 |
+
{ dict } : Provider information
|
| 341 |
+
"""
|
| 342 |
+
return {"provider" : self.provider.value,
|
| 343 |
+
"model" : self.model_name,
|
| 344 |
+
"base_url" : self.base_url,
|
| 345 |
+
"healthy" : self.check_health(),
|
| 346 |
+
"timeout" : self.timeout,
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# Global LLM client instance
|
| 351 |
+
_llm_client = None
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_llm_client(provider: LLMProvider = None, **kwargs) -> LLMClient:
|
| 355 |
+
"""
|
| 356 |
+
Get global LLM client instance (singleton)
|
| 357 |
+
|
| 358 |
+
Arguments:
|
| 359 |
+
----------
|
| 360 |
+
provider { LLMProvider } : LLM provider to use
|
| 361 |
+
|
| 362 |
+
**kwargs : Additional client configuration
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
--------
|
| 366 |
+
{ LLMClient } : LLMClient instance
|
| 367 |
+
"""
|
| 368 |
+
global _llm_client
|
| 369 |
+
|
| 370 |
+
if _llm_client is None or (provider and _llm_client.provider != provider):
|
| 371 |
+
_llm_client = LLMClient(provider, **kwargs)
|
| 372 |
+
|
| 373 |
+
return _llm_client
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@handle_errors(error_type = LLMClientError, log_error = True, reraise = False)
|
| 377 |
+
async def generate_text(messages: List[Dict], provider: LLMProvider = LLMProvider.OLLAMA, **kwargs) -> str:
|
| 378 |
+
"""
|
| 379 |
+
Convenience function for text generation
|
| 380 |
+
|
| 381 |
+
Arguments:
|
| 382 |
+
----------
|
| 383 |
+
messages { list } : List of message dictionaries
|
| 384 |
+
|
| 385 |
+
provider { LLMProvider } : LLM provider to use
|
| 386 |
+
|
| 387 |
+
**kwargs : Generation parameters
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
--------
|
| 391 |
+
{ str } : Generated text
|
| 392 |
+
"""
|
| 393 |
+
client = get_llm_client(provider, **kwargs)
|
| 394 |
+
response = await client.generate(messages, **kwargs)
|
| 395 |
+
|
| 396 |
+
return response["content"]
|
generation/prompt_builder.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Dict
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from config.models import PromptType
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.models import ChunkWithScore
|
| 8 |
+
from config.logging_config import get_logger
|
| 9 |
+
from utils.error_handler import handle_errors
|
| 10 |
+
from generation.token_manager import TokenManager
|
| 11 |
+
from utils.error_handler import PromptBuildingError
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Setup Settings and Logging
|
| 15 |
+
settings = get_settings()
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PromptBuilder:
|
| 20 |
+
"""
|
| 21 |
+
Intelligent prompt building for LLM generation: Constructs optimized prompts for different task types with proper context management and citation handling
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, model_name: str = None):
|
| 24 |
+
"""
|
| 25 |
+
Initialize prompt builder
|
| 26 |
+
|
| 27 |
+
Arguments:
|
| 28 |
+
----------
|
| 29 |
+
model_name { str } : Model name for token management
|
| 30 |
+
"""
|
| 31 |
+
self.logger = logger
|
| 32 |
+
self.settings = get_settings()
|
| 33 |
+
self.model_name = model_name or self.settings.OLLAMA_MODEL
|
| 34 |
+
self.token_manager = TokenManager(model_name)
|
| 35 |
+
|
| 36 |
+
# Prompt templates for different tasks
|
| 37 |
+
self.prompt_templates = {PromptType.QA : {"system": self._get_qa_system_prompt(), "user": self._get_qa_user_template(), "max_context_ratio": 0.6},
|
| 38 |
+
PromptType.SUMMARY : {"system": self._get_summary_system_prompt(), "user": self._get_summary_user_template(), "max_context_ratio": 0.8,},
|
| 39 |
+
PromptType.ANALYTICAL : {"system": self._get_analytical_system_prompt(), "user": self._get_analytical_user_template(), "max_context_ratio": 0.7},
|
| 40 |
+
PromptType.COMPARISON : {"system": self._get_comparison_system_prompt(), "user": self._get_comparison_user_template(), "max_context_ratio": 0.5},
|
| 41 |
+
PromptType.EXTRACTION : {"system": self._get_extraction_system_prompt(), "user": self._get_extraction_user_template(), "max_context_ratio": 0.6},
|
| 42 |
+
PromptType.CREATIVE : {"system": self._get_creative_system_prompt(), "user": self._get_creative_user_template(), "max_context_ratio": 0.4},
|
| 43 |
+
PromptType.CONVERSATIONAL : {"system": self._get_conversational_system_prompt(), "user": self._get_conversational_user_template(), "max_context_ratio": 0.5}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_prompt(self, query: str, context: str, sources: List[ChunkWithScore], prompt_type: PromptType = PromptType.QA, include_citations: bool = True,
|
| 48 |
+
max_completion_tokens: int = 1000) -> Dict[str, str]:
|
| 49 |
+
"""
|
| 50 |
+
Build complete prompt for LLM generation
|
| 51 |
+
|
| 52 |
+
Arguments:
|
| 53 |
+
----------
|
| 54 |
+
query { str } : User query
|
| 55 |
+
|
| 56 |
+
context { str } : Retrieved context
|
| 57 |
+
|
| 58 |
+
sources { list } : Source chunks
|
| 59 |
+
|
| 60 |
+
prompt_type { PromptType } : Type of prompt to build
|
| 61 |
+
|
| 62 |
+
include_citations { bool } : Whether to include citation instructions
|
| 63 |
+
|
| 64 |
+
max_completion_tokens { int } : Tokens to reserve for completion
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
--------
|
| 68 |
+
{ dict } : Dictionary with 'system' and 'user' prompts
|
| 69 |
+
"""
|
| 70 |
+
if not query or not context:
|
| 71 |
+
raise PromptBuildingError("Query and context cannot be empty")
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
# Get template for prompt type
|
| 75 |
+
template = self.prompt_templates.get(prompt_type, self.prompt_templates[PromptType.QA])
|
| 76 |
+
|
| 77 |
+
# Optimize context to fit within token limits
|
| 78 |
+
optimized_context = self._optimize_context(context = context,
|
| 79 |
+
template = template,
|
| 80 |
+
max_completion_tokens = max_completion_tokens,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Build system prompt
|
| 84 |
+
system_prompt = self._build_system_prompt(template = template,
|
| 85 |
+
include_citations = include_citations,
|
| 86 |
+
prompt_type = prompt_type,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Build user prompt
|
| 90 |
+
user_prompt = self._build_user_prompt(template = template,
|
| 91 |
+
query = query,
|
| 92 |
+
context = optimized_context,
|
| 93 |
+
sources = sources,
|
| 94 |
+
prompt_type = prompt_type,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Validate token usage
|
| 98 |
+
self._validate_prompt_length(system_prompt = system_prompt,
|
| 99 |
+
user_prompt = user_prompt,
|
| 100 |
+
max_completion_tokens = max_completion_tokens,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.logger.debug(f"Built {prompt_type.value} prompt: {self.token_manager.count_tokens(system_prompt + user_prompt)} tokens")
|
| 104 |
+
|
| 105 |
+
return {"system" : system_prompt,
|
| 106 |
+
"user" : user_prompt,
|
| 107 |
+
"metadata" : {"prompt_type" : prompt_type.value,
|
| 108 |
+
"context_tokens" : self.token_manager.count_tokens(optimized_context),
|
| 109 |
+
"total_tokens" : self.token_manager.count_tokens(system_prompt + user_prompt),
|
| 110 |
+
"sources_count" : len(sources),
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
self.logger.error(f"Prompt building failed: {repr(e)}")
|
| 116 |
+
raise PromptBuildingError(f"Prompt building failed: {repr(e)}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _optimize_context(self, context: str, template: Dict, max_completion_tokens: int) -> str:
|
| 120 |
+
"""
|
| 121 |
+
Optimize context to fit within token limits
|
| 122 |
+
|
| 123 |
+
Arguments:
|
| 124 |
+
----------
|
| 125 |
+
context { str } : Context text to optimize
|
| 126 |
+
|
| 127 |
+
template { Dict } : Prompt template
|
| 128 |
+
|
| 129 |
+
max_completion_tokens { int } : Tokens to reserve for completion
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
--------
|
| 133 |
+
{ str } : Optimized context
|
| 134 |
+
"""
|
| 135 |
+
# Calculate tokens for system prompt
|
| 136 |
+
system_tokens = self.token_manager.count_tokens(template["system"])
|
| 137 |
+
|
| 138 |
+
# Estimate user template tokens by removing placeholders
|
| 139 |
+
user_template_clean = template["user"]
|
| 140 |
+
|
| 141 |
+
# Remove all known placeholders
|
| 142 |
+
placeholders_to_remove = ["{query}", "{context}", "{sources_info}", "{focus}"]
|
| 143 |
+
|
| 144 |
+
for placeholder in placeholders_to_remove:
|
| 145 |
+
user_template_clean = user_template_clean.replace(placeholder, "")
|
| 146 |
+
|
| 147 |
+
user_template_tokens = self.token_manager.count_tokens(user_template_clean)
|
| 148 |
+
|
| 149 |
+
# Calculate available tokens for context - Reserve: system + user template + completion + buffer
|
| 150 |
+
reserved_tokens = system_tokens + user_template_tokens + max_completion_tokens + 100
|
| 151 |
+
max_context_tokens = self.token_manager.context_window - reserved_tokens
|
| 152 |
+
|
| 153 |
+
# Ensure we don't exceed zero
|
| 154 |
+
if (max_context_tokens <= 0):
|
| 155 |
+
self.logger.warning(f"No tokens available for context. Reserved: {reserved_tokens}, Window: {self.token_manager.context_window}")
|
| 156 |
+
return ""
|
| 157 |
+
|
| 158 |
+
# Apply max context ratio from template config
|
| 159 |
+
ratio_limit = int(self.token_manager.context_window * template["max_context_ratio"])
|
| 160 |
+
max_context_tokens = min(max_context_tokens, ratio_limit)
|
| 161 |
+
|
| 162 |
+
# Log optimization details
|
| 163 |
+
self.logger.debug(f"Context optimization: max_tokens={max_context_tokens}, ratio={template['max_context_ratio']}")
|
| 164 |
+
|
| 165 |
+
# Truncate context to fit
|
| 166 |
+
optimized_context = self.token_manager.truncate_to_fit(context, max_context_tokens, strategy="end")
|
| 167 |
+
|
| 168 |
+
# Calculate reduction percentage
|
| 169 |
+
original_tokens = self.token_manager.count_tokens(context)
|
| 170 |
+
optimized_tokens = self.token_manager.count_tokens(optimized_context)
|
| 171 |
+
|
| 172 |
+
if (original_tokens > optimized_tokens):
|
| 173 |
+
reduction_pct = ((original_tokens - optimized_tokens) / original_tokens) * 100
|
| 174 |
+
|
| 175 |
+
self.logger.info(f"Context reduced by {reduction_pct:.1f}% ({original_tokens} → {optimized_tokens} tokens)")
|
| 176 |
+
|
| 177 |
+
return optimized_context
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _build_system_prompt(self, template: Dict, include_citations: bool, prompt_type: PromptType) -> str:
|
| 181 |
+
"""
|
| 182 |
+
Build system prompt
|
| 183 |
+
"""
|
| 184 |
+
system_prompt = template["system"]
|
| 185 |
+
|
| 186 |
+
# Add citation instructions if needed
|
| 187 |
+
if (include_citations and (prompt_type != PromptType.CREATIVE)):
|
| 188 |
+
citation_instructions = self._get_citation_instructions()
|
| 189 |
+
system_prompt += "\n\n" + citation_instructions
|
| 190 |
+
|
| 191 |
+
return system_prompt
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _build_user_prompt(self, template: Dict, query: str, context: str, sources: List[ChunkWithScore], prompt_type: PromptType) -> str:
|
| 195 |
+
"""
|
| 196 |
+
Build user prompt
|
| 197 |
+
"""
|
| 198 |
+
# Format sources information
|
| 199 |
+
sources_info = self._format_sources_info(sources) if sources else ""
|
| 200 |
+
|
| 201 |
+
# Build user prompt using template
|
| 202 |
+
user_prompt = template["user"].format(query = query,
|
| 203 |
+
context = context,
|
| 204 |
+
sources_info = sources_info,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Add task-specific formatting
|
| 208 |
+
if (prompt_type == PromptType.COMPARISON):
|
| 209 |
+
user_prompt = self._enhance_comparison_prompt(user_prompt = user_prompt,
|
| 210 |
+
sources = sources,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
elif (prompt_type == PromptType.ANALYTICAL):
|
| 214 |
+
user_prompt = self._enhance_analytical_prompt(user_prompt = user_prompt,
|
| 215 |
+
sources = sources,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return user_prompt
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _format_sources_info(self, sources: List[ChunkWithScore]) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Format sources information for the prompt
|
| 224 |
+
"""
|
| 225 |
+
if not sources:
|
| 226 |
+
return ""
|
| 227 |
+
|
| 228 |
+
sources_list = list()
|
| 229 |
+
|
| 230 |
+
for i, source in enumerate(sources, 1):
|
| 231 |
+
chunk = source.chunk
|
| 232 |
+
source_info = f"Source [{i}]:"
|
| 233 |
+
|
| 234 |
+
if hasattr(chunk, 'metadata') and 'filename' in chunk.metadata:
|
| 235 |
+
source_info += f" {chunk.metadata['filename']}"
|
| 236 |
+
|
| 237 |
+
if chunk.page_number:
|
| 238 |
+
source_info += f" (page {chunk.page_number})"
|
| 239 |
+
|
| 240 |
+
if chunk.section_title:
|
| 241 |
+
source_info += f" - {chunk.section_title}"
|
| 242 |
+
|
| 243 |
+
sources_list.append(source_info)
|
| 244 |
+
|
| 245 |
+
return "Available sources:\n" + "\n".join(sources_list)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _enhance_comparison_prompt(self, user_prompt: str, sources: List[ChunkWithScore]) -> str:
|
| 249 |
+
"""
|
| 250 |
+
Enhance prompt for comparison tasks
|
| 251 |
+
|
| 252 |
+
Arguments:
|
| 253 |
+
----------
|
| 254 |
+
user_prompt { str } : Base user prompt
|
| 255 |
+
|
| 256 |
+
sources { list } : Source chunks
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
--------
|
| 260 |
+
{ str } : Enhanced prompt
|
| 261 |
+
"""
|
| 262 |
+
if (len(sources) < 2):
|
| 263 |
+
return user_prompt
|
| 264 |
+
|
| 265 |
+
enhancement = "\n\nPlease compare and contrast the information from different sources. "
|
| 266 |
+
enhancement += "Highlight agreements, disagreements, and complementary information. "
|
| 267 |
+
enhancement += "If sources conflict, present both perspectives clearly."
|
| 268 |
+
|
| 269 |
+
return user_prompt + enhancement
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _enhance_analytical_prompt(self, user_prompt: str, sources: List[ChunkWithScore]) -> str:
|
| 273 |
+
"""
|
| 274 |
+
Enhance prompt for analytical tasks - FIXED SIGNATURE
|
| 275 |
+
|
| 276 |
+
Arguments:
|
| 277 |
+
----------
|
| 278 |
+
user_prompt { str } : Base user prompt
|
| 279 |
+
|
| 280 |
+
sources { list } : Source chunks
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
--------
|
| 284 |
+
{ str } : Enhanced prompt
|
| 285 |
+
"""
|
| 286 |
+
enhancement = "\n\nProvide analytical insights by:"
|
| 287 |
+
enhancement += "\n1. Identifying patterns and relationships in the information"
|
| 288 |
+
enhancement += "\n2. Analyzing implications and consequences"
|
| 289 |
+
enhancement += "\n3. Evaluating the strength of evidence from different sources"
|
| 290 |
+
enhancement += "\n4. Drawing well-reasoned conclusions"
|
| 291 |
+
|
| 292 |
+
return user_prompt + enhancement
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _validate_prompt_length(self, system_prompt: str, user_prompt: str, max_completion_tokens: int):
|
| 296 |
+
"""
|
| 297 |
+
Validate that prompt fits within context window
|
| 298 |
+
|
| 299 |
+
Arguments:
|
| 300 |
+
----------
|
| 301 |
+
system_prompt { str } : System prompt
|
| 302 |
+
|
| 303 |
+
user_prompt { str } : User prompt
|
| 304 |
+
|
| 305 |
+
max_completion_tokens { int } : Tokens needed for completion
|
| 306 |
+
|
| 307 |
+
Raises:
|
| 308 |
+
-------
|
| 309 |
+
PromptBuildingError : If prompt exceeds context window
|
| 310 |
+
"""
|
| 311 |
+
system_tokens = self.token_manager.count_tokens(system_prompt)
|
| 312 |
+
user_tokens = self.token_manager.count_tokens(user_prompt)
|
| 313 |
+
total_tokens = system_tokens + user_tokens
|
| 314 |
+
total_required = total_tokens + max_completion_tokens
|
| 315 |
+
|
| 316 |
+
if (total_required > self.token_manager.context_window):
|
| 317 |
+
error_msg = (f"Prompt exceeds context window:\n"
|
| 318 |
+
f"- System prompt: {system_tokens} tokens\n"
|
| 319 |
+
f"- User prompt: {user_tokens} tokens\n"
|
| 320 |
+
f"- Completion reserve: {max_completion_tokens} tokens\n"
|
| 321 |
+
f"- Total required: {total_required} tokens\n"
|
| 322 |
+
f"- Context window: {self.token_manager.context_window} tokens\n"
|
| 323 |
+
f"- Overflow: {total_required - self.token_manager.context_window} tokens"
|
| 324 |
+
)
|
| 325 |
+
self.logger.error(error_msg)
|
| 326 |
+
|
| 327 |
+
raise PromptBuildingError(error_msg)
|
| 328 |
+
|
| 329 |
+
# Log successful validation
|
| 330 |
+
utilization = (total_required / self.token_manager.context_window) * 100
|
| 331 |
+
|
| 332 |
+
self.logger.debug(f"Prompt validation passed: {total_required}/{self.token_manager.context_window} tokens ({utilization:.1f}% utilization)")
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _get_citation_instructions(self) -> str:
|
| 336 |
+
"""
|
| 337 |
+
Get citation instructions for system prompt
|
| 338 |
+
"""
|
| 339 |
+
return ("CITATION INSTRUCTIONS:\n"
|
| 340 |
+
"1. Always cite your sources using [number] notation\n"
|
| 341 |
+
"2. Use the citation number that corresponds to the source in the context\n"
|
| 342 |
+
"3. Cite sources for all factual claims and specific information\n"
|
| 343 |
+
"4. If information appears in multiple sources, cite the most relevant one\n"
|
| 344 |
+
"5. Do not make up information not present in the provided context\n"
|
| 345 |
+
"6. If the context doesn't contain the answer, explicitly state this"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Template methods for different prompt types
|
| 350 |
+
def _get_qa_system_prompt(self) -> str:
|
| 351 |
+
return ("You are a precise and helpful AI assistant that answers questions based solely on the provided context.\n\n"
|
| 352 |
+
"Core Principles:\n"
|
| 353 |
+
"1. ONLY use information from the provided context\n"
|
| 354 |
+
"2. Be concise but complete - don't omit important details\n"
|
| 355 |
+
"3. Structure complex answers clearly\n"
|
| 356 |
+
"4. If information is ambiguous or conflicting, acknowledge this\n"
|
| 357 |
+
"5. Never make up or infer information not present in the context"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _get_qa_user_template(self) -> str:
|
| 362 |
+
return ("Context Information:\n{context}\n\n"
|
| 363 |
+
"{sources_info}\n\n"
|
| 364 |
+
"Question: {query}\n\n"
|
| 365 |
+
"Instructions: Answer the question using ONLY the information provided in the context above. "
|
| 366 |
+
"Cite your sources using [number] notation. If the context doesn't contain enough information "
|
| 367 |
+
"to answer fully, state this clearly.\n\n"
|
| 368 |
+
"Answer:"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _get_summary_system_prompt(self) -> str:
|
| 373 |
+
return ("You are a thorough AI assistant that provides comprehensive summaries based on the provided context.\n\n"
|
| 374 |
+
"Summary Guidelines:\n"
|
| 375 |
+
"1. Capture all key points and main ideas\n"
|
| 376 |
+
"2. Maintain the original meaning and intent\n"
|
| 377 |
+
"3. Organize information logically\n"
|
| 378 |
+
"4. Be comprehensive but concise\n"
|
| 379 |
+
"5. Highlight important findings and conclusions"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _get_summary_user_template(self) -> str:
|
| 384 |
+
return ("Content to summarize:\n{context}\n\n"
|
| 385 |
+
"Please provide a comprehensive summary that captures all key points and main ideas. "
|
| 386 |
+
"Organize the summary logically and ensure it reflects the original content accurately.\n\n"
|
| 387 |
+
"Summary:"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _get_analytical_system_prompt(self) -> str:
|
| 392 |
+
return ("You are an analytical AI assistant that provides insights based on the provided context.\n\n"
|
| 393 |
+
"Analytical Guidelines:\n"
|
| 394 |
+
"1. Analyze patterns and connections in the information\n"
|
| 395 |
+
"2. Compare different perspectives if multiple sources exist\n"
|
| 396 |
+
"3. Highlight key findings and implications\n"
|
| 397 |
+
"4. Identify gaps or limitations in the available information\n"
|
| 398 |
+
"5. Provide well-reasoned analysis and conclusions"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _get_analytical_user_template(self) -> str:
|
| 403 |
+
return ("Context for analysis:\n{context}\n\n"
|
| 404 |
+
"{sources_info}\n\n"
|
| 405 |
+
"Analytical task: {query}\n\n"
|
| 406 |
+
"Please provide a detailed analysis based on the context above. "
|
| 407 |
+
"Identify patterns, relationships, and implications. "
|
| 408 |
+
"Cite sources for all analytical claims.\n\n"
|
| 409 |
+
"Analysis:"
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _get_comparison_system_prompt(self) -> str:
|
| 414 |
+
return ("You are an AI assistant that compares information across multiple sources.\n\n"
|
| 415 |
+
"Comparison Guidelines:\n"
|
| 416 |
+
"1. Identify similarities and differences between sources\n"
|
| 417 |
+
"2. Note if sources agree or disagree on specific points\n"
|
| 418 |
+
"3. Highlight complementary information\n"
|
| 419 |
+
"4. Present conflicting perspectives fairly\n"
|
| 420 |
+
"5. Draw conclusions about the overall consensus or disagreement"
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _get_comparison_user_template(self) -> str:
|
| 425 |
+
return ("Context from multiple sources:\n{context}\n\n"
|
| 426 |
+
"{sources_info}\n\n"
|
| 427 |
+
"Comparison task: {query}\n\n"
|
| 428 |
+
"Please compare how different sources address this topic. "
|
| 429 |
+
"Identify agreements, disagreements, and complementary information. "
|
| 430 |
+
"Cite specific sources for each point of comparison.\n\n"
|
| 431 |
+
"Comparison:"
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _get_extraction_system_prompt(self) -> str:
|
| 436 |
+
return ("You are a precise AI assistant that extracts specific information from context.\n\n"
|
| 437 |
+
"Extraction Guidelines:\n"
|
| 438 |
+
"1. Extract only the requested information\n"
|
| 439 |
+
"2. Be thorough and complete\n"
|
| 440 |
+
"3. Maintain accuracy and precision\n"
|
| 441 |
+
"4. Organize extracted information clearly\n"
|
| 442 |
+
"5. Cite sources for all extracted information"
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _get_extraction_user_template(self) -> str:
|
| 447 |
+
return ("Context:\n{context}\n\n"
|
| 448 |
+
"{sources_info}\n\n"
|
| 449 |
+
"Extraction task: {query}\n\n"
|
| 450 |
+
"Please extract the requested information from the context above. "
|
| 451 |
+
"Be thorough and precise. Cite sources for all extracted information.\n\n"
|
| 452 |
+
"Extracted Information:"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _get_creative_system_prompt(self) -> str:
|
| 457 |
+
return ("You are a creative AI assistant that generates content based on provided context.\n\n"
|
| 458 |
+
"Creative Guidelines:\n"
|
| 459 |
+
"1. Use the context as inspiration and foundation\n"
|
| 460 |
+
"2. Be creative and engaging\n"
|
| 461 |
+
"3. Maintain coherence with the source material\n"
|
| 462 |
+
"4. You may extrapolate and build upon the provided information\n"
|
| 463 |
+
"5. Clearly distinguish between source-based content and creative additions"
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _get_creative_user_template(self) -> str:
|
| 468 |
+
return ("Context and inspiration:\n{context}\n\n"
|
| 469 |
+
"Creative task: {query}\n\n"
|
| 470 |
+
"Please create content based on the context above. "
|
| 471 |
+
"You may use the context as inspiration and build upon it creatively. "
|
| 472 |
+
"If you add information beyond what's in the context, make this clear.\n\n"
|
| 473 |
+
"Creative Response:"
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def _get_conversational_system_prompt(self) -> str:
|
| 478 |
+
return ("You are a helpful and engaging conversational AI assistant.\n\n"
|
| 479 |
+
"Conversational Guidelines:\n"
|
| 480 |
+
"1. Be natural and engaging in conversation\n"
|
| 481 |
+
"2. Use the provided context to inform your responses\n"
|
| 482 |
+
"3. Maintain a friendly and helpful tone\n"
|
| 483 |
+
"4. Ask clarifying questions when needed\n"
|
| 484 |
+
"5. Cite sources when providing specific information"
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _get_conversational_user_template(self) -> str:
|
| 489 |
+
return ("Context for our conversation:\n{context}\n\n"
|
| 490 |
+
"Current message: {query}\n\n"
|
| 491 |
+
"Please respond naturally and helpfully based on the context above. "
|
| 492 |
+
"If providing specific information, cite your sources.\n\n"
|
| 493 |
+
"Response:"
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
# Global prompt builder instance
|
| 498 |
+
_prompt_builder = None
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def get_prompt_builder(model_name: str = None) -> PromptBuilder:
|
| 502 |
+
"""
|
| 503 |
+
Get global prompt builder instance (singleton)
|
| 504 |
+
|
| 505 |
+
Arguments:
|
| 506 |
+
----------
|
| 507 |
+
model_name { str } : Model name for token management
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
--------
|
| 511 |
+
{ PromptBuilder } : PromptBuilder instance
|
| 512 |
+
"""
|
| 513 |
+
global _prompt_builder
|
| 514 |
+
|
| 515 |
+
if _prompt_builder is None or (model_name and _prompt_builder.model_name != model_name):
|
| 516 |
+
_prompt_builder = PromptBuilder(model_name)
|
| 517 |
+
|
| 518 |
+
return _prompt_builder
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
@handle_errors(error_type = PromptBuildingError, log_error = True, reraise = False)
|
| 522 |
+
def build_qa_prompt(query: str, context: str, sources: List[ChunkWithScore], **kwargs) -> Dict[str, str]:
|
| 523 |
+
"""
|
| 524 |
+
Convenience function for building QA prompts
|
| 525 |
+
|
| 526 |
+
Arguments:
|
| 527 |
+
----------
|
| 528 |
+
query { str } : User query
|
| 529 |
+
|
| 530 |
+
context { str } : Retrieved context
|
| 531 |
+
|
| 532 |
+
sources { list } : Source chunks
|
| 533 |
+
|
| 534 |
+
**kwargs : Additional prompt building arguments
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
--------
|
| 538 |
+
{ dict } : Dictionary with system and user prompts
|
| 539 |
+
"""
|
| 540 |
+
builder = get_prompt_builder()
|
| 541 |
+
|
| 542 |
+
return builder.build_prompt(query, context, sources, PromptType.QA, **kwargs)
|
generation/query_classifier.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict
|
| 4 |
+
from config.models import LLMProvider
|
| 5 |
+
from config.settings import get_settings
|
| 6 |
+
from config.logging_config import get_logger
|
| 7 |
+
from generation.llm_client import get_llm_client
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Setup Settings and Logging
|
| 11 |
+
settings = get_settings()
|
| 12 |
+
logger = get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class QueryClassifier:
|
| 16 |
+
"""
|
| 17 |
+
LLM-based query classifier that intelligently routes queries to:
|
| 18 |
+
1. General/Conversational (no document context needed)
|
| 19 |
+
2. RAG/Document-based (needs retrieval from documents)
|
| 20 |
+
|
| 21 |
+
Uses the LLM itself for classification instead of hardcoded patterns.
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, provider: LLMProvider = None, model_name: str = None):
|
| 24 |
+
self.logger = logger
|
| 25 |
+
self.provider = provider or LLMProvider.OLLAMA
|
| 26 |
+
self.model_name = model_name or settings.OLLAMA_MODEL
|
| 27 |
+
|
| 28 |
+
# Initialize LLM client for classification
|
| 29 |
+
self.llm_client = get_llm_client(provider = self.provider,
|
| 30 |
+
model_name = self.model_name,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Classification prompt
|
| 34 |
+
self.system_prompt = """
|
| 35 |
+
You are a query classification system for a RAG (Retrieval-Augmented Generation) system.
|
| 36 |
+
Your job is to determine if a user query should be answered using the user's uploaded documents.
|
| 37 |
+
|
| 38 |
+
**IMPORTANT CONTEXT**: The user has uploaded documents to the system. All queries related to the content of those uploaded documents should use RAG.
|
| 39 |
+
|
| 40 |
+
Classify queries into TWO categories:
|
| 41 |
+
|
| 42 |
+
**RAG (Document-based)** - Use when ANY of these are true:
|
| 43 |
+
1. Query asks about ANY content that could be in the uploaded documents
|
| 44 |
+
2. Query asks factual questions that could be answered from document content
|
| 45 |
+
3. Query asks for lists, summaries, or analysis of information
|
| 46 |
+
4. Query mentions specific details, data, statistics, names, dates, or facts
|
| 47 |
+
5. Query asks "what", "how", "list", "explain", "summarize", "compare", "analyze" about any topic
|
| 48 |
+
6. Query could reasonably be answered by searching through documents
|
| 49 |
+
7. **CRITICAL**: When documents are uploaded, DEFAULT TO RAG for most factual/content queries
|
| 50 |
+
|
| 51 |
+
**GENERAL (Conversational)** - Use ONLY when MOST of these are true:
|
| 52 |
+
1. Query is purely conversational (greetings, thanks, casual chat)
|
| 53 |
+
2. Query asks about the RAG system itself or its functionality
|
| 54 |
+
3. Query asks for general knowledge that is NOT specific to uploaded documents
|
| 55 |
+
4. Query is a meta-question about how to use the system
|
| 56 |
+
5. Query contains NO request for factual information from documents
|
| 57 |
+
|
| 58 |
+
**EXAMPLES FOR ANY DOCUMENT TYPE**:
|
| 59 |
+
- For business documents: "What sales channels does the company use?" → RAG
|
| 60 |
+
- For research papers: "What were the study's findings?" → RAG
|
| 61 |
+
- For legal documents: "What are the key clauses?" → RAG
|
| 62 |
+
- For technical manuals: "How do I configure the system?" → RAG
|
| 63 |
+
- For personal documents: "What dates are mentioned?" → RAG
|
| 64 |
+
- "Hi, how are you?" → GENERAL
|
| 65 |
+
- "How do I upload a document?" → GENERAL
|
| 66 |
+
- "What is the capital of France?" → GENERAL (unless geography documents were uploaded)
|
| 67 |
+
|
| 68 |
+
**KEY RULES**:
|
| 69 |
+
1. When documents exist, assume queries are about them unless clearly not
|
| 70 |
+
2. When in doubt, classify as RAG (safer to search than hallucinate)
|
| 71 |
+
3. If query could be answered from document content, use RAG
|
| 72 |
+
4. Only use GENERAL for purely conversational or system-related queries
|
| 73 |
+
|
| 74 |
+
Respond with ONLY a JSON object (no markdown, no extra text):
|
| 75 |
+
{
|
| 76 |
+
"type": "rag" or "general",
|
| 77 |
+
"confidence": 0.0 to 1.0,
|
| 78 |
+
"reason": "brief explanation"
|
| 79 |
+
}
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
async def classify(self, query: str, has_documents: bool = True) -> Dict:
|
| 84 |
+
"""
|
| 85 |
+
Classify a query using LLM
|
| 86 |
+
|
| 87 |
+
Arguments:
|
| 88 |
+
----------
|
| 89 |
+
query { str } : User query
|
| 90 |
+
|
| 91 |
+
has_documents { bool } : Whether documents are available in the system
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
--------
|
| 95 |
+
{ dict } : Classification result
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
# If no documents are available, everything should be general
|
| 99 |
+
if not has_documents:
|
| 100 |
+
return {"type" : "general",
|
| 101 |
+
"confidence" : 1.0,
|
| 102 |
+
"reason" : "No documents available in system",
|
| 103 |
+
"suggested_action" : "respond_with_general_llm",
|
| 104 |
+
"is_llm_classified" : False,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
# Build classification prompt
|
| 108 |
+
user_prompt = f"""
|
| 109 |
+
Query: "{query}"
|
| 110 |
+
|
| 111 |
+
System status: {"Documents are available" if has_documents else "No documents uploaded"}
|
| 112 |
+
|
| 113 |
+
Classify this query. Remember: if uncertain, prefer RAG.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
messages = [{"role" : "system",
|
| 117 |
+
"content" : self.system_prompt,
|
| 118 |
+
},
|
| 119 |
+
{"role" : "user",
|
| 120 |
+
"content" : user_prompt,
|
| 121 |
+
}
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
# Get LLM classification (use low temperature for consistency)
|
| 125 |
+
llm_response = await self.llm_client.generate(messages = messages,
|
| 126 |
+
temperature = 0.1, # Low temperature for consistent classification
|
| 127 |
+
max_tokens = 150,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
response_text = llm_response.get("content", "").strip()
|
| 131 |
+
|
| 132 |
+
# Parse JSON response
|
| 133 |
+
classification = self._parse_llm_response(response_text)
|
| 134 |
+
|
| 135 |
+
# Add suggested action based on classification
|
| 136 |
+
if (classification["type"] == "rag"):
|
| 137 |
+
classification["suggested_action"] = "respond_with_rag"
|
| 138 |
+
|
| 139 |
+
elif (classification["type"] == "general"):
|
| 140 |
+
classification["suggested_action"] = "respond_with_general_llm"
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
# Default to RAG if uncertain
|
| 144 |
+
classification["suggested_action"] = "respond_with_rag"
|
| 145 |
+
|
| 146 |
+
classification["is_llm_classified"] = True
|
| 147 |
+
|
| 148 |
+
logger.info(f"LLM classified query as: {classification['type']} (confidence: {classification['confidence']:.2f})")
|
| 149 |
+
logger.debug(f"Classification reason: {classification['reason']}")
|
| 150 |
+
|
| 151 |
+
return classification
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.error(f"LLM classification failed: {e}, defaulting to RAG")
|
| 155 |
+
# On error, default to RAG (safer to try document search)
|
| 156 |
+
return {"type" : "rag",
|
| 157 |
+
"confidence" : 0.5,
|
| 158 |
+
"reason" : f"Classification failed: {str(e)}, defaulting to RAG",
|
| 159 |
+
"suggested_action" : "respond_with_rag",
|
| 160 |
+
"is_llm_classified" : False,
|
| 161 |
+
"error" : str(e)
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _parse_llm_response(self, response_text: str) -> Dict:
|
| 166 |
+
"""
|
| 167 |
+
Parse LLM JSON response
|
| 168 |
+
|
| 169 |
+
Arguments:
|
| 170 |
+
----------
|
| 171 |
+
response_text { str } : LLM response text
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
--------
|
| 175 |
+
{ dict } : Parsed classification
|
| 176 |
+
"""
|
| 177 |
+
try:
|
| 178 |
+
# Remove markdown code blocks if present
|
| 179 |
+
if ("```json" in response_text):
|
| 180 |
+
response_text = response_text.split("```json")[1].split("```")[0].strip()
|
| 181 |
+
|
| 182 |
+
elif ("```" in response_text):
|
| 183 |
+
response_text = response_text.split("```")[1].split("```")[0].strip()
|
| 184 |
+
|
| 185 |
+
# Parse JSON
|
| 186 |
+
result = json.loads(response_text)
|
| 187 |
+
|
| 188 |
+
# Validate required fields
|
| 189 |
+
if ("type" not in result) or (result["type"] not in ["rag", "general"]):
|
| 190 |
+
raise ValueError(f"Invalid type in response: {result.get('type')}")
|
| 191 |
+
|
| 192 |
+
# Set defaults for missing fields
|
| 193 |
+
result.setdefault("confidence", 0.8)
|
| 194 |
+
result.setdefault("reason", "LLM classification")
|
| 195 |
+
|
| 196 |
+
# Clamp confidence to valid range
|
| 197 |
+
result["confidence"] = max(0.0, min(1.0, float(result["confidence"])))
|
| 198 |
+
|
| 199 |
+
return result
|
| 200 |
+
|
| 201 |
+
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
| 202 |
+
logger.warning(f"Failed to parse LLM response: {e}")
|
| 203 |
+
logger.debug(f"Raw response: {response_text}")
|
| 204 |
+
|
| 205 |
+
# Try to extract type from text if JSON parsing fails
|
| 206 |
+
response_lower = response_text.lower()
|
| 207 |
+
|
| 208 |
+
if (("general" in response_lower) and ("rag" not in response_lower)):
|
| 209 |
+
return {"type" : "general",
|
| 210 |
+
"confidence" : 0.6,
|
| 211 |
+
"reason" : "Parsed from non-JSON response",
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
# Default to RAG if parsing fails
|
| 216 |
+
return {"type" : "rag",
|
| 217 |
+
"confidence" : 0.6,
|
| 218 |
+
"reason" : "Failed to parse response, defaulting to RAG",
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# Global classifier instance
|
| 223 |
+
_query_classifier = None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_query_classifier(provider: LLMProvider = None, model_name: str = None) -> QueryClassifier:
|
| 227 |
+
"""
|
| 228 |
+
Get global query classifier instance
|
| 229 |
+
|
| 230 |
+
Arguments:
|
| 231 |
+
----------
|
| 232 |
+
provider { LLMProvider } : LLM provider
|
| 233 |
+
|
| 234 |
+
model_name { str } : Model name
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
--------
|
| 238 |
+
{ QueryClassifier } : QueryClassifier instance
|
| 239 |
+
"""
|
| 240 |
+
global _query_classifier
|
| 241 |
+
|
| 242 |
+
if _query_classifier is None:
|
| 243 |
+
_query_classifier = QueryClassifier(provider = provider,
|
| 244 |
+
model_name = model_name,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return _query_classifier
|
generation/response_generator.py
ADDED
|
@@ -0,0 +1,880 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import time
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from typing import List
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import AsyncGenerator
|
| 9 |
+
from config.models import PromptType
|
| 10 |
+
from config.models import LLMProvider
|
| 11 |
+
from config.models import QueryRequest
|
| 12 |
+
from config.models import QueryResponse
|
| 13 |
+
from config.models import ChunkWithScore
|
| 14 |
+
from config.settings import get_settings
|
| 15 |
+
from config.logging_config import get_logger
|
| 16 |
+
from utils.error_handler import handle_errors
|
| 17 |
+
from generation.llm_client import get_llm_client
|
| 18 |
+
from utils.error_handler import ResponseGenerationError
|
| 19 |
+
from generation.prompt_builder import get_prompt_builder
|
| 20 |
+
from retrieval.hybrid_retriever import get_hybrid_retriever
|
| 21 |
+
from generation.query_classifier import get_query_classifier
|
| 22 |
+
from generation.general_responder import get_general_responder
|
| 23 |
+
from generation.citation_formatter import get_citation_formatter
|
| 24 |
+
from generation.temperature_controller import get_temperature_controller
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Setup Settings and Logging
|
| 28 |
+
settings = get_settings()
|
| 29 |
+
logger = get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ResponseGenerator:
|
| 33 |
+
"""
|
| 34 |
+
Main orchestrator for RAG response generation with LLM-based intelligent query routing
|
| 35 |
+
|
| 36 |
+
Handles both:
|
| 37 |
+
1. Generic/conversational queries (greetings, system info, general knowledge)
|
| 38 |
+
2. Document-based RAG queries (retrieval + generation)
|
| 39 |
+
|
| 40 |
+
Pipeline: Query → LLM Classifier → Route to (General LLM | RAG Pipeline) → Response
|
| 41 |
+
"""
|
| 42 |
+
def __init__(self, provider: LLMProvider = None, model_name: str = None):
|
| 43 |
+
"""
|
| 44 |
+
Initialize response generator with LLM-based query routing capabilities
|
| 45 |
+
|
| 46 |
+
Arguments:
|
| 47 |
+
----------
|
| 48 |
+
provider { LLMProvider } : LLM provider (Ollama/OpenAI)
|
| 49 |
+
|
| 50 |
+
model_name { str } : Model name to use
|
| 51 |
+
"""
|
| 52 |
+
self.logger = logger
|
| 53 |
+
self.settings = get_settings()
|
| 54 |
+
|
| 55 |
+
# Auto-detect provider for HF Spaces
|
| 56 |
+
if provider is None:
|
| 57 |
+
if (self.settings.IS_HF_SPACE and not self.settings.OLLAMA_ENABLED):
|
| 58 |
+
if (self.settings.USE_OPENAI and self.settings.OPENAI_API_KEY):
|
| 59 |
+
provider = LLMProvider.OPENAI
|
| 60 |
+
model_name = model_name or self.settings.OPENAI_MODEL
|
| 61 |
+
logger.info("Auto-detected: Using OpenAI")
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError("No LLM provider configured for HF Space. Set OPENAI_API_KEY in Space secrets.")
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
# Local development - use Ollama
|
| 68 |
+
provider = LLMProvider.OLLAMA
|
| 69 |
+
|
| 70 |
+
self.provider = provider
|
| 71 |
+
self.model_name = model_name
|
| 72 |
+
|
| 73 |
+
# Initialize components
|
| 74 |
+
self.llm_client = get_llm_client(provider = self.provider,
|
| 75 |
+
model_name = self.model_name,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Query routing components (NOW USES LLM FOR CLASSIFICATION)
|
| 79 |
+
self.query_classifier = get_query_classifier(provider = self.provider,
|
| 80 |
+
model_name = self.model_name,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.general_responder = get_general_responder(provider = self.provider,
|
| 84 |
+
model_name = self.model_name,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# RAG components
|
| 88 |
+
self.hybrid_retriever = get_hybrid_retriever()
|
| 89 |
+
self.prompt_builder = get_prompt_builder(model_name = self.model_name)
|
| 90 |
+
self.citation_formatter = get_citation_formatter()
|
| 91 |
+
self.temperature_controller = get_temperature_controller()
|
| 92 |
+
|
| 93 |
+
# Statistics
|
| 94 |
+
self.generation_count = 0
|
| 95 |
+
self.total_generation_time = 0.0
|
| 96 |
+
self.general_query_count = 0
|
| 97 |
+
self.rag_query_count = 0
|
| 98 |
+
|
| 99 |
+
self.logger.info(f"Initialized ResponseGenerator with LLM-Based Query Routing: provider={self.provider.value}, model={self.model_name}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@handle_errors(error_type = ResponseGenerationError, log_error = True, reraise = True)
|
| 103 |
+
async def generate_response(self, request: QueryRequest, conversation_history: List[Dict] = None, has_documents: bool = True) -> QueryResponse:
|
| 104 |
+
"""
|
| 105 |
+
Generate response with LLM-based intelligent query routing
|
| 106 |
+
|
| 107 |
+
Arguments:
|
| 108 |
+
----------
|
| 109 |
+
request { QueryRequest } : Query request object
|
| 110 |
+
|
| 111 |
+
conversation_history { list } : Previous conversation messages
|
| 112 |
+
|
| 113 |
+
has_documents { bool } : Whether documents are available in the system
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
--------
|
| 117 |
+
{ QueryResponse } : Complete query response
|
| 118 |
+
"""
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
|
| 121 |
+
self.logger.info(f"Processing query: '{request.query[:100]}...'")
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
# Classify query using LLM
|
| 125 |
+
classification = await self.query_classifier.classify(query = request.query,
|
| 126 |
+
has_documents = has_documents,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.logger.info(f"Query classified as: {classification['type']} (confidence: {classification['confidence']:.2f}, LLM-based: {classification.get('is_llm_classified', False)})")
|
| 130 |
+
self.logger.debug(f"Classification reason: {classification['reason']}")
|
| 131 |
+
|
| 132 |
+
# Route based on classification
|
| 133 |
+
if (classification['suggested_action'] == 'respond_with_general_llm'):
|
| 134 |
+
# Handle as general query
|
| 135 |
+
response = await self._handle_general_query(request = request,
|
| 136 |
+
classification = classification,
|
| 137 |
+
start_time = start_time,
|
| 138 |
+
history = conversation_history,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.general_query_count += 1
|
| 142 |
+
|
| 143 |
+
return response
|
| 144 |
+
|
| 145 |
+
elif (classification['suggested_action'] == 'respond_with_rag'):
|
| 146 |
+
# Handle as RAG query
|
| 147 |
+
response = await self._handle_rag_query(request = request,
|
| 148 |
+
classification = classification,
|
| 149 |
+
start_time = start_time,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.rag_query_count += 1
|
| 153 |
+
|
| 154 |
+
return response
|
| 155 |
+
|
| 156 |
+
else:
|
| 157 |
+
# Default to RAG if unclear
|
| 158 |
+
self.logger.info("Unclear classification - defaulting to RAG...")
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
response = await self._handle_rag_query(request = request,
|
| 162 |
+
classification = classification,
|
| 163 |
+
start_time = start_time,
|
| 164 |
+
allow_fallback = True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# If no results from RAG, fall back to general
|
| 168 |
+
if ((not response.sources) or (len(response.sources) == 0)):
|
| 169 |
+
self.logger.info("No RAG results - falling back to general response")
|
| 170 |
+
|
| 171 |
+
response = await self._handle_general_query(request = request,
|
| 172 |
+
classification = classification,
|
| 173 |
+
start_time = start_time,
|
| 174 |
+
history = conversation_history,
|
| 175 |
+
)
|
| 176 |
+
self.general_query_count += 1
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
self.rag_query_count += 1
|
| 180 |
+
|
| 181 |
+
return response
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
self.logger.warning(f"RAG attempt failed, falling back to general: {e}")
|
| 185 |
+
response = await self._handle_general_query(request = request,
|
| 186 |
+
classification = classification,
|
| 187 |
+
start_time = start_time,
|
| 188 |
+
history = conversation_history,
|
| 189 |
+
)
|
| 190 |
+
self.general_query_count += 1
|
| 191 |
+
|
| 192 |
+
return response
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
self.logger.error(f"Response generation failed: {repr(e)}", exc_info = True)
|
| 196 |
+
raise ResponseGenerationError(f"Response generation failed: {repr(e)}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
async def _handle_general_query(self, request: QueryRequest, classification: Dict, start_time: float, history: List[Dict] = None) -> QueryResponse:
|
| 200 |
+
"""
|
| 201 |
+
Handle general/conversational queries without RAG
|
| 202 |
+
|
| 203 |
+
Arguments:
|
| 204 |
+
----------
|
| 205 |
+
request { QueryRequest } : Original request
|
| 206 |
+
|
| 207 |
+
classification { dict } : Classification result
|
| 208 |
+
|
| 209 |
+
start_time { float } : Start timestamp
|
| 210 |
+
|
| 211 |
+
history { list } : Conversation history
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
--------
|
| 215 |
+
{ QueryResponse } : Response without RAG
|
| 216 |
+
"""
|
| 217 |
+
self.logger.debug("Handling as general query...")
|
| 218 |
+
|
| 219 |
+
# Use general responder
|
| 220 |
+
general_response = await self.general_responder.respond(query = request.query,
|
| 221 |
+
conversation_history = history,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
answer = general_response.get("answer", "I'm here to help! Please let me know how I can assist you.")
|
| 225 |
+
total_time = (time.time() - start_time) * 1000
|
| 226 |
+
|
| 227 |
+
# Create QueryResponse object
|
| 228 |
+
response = QueryResponse(query = request.query,
|
| 229 |
+
answer = answer,
|
| 230 |
+
sources = [], # No sources for general queries
|
| 231 |
+
retrieval_time_ms = 0.0,
|
| 232 |
+
generation_time_ms = total_time,
|
| 233 |
+
total_time_ms = total_time,
|
| 234 |
+
tokens_used = general_response.get("tokens_used", {"input": 0, "output": 0, "total": 0}),
|
| 235 |
+
model_used = self.model_name,
|
| 236 |
+
timestamp = datetime.now(),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Add metadata about query type
|
| 240 |
+
if request.include_metrics:
|
| 241 |
+
response.metrics = {"query_type" : "general",
|
| 242 |
+
"classification" : classification['type'],
|
| 243 |
+
"confidence" : classification['confidence'],
|
| 244 |
+
"requires_rag" : False,
|
| 245 |
+
"conversation_mode" : True,
|
| 246 |
+
"llm_classified" : classification.get('is_llm_classified', False),
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
self.logger.info(f"General response generated in {total_time:.0f}ms")
|
| 250 |
+
|
| 251 |
+
return response
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
async def _handle_rag_query(self, request: QueryRequest, classification: Dict, start_time: float, allow_fallback: bool = False) -> QueryResponse:
|
| 255 |
+
"""
|
| 256 |
+
Handle RAG-based queries with document retrieval
|
| 257 |
+
|
| 258 |
+
Arguments:
|
| 259 |
+
----------
|
| 260 |
+
request { QueryRequest } : Original request
|
| 261 |
+
|
| 262 |
+
classification { dict } : Classification result
|
| 263 |
+
|
| 264 |
+
start_time { float } : Start timestamp
|
| 265 |
+
|
| 266 |
+
allow_fallback { bool } : Whether to allow fallback to general
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
--------
|
| 270 |
+
{ QueryResponse } : RAG response
|
| 271 |
+
"""
|
| 272 |
+
self.logger.debug("Handling as RAG query...")
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
# Retrieve relevant context
|
| 276 |
+
self.logger.debug("Retrieving context...")
|
| 277 |
+
retrieval_start = time.time()
|
| 278 |
+
|
| 279 |
+
retrieval_result = self.hybrid_retriever.retrieve_with_context(query = request.query,
|
| 280 |
+
top_k = request.top_k or self.settings.TOP_K_RETRIEVE,
|
| 281 |
+
enable_reranking = request.enable_reranking,
|
| 282 |
+
include_citations = request.include_sources,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
retrieval_time = (time.time() - retrieval_start) * 1000
|
| 286 |
+
|
| 287 |
+
chunks = retrieval_result["chunks"]
|
| 288 |
+
context = retrieval_result["context"]
|
| 289 |
+
|
| 290 |
+
if not chunks:
|
| 291 |
+
self.logger.warning("No relevant context found for query")
|
| 292 |
+
|
| 293 |
+
if allow_fallback:
|
| 294 |
+
# Return empty response to trigger fallback
|
| 295 |
+
return QueryResponse(query = request.query,
|
| 296 |
+
answer = "",
|
| 297 |
+
sources = [],
|
| 298 |
+
retrieval_time_ms = retrieval_time,
|
| 299 |
+
generation_time_ms = 0.0,
|
| 300 |
+
total_time_ms = retrieval_time,
|
| 301 |
+
tokens_used = {"input": 0, "output": 0, "total": 0},
|
| 302 |
+
model_used = self.model_name,
|
| 303 |
+
timestamp = datetime.now(),
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
else:
|
| 307 |
+
return self._create_no_results_response(request = request,
|
| 308 |
+
retrieval_time_ms = retrieval_time,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
self.logger.info(f"Retrieved {len(chunks)} chunks in {retrieval_time:.0f}ms")
|
| 312 |
+
|
| 313 |
+
# Determine prompt type and temperature
|
| 314 |
+
self.logger.debug("Determining prompt strategy...")
|
| 315 |
+
|
| 316 |
+
prompt_type = self._infer_prompt_type(query = request.query)
|
| 317 |
+
|
| 318 |
+
temperature = self._get_adaptive_temperature(request = request,
|
| 319 |
+
query = request.query,
|
| 320 |
+
context = context,
|
| 321 |
+
retrieval_scores = [chunk.score for chunk in chunks],
|
| 322 |
+
prompt_type = prompt_type,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.logger.debug(f"Prompt type: {prompt_type.value}, Temperature: {temperature}")
|
| 326 |
+
|
| 327 |
+
# Build optimized prompt
|
| 328 |
+
self.logger.debug("Building prompt...")
|
| 329 |
+
|
| 330 |
+
prompt = self.prompt_builder.build_prompt(query = request.query,
|
| 331 |
+
context = context,
|
| 332 |
+
sources = chunks,
|
| 333 |
+
prompt_type = prompt_type,
|
| 334 |
+
include_citations = request.include_sources,
|
| 335 |
+
max_completion_tokens = request.max_tokens or self.settings.MAX_TOKENS,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Generate LLM response
|
| 339 |
+
self.logger.debug("Generating LLM response...")
|
| 340 |
+
generation_start = time.time()
|
| 341 |
+
|
| 342 |
+
messages = [{"role" : "system",
|
| 343 |
+
"content" : prompt["system"]
|
| 344 |
+
},
|
| 345 |
+
{"role" : "user",
|
| 346 |
+
"content" : prompt["user"],
|
| 347 |
+
}
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
llm_response = await self.llm_client.generate(messages = messages,
|
| 351 |
+
temperature = temperature,
|
| 352 |
+
top_p = request.top_p or self.settings.TOP_P,
|
| 353 |
+
max_tokens = request.max_tokens or self.settings.MAX_TOKENS,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
generation_time = (time.time() - generation_start) * 1000
|
| 357 |
+
|
| 358 |
+
answer = llm_response["content"]
|
| 359 |
+
|
| 360 |
+
self.logger.info(f"Generated response in {generation_time:.0f}ms ({llm_response['usage']['completion_tokens']} tokens)")
|
| 361 |
+
|
| 362 |
+
# Format citations (if enabled)
|
| 363 |
+
if request.include_sources:
|
| 364 |
+
self.logger.debug("Formatting citations...")
|
| 365 |
+
answer = self._post_process_citations(answer = answer,
|
| 366 |
+
sources = chunks,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Create response object
|
| 370 |
+
total_time = (time.time() - start_time) * 1000
|
| 371 |
+
|
| 372 |
+
response = QueryResponse(query = request.query,
|
| 373 |
+
answer = answer,
|
| 374 |
+
sources = chunks if request.include_sources else [],
|
| 375 |
+
retrieval_time_ms = retrieval_time,
|
| 376 |
+
generation_time_ms = generation_time,
|
| 377 |
+
total_time_ms = total_time,
|
| 378 |
+
tokens_used = {"input" : llm_response["usage"]["prompt_tokens"],
|
| 379 |
+
"output" : llm_response["usage"]["completion_tokens"],
|
| 380 |
+
"total" : llm_response["usage"]["total_tokens"],
|
| 381 |
+
},
|
| 382 |
+
model_used = self.model_name,
|
| 383 |
+
timestamp = datetime.now(),
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Add quality metrics if requested
|
| 387 |
+
if request.include_metrics:
|
| 388 |
+
response.metrics = self._calculate_quality_metrics(query = request.query,
|
| 389 |
+
answer = answer,
|
| 390 |
+
context = context,
|
| 391 |
+
sources = chunks,
|
| 392 |
+
)
|
| 393 |
+
# Track both: prediction & reality
|
| 394 |
+
response.metrics["predicted_type"] = classification.get('type', 'unknown')
|
| 395 |
+
response.metrics["predicted_confidence"] = classification.get('confidence', 0.0)
|
| 396 |
+
response.metrics["actual_type"] = "rag" # Always rag if we're here
|
| 397 |
+
response.metrics["execution_path"] = "rag_pipeline"
|
| 398 |
+
response.metrics["has_context"] = len(chunks) > 0
|
| 399 |
+
response.metrics["context_chunks"] = len(chunks)
|
| 400 |
+
response.metrics["rag_confidence"] = min(1.0, sum(c.score for c in chunks) / len(chunks) if chunks else 0.0)
|
| 401 |
+
response.metrics["is_forced_rag"] = classification.get('is_forced_rag', False)
|
| 402 |
+
response.metrics["llm_classified"] = classification.get('is_llm_classified', False)
|
| 403 |
+
|
| 404 |
+
# Add context for evaluation
|
| 405 |
+
response.metrics["context_for_evaluation"] = context
|
| 406 |
+
|
| 407 |
+
# Update statistics
|
| 408 |
+
self.generation_count += 1
|
| 409 |
+
self.total_generation_time += total_time
|
| 410 |
+
|
| 411 |
+
self.logger.info(f"RAG response generated successfully in {total_time:.0f}ms")
|
| 412 |
+
|
| 413 |
+
return response
|
| 414 |
+
|
| 415 |
+
except Exception as e:
|
| 416 |
+
self.logger.error(f"RAG query handling failed: {repr(e)}", exc_info = True)
|
| 417 |
+
|
| 418 |
+
if allow_fallback:
|
| 419 |
+
# Return empty to trigger fallback
|
| 420 |
+
return QueryResponse(query = request.query,
|
| 421 |
+
answer = "",
|
| 422 |
+
sources = [],
|
| 423 |
+
retrieval_time_ms = 0.0,
|
| 424 |
+
generation_time_ms = 0.0,
|
| 425 |
+
total_time_ms = 0.0,
|
| 426 |
+
tokens_used = {"input": 0, "output": 0, "total": 0},
|
| 427 |
+
model_used = self.model_name,
|
| 428 |
+
timestamp = datetime.now(),
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
raise
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@handle_errors(error_type = ResponseGenerationError, log_error = True, reraise = True)
|
| 435 |
+
async def generate_response_stream(self, request: QueryRequest, has_documents: bool = True) -> AsyncGenerator[str, None]:
|
| 436 |
+
"""
|
| 437 |
+
Generate streaming RAG response
|
| 438 |
+
|
| 439 |
+
Arguments:
|
| 440 |
+
----------
|
| 441 |
+
request { QueryRequest } : Query request object
|
| 442 |
+
|
| 443 |
+
has_documents { bool } : Whether documents are available
|
| 444 |
+
|
| 445 |
+
Yields:
|
| 446 |
+
-------
|
| 447 |
+
{ str } : Response chunks (tokens)
|
| 448 |
+
"""
|
| 449 |
+
self.logger.info(f"Generating streaming response for query: '{request.query[:100]}...'")
|
| 450 |
+
|
| 451 |
+
try:
|
| 452 |
+
# Classify query first
|
| 453 |
+
classification = await self.query_classifier.classify(query = request.query,
|
| 454 |
+
has_documents = has_documents,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
if (classification['suggested_action'] == 'respond_with_general_llm'):
|
| 458 |
+
# Stream general response
|
| 459 |
+
general_response = await self.general_responder.respond(query = request.query)
|
| 460 |
+
yield general_response.get("answer", "")
|
| 461 |
+
|
| 462 |
+
return
|
| 463 |
+
|
| 464 |
+
# Otherwise proceed with RAG streaming - Procced with Retrieving context
|
| 465 |
+
retrieval_result = self.hybrid_retriever.retrieve_with_context(query = request.query,
|
| 466 |
+
top_k = request.top_k or self.settings.TOP_K_RETRIEVE,
|
| 467 |
+
enable_reranking = request.enable_reranking,
|
| 468 |
+
include_citations = request.include_sources,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
chunks = retrieval_result["chunks"]
|
| 472 |
+
context = retrieval_result["context"]
|
| 473 |
+
|
| 474 |
+
if not chunks:
|
| 475 |
+
yield "I couldn't find relevant information to answer your question."
|
| 476 |
+
return
|
| 477 |
+
|
| 478 |
+
# Determine strategy
|
| 479 |
+
prompt_type = self._infer_prompt_type(query = request.query)
|
| 480 |
+
temperature = self._get_adaptive_temperature(request = request,
|
| 481 |
+
query = request.query,
|
| 482 |
+
context = context,
|
| 483 |
+
retrieval_scores = [chunk.score for chunk in chunks],
|
| 484 |
+
prompt_type = prompt_type,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Build prompt
|
| 488 |
+
prompt = self.prompt_builder.build_prompt(query = request.query,
|
| 489 |
+
context = context,
|
| 490 |
+
sources = chunks,
|
| 491 |
+
prompt_type = prompt_type,
|
| 492 |
+
include_citations = request.include_sources,
|
| 493 |
+
max_completion_tokens = request.max_tokens or self.settings.MAX_TOKENS,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# Stream LLM response
|
| 497 |
+
messages = [{"role" : "system",
|
| 498 |
+
"content" : prompt["system"],
|
| 499 |
+
},
|
| 500 |
+
{"role" : "user",
|
| 501 |
+
"content" : prompt["user"],
|
| 502 |
+
},
|
| 503 |
+
]
|
| 504 |
+
|
| 505 |
+
async for chunk_text in self.llm_client.generate_stream(messages = messages,
|
| 506 |
+
temperature = temperature,
|
| 507 |
+
top_p = request.top_p or self.settings.TOP_P,
|
| 508 |
+
max_tokens = request.max_tokens or self.settings.MAX_TOKENS,
|
| 509 |
+
):
|
| 510 |
+
yield chunk_text
|
| 511 |
+
|
| 512 |
+
self.logger.info("Streaming response completed")
|
| 513 |
+
|
| 514 |
+
except Exception as e:
|
| 515 |
+
self.logger.error(f"Streaming generation failed: {repr(e)}", exc_info = True)
|
| 516 |
+
|
| 517 |
+
yield f"\n\n[Error: {str(e)}]"
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def _infer_prompt_type(self, query: str) -> PromptType:
|
| 521 |
+
"""
|
| 522 |
+
Infer appropriate prompt type from query
|
| 523 |
+
|
| 524 |
+
Arguments:
|
| 525 |
+
----------
|
| 526 |
+
query { str } : User query
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
--------
|
| 530 |
+
{ PromptType } : Inferred prompt type
|
| 531 |
+
"""
|
| 532 |
+
query_lower = query.lower()
|
| 533 |
+
|
| 534 |
+
# Summary indicators
|
| 535 |
+
if (any(word in query_lower for word in ['summarize', 'summary', 'overview', 'tldr', 'brief'])):
|
| 536 |
+
return PromptType.SUMMARY
|
| 537 |
+
|
| 538 |
+
# Comparison indicators
|
| 539 |
+
if (any(word in query_lower for word in ['compare', 'contrast', 'difference', 'versus', 'vs'])):
|
| 540 |
+
return PromptType.COMPARISON
|
| 541 |
+
|
| 542 |
+
# Analytical indicators
|
| 543 |
+
if (any(word in query_lower for word in ['analyze', 'analysis', 'evaluate', 'assess', 'examine'])):
|
| 544 |
+
return PromptType.ANALYTICAL
|
| 545 |
+
|
| 546 |
+
# Extraction indicators
|
| 547 |
+
if (any(word in query_lower for word in ['extract', 'list', 'find all', 'identify', 'enumerate'])):
|
| 548 |
+
return PromptType.EXTRACTION
|
| 549 |
+
|
| 550 |
+
# Creative indicators
|
| 551 |
+
if (any(word in query_lower for word in ['create', 'write', 'compose', 'generate', 'imagine'])):
|
| 552 |
+
return PromptType.CREATIVE
|
| 553 |
+
|
| 554 |
+
# Conversational indicators
|
| 555 |
+
if (any(word in query_lower for word in ['tell me about', 'explain', 'discuss', 'talk about'])):
|
| 556 |
+
return PromptType.CONVERSATIONAL
|
| 557 |
+
|
| 558 |
+
# Default to QA
|
| 559 |
+
return PromptType.QA
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def _get_adaptive_temperature(self, request: QueryRequest, query: str, context: str, retrieval_scores: List[float], prompt_type: PromptType) -> float:
|
| 563 |
+
"""
|
| 564 |
+
Get adaptive temperature based on query characteristics
|
| 565 |
+
|
| 566 |
+
Arguments:
|
| 567 |
+
----------
|
| 568 |
+
request { QueryRequest } : Original request
|
| 569 |
+
|
| 570 |
+
query { str } : User query
|
| 571 |
+
|
| 572 |
+
context { str } : Retrieved context
|
| 573 |
+
|
| 574 |
+
retrieval_scores { list } : Retrieval scores
|
| 575 |
+
|
| 576 |
+
prompt_type { PromptType } : Inferred prompt type
|
| 577 |
+
|
| 578 |
+
Returns:
|
| 579 |
+
--------
|
| 580 |
+
{ float } : Temperature value
|
| 581 |
+
"""
|
| 582 |
+
# Use request temperature if explicitly provided
|
| 583 |
+
if (request.temperature is not None):
|
| 584 |
+
self.logger.debug(f"Using request temperature: {request.temperature}")
|
| 585 |
+
|
| 586 |
+
return request.temperature
|
| 587 |
+
|
| 588 |
+
# Otherwise, use adaptive temperature controller
|
| 589 |
+
temperature = self.temperature_controller.get_temperature(query = query,
|
| 590 |
+
context = context,
|
| 591 |
+
retrieval_scores = retrieval_scores,
|
| 592 |
+
query_type = prompt_type.value,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
return temperature
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def _post_process_citations(self, answer: str, sources: List[ChunkWithScore]) -> str:
|
| 599 |
+
"""
|
| 600 |
+
Post-process answer to format citations
|
| 601 |
+
|
| 602 |
+
Arguments:
|
| 603 |
+
----------
|
| 604 |
+
answer { str } : Generated answer with citation markers
|
| 605 |
+
|
| 606 |
+
sources { list } : Source chunks
|
| 607 |
+
|
| 608 |
+
Returns:
|
| 609 |
+
--------
|
| 610 |
+
{ str } : Answer with formatted citations
|
| 611 |
+
"""
|
| 612 |
+
try:
|
| 613 |
+
# Validate citations
|
| 614 |
+
is_valid, invalid = self.citation_formatter.validate_citations(answer, sources)
|
| 615 |
+
|
| 616 |
+
if not is_valid:
|
| 617 |
+
self.logger.warning(f"Invalid citations found: {invalid}")
|
| 618 |
+
# Normalize to fix issues
|
| 619 |
+
answer = self.citation_formatter.normalize_citations(answer, sources)
|
| 620 |
+
|
| 621 |
+
# Format citations according to style
|
| 622 |
+
formatted_answer = self.citation_formatter.format_citations_in_text(answer, sources)
|
| 623 |
+
|
| 624 |
+
return formatted_answer
|
| 625 |
+
|
| 626 |
+
except Exception as e:
|
| 627 |
+
self.logger.error(f"Citation post-processing failed: {repr(e)}")
|
| 628 |
+
# Return original answer if formatting fails
|
| 629 |
+
return answer
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def _create_no_results_response(self, request: QueryRequest, retrieval_time_ms: float) -> QueryResponse:
|
| 633 |
+
"""
|
| 634 |
+
Create response when no results are found
|
| 635 |
+
|
| 636 |
+
Arguments:
|
| 637 |
+
----------
|
| 638 |
+
request { QueryRequest } : Original request
|
| 639 |
+
|
| 640 |
+
retrieval_time_ms { float } : Time spent on retrieval
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
--------
|
| 644 |
+
{ QueryResponse } : Response indicating no results
|
| 645 |
+
"""
|
| 646 |
+
no_results_answer = ("I couldn't find relevant information in the available documents to answer your question. "
|
| 647 |
+
"This could mean:\n"
|
| 648 |
+
"1. The information is not present in the indexed documents\n"
|
| 649 |
+
"2. The question may need to be rephrased for better matching\n"
|
| 650 |
+
"3. The relevant documents haven't been uploaded yet\n\n"
|
| 651 |
+
"Please try:\n"
|
| 652 |
+
"- Rephrasing your question with different keywords\n"
|
| 653 |
+
"- Asking a more specific or general question\n"
|
| 654 |
+
"- Ensuring the relevant documents are uploaded\n"
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
return QueryResponse(query = request.query,
|
| 658 |
+
answer = no_results_answer,
|
| 659 |
+
sources = [],
|
| 660 |
+
retrieval_time_ms = retrieval_time_ms,
|
| 661 |
+
generation_time_ms = 0.0,
|
| 662 |
+
total_time_ms = retrieval_time_ms,
|
| 663 |
+
tokens_used = {"input": 0, "output": 0, "total": 0},
|
| 664 |
+
model_used = self.model_name,
|
| 665 |
+
timestamp = datetime.now(),
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def _calculate_quality_metrics(self, query: str, answer: str, context: str, sources: List[ChunkWithScore]) -> Dict[str, float]:
|
| 670 |
+
"""
|
| 671 |
+
Calculate quality metrics for the response
|
| 672 |
+
|
| 673 |
+
Arguments:
|
| 674 |
+
----------
|
| 675 |
+
query { str } : User query
|
| 676 |
+
|
| 677 |
+
answer { str } : Generated answer
|
| 678 |
+
|
| 679 |
+
context { str } : Retrieved context
|
| 680 |
+
|
| 681 |
+
sources { list } : Source chunks
|
| 682 |
+
|
| 683 |
+
Returns:
|
| 684 |
+
--------
|
| 685 |
+
{ dict } : Quality metrics
|
| 686 |
+
"""
|
| 687 |
+
metrics = dict()
|
| 688 |
+
|
| 689 |
+
try:
|
| 690 |
+
# Answer length metrics
|
| 691 |
+
metrics["answer_length"] = len(answer.split())
|
| 692 |
+
metrics["answer_char_length"] = len(answer)
|
| 693 |
+
|
| 694 |
+
# Citation metrics
|
| 695 |
+
citation_stats = self.citation_formatter.get_citation_statistics(answer, sources)
|
| 696 |
+
metrics["citations_used"] = citation_stats.get("total_citations", 0)
|
| 697 |
+
metrics["unique_citations"] = citation_stats.get("unique_citations", 0)
|
| 698 |
+
metrics["citation_density"] = citation_stats.get("citation_density", 0.0)
|
| 699 |
+
|
| 700 |
+
# Context utilization
|
| 701 |
+
context_length = len(context.split())
|
| 702 |
+
metrics["context_utilization"] = min(1.0, metrics["answer_length"] / max(1, context_length))
|
| 703 |
+
|
| 704 |
+
# Retrieval quality
|
| 705 |
+
if sources:
|
| 706 |
+
avg_score = sum(s.score for s in sources) / len(sources)
|
| 707 |
+
metrics["avg_retrieval_score"] = avg_score
|
| 708 |
+
metrics["top_retrieval_score"] = sources[0].score if sources else 0.0
|
| 709 |
+
|
| 710 |
+
# Query-answer alignment (simple keyword overlap)
|
| 711 |
+
query_words = set(query.lower().split())
|
| 712 |
+
answer_words = set(answer.lower().split())
|
| 713 |
+
overlap = len(query_words & answer_words)
|
| 714 |
+
metrics["query_answer_overlap"] = overlap / max(1, len(query_words))
|
| 715 |
+
|
| 716 |
+
except Exception as e:
|
| 717 |
+
self.logger.warning(f"Failed to calculate some quality metrics: {repr(e)}")
|
| 718 |
+
|
| 719 |
+
return metrics
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
async def generate_batch_responses(self, requests: List[QueryRequest], has_documents: bool = True) -> List[QueryResponse]:
|
| 723 |
+
"""
|
| 724 |
+
Generate responses for multiple queries in batch
|
| 725 |
+
|
| 726 |
+
Arguments:
|
| 727 |
+
----------
|
| 728 |
+
requests { list } : List of query requests
|
| 729 |
+
|
| 730 |
+
has_documents { bool } : Whether documents are available
|
| 731 |
+
|
| 732 |
+
Returns:
|
| 733 |
+
--------
|
| 734 |
+
{ list } : List of query responses
|
| 735 |
+
"""
|
| 736 |
+
self.logger.info(f"Generating batch responses for {len(requests)} queries")
|
| 737 |
+
|
| 738 |
+
tasks = [self.generate_response(request = request,
|
| 739 |
+
has_documents = has_documents) for request in requests]
|
| 740 |
+
|
| 741 |
+
responses = await asyncio.gather(*tasks, return_exceptions = True)
|
| 742 |
+
|
| 743 |
+
# Handle exceptions
|
| 744 |
+
results = list()
|
| 745 |
+
|
| 746 |
+
for i, response in enumerate(responses):
|
| 747 |
+
if isinstance(response, Exception):
|
| 748 |
+
self.logger.error(f"Batch query {i} failed: {repr(response)}")
|
| 749 |
+
# Create error response
|
| 750 |
+
error_response = self._create_error_response(requests[i], str(response))
|
| 751 |
+
results.append(error_response)
|
| 752 |
+
|
| 753 |
+
else:
|
| 754 |
+
results.append(response)
|
| 755 |
+
|
| 756 |
+
self.logger.info(f"Completed batch generation: {len(results)} responses")
|
| 757 |
+
|
| 758 |
+
return results
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def _create_error_response(self, request: QueryRequest, error_message: str) -> QueryResponse:
|
| 762 |
+
"""
|
| 763 |
+
Create error response for failed generation
|
| 764 |
+
|
| 765 |
+
Arguments:
|
| 766 |
+
----------
|
| 767 |
+
request { QueryRequest } : Original request
|
| 768 |
+
|
| 769 |
+
error_message { str } : Error message
|
| 770 |
+
|
| 771 |
+
Returns:
|
| 772 |
+
--------
|
| 773 |
+
{ QueryResponse } : Error response
|
| 774 |
+
"""
|
| 775 |
+
return QueryResponse(query = request.query,
|
| 776 |
+
answer = f"An error occurred while generating the response: {error_message}",
|
| 777 |
+
sources = [],
|
| 778 |
+
retrieval_time_ms = 0.0,
|
| 779 |
+
generation_time_ms = 0.0,
|
| 780 |
+
total_time_ms = 0.0,
|
| 781 |
+
tokens_used = {"input": 0, "output": 0, "total": 0},
|
| 782 |
+
model_used = self.model_name,
|
| 783 |
+
timestamp = datetime.now(),
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def get_generation_stats(self) -> Dict:
|
| 788 |
+
"""
|
| 789 |
+
Get generation statistics including query type breakdown
|
| 790 |
+
|
| 791 |
+
Returns:
|
| 792 |
+
--------
|
| 793 |
+
{ dict } : Generation statistics
|
| 794 |
+
"""
|
| 795 |
+
avg_time = (self.total_generation_time / self.generation_count) if self.generation_count > 0 else 0
|
| 796 |
+
|
| 797 |
+
return {"total_generations" : self.generation_count,
|
| 798 |
+
"general_queries" : self.general_query_count,
|
| 799 |
+
"rag_queries" : self.rag_query_count,
|
| 800 |
+
"total_generation_time" : self.total_generation_time,
|
| 801 |
+
"avg_generation_time_ms" : avg_time,
|
| 802 |
+
"provider" : self.provider.value,
|
| 803 |
+
"model" : self.model_name,
|
| 804 |
+
"llm_health" : self.llm_client.check_health(),
|
| 805 |
+
"query_routing_enabled" : True,
|
| 806 |
+
"llm_based_routing" : True,
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def reset_stats(self):
|
| 811 |
+
"""
|
| 812 |
+
Reset generation statistics
|
| 813 |
+
"""
|
| 814 |
+
self.generation_count = 0
|
| 815 |
+
self.general_query_count = 0
|
| 816 |
+
self.rag_query_count = 0
|
| 817 |
+
self.total_generation_time = 0.0
|
| 818 |
+
|
| 819 |
+
self.logger.info("Generation statistics reset")
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
# Global response generator instance
|
| 823 |
+
_response_generator = None
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def get_response_generator(provider: LLMProvider = None, model_name: str = None) -> ResponseGenerator:
|
| 827 |
+
"""
|
| 828 |
+
Get global response generator instance (singleton)
|
| 829 |
+
|
| 830 |
+
Arguments:
|
| 831 |
+
----------
|
| 832 |
+
provider { LLMProvider } : LLM provider
|
| 833 |
+
|
| 834 |
+
model_name { str } : Model name
|
| 835 |
+
|
| 836 |
+
Returns:
|
| 837 |
+
--------
|
| 838 |
+
{ ResponseGenerator } : ResponseGenerator instance
|
| 839 |
+
"""
|
| 840 |
+
global _response_generator
|
| 841 |
+
|
| 842 |
+
if _response_generator is None or (provider and _response_generator.provider != provider):
|
| 843 |
+
_response_generator = ResponseGenerator(provider, model_name)
|
| 844 |
+
|
| 845 |
+
return _response_generator
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
@handle_errors(error_type = ResponseGenerationError, log_error = True, reraise = False)
|
| 849 |
+
async def generate_answer(query: str, top_k: int = 5, temperature: float = None, has_documents: bool = True, **kwargs) -> str:
|
| 850 |
+
"""
|
| 851 |
+
Convenience function for quick answer generation
|
| 852 |
+
|
| 853 |
+
Arguments:
|
| 854 |
+
----------
|
| 855 |
+
query { str } : User query
|
| 856 |
+
|
| 857 |
+
top_k { int } : Number of chunks to retrieve
|
| 858 |
+
|
| 859 |
+
temperature { float } : Temperature for generation
|
| 860 |
+
|
| 861 |
+
has_documents { bool } : Whether documents are available
|
| 862 |
+
|
| 863 |
+
**kwargs : Additional parameters
|
| 864 |
+
|
| 865 |
+
Returns:
|
| 866 |
+
--------
|
| 867 |
+
{ str } : Generated answer
|
| 868 |
+
"""
|
| 869 |
+
request = QueryRequest(query = query,
|
| 870 |
+
top_k = top_k,
|
| 871 |
+
temperature = temperature,
|
| 872 |
+
**kwargs
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
generator = get_response_generator()
|
| 876 |
+
response = await generator.generate_response(request = request,
|
| 877 |
+
has_documents = has_documents,
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
return response.answer
|
generation/temperature_controller.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import math
|
| 3 |
+
from typing import Any
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.logging_config import get_logger
|
| 8 |
+
from utils.error_handler import handle_errors
|
| 9 |
+
from config.models import TemperatureStrategy
|
| 10 |
+
from utils.error_handler import TemperatureControlError
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Setup Settings and Logging
|
| 14 |
+
settings = get_settings()
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TemperatureController:
|
| 19 |
+
"""
|
| 20 |
+
Intelligent temperature control for LLM generation: Implements adaptive temperature strategies based on query type, complexity, and desired output characteristics
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, base_temperature: float = None, strategy: TemperatureStrategy = None):
|
| 23 |
+
"""
|
| 24 |
+
Initialize temperature controller
|
| 25 |
+
|
| 26 |
+
Arguments:
|
| 27 |
+
----------
|
| 28 |
+
base_temperature { float } : Base temperature value (default from settings)
|
| 29 |
+
|
| 30 |
+
strategy { str } : Temperature control strategy
|
| 31 |
+
"""
|
| 32 |
+
self.logger = logger
|
| 33 |
+
self.settings = get_settings()
|
| 34 |
+
self.base_temperature = base_temperature or self.settings.DEFAULT_TEMPERATURE
|
| 35 |
+
self.strategy = strategy or TemperatureStrategy.ADAPTIVE
|
| 36 |
+
|
| 37 |
+
# Validate base temperature
|
| 38 |
+
if not (0.0 <= self.base_temperature <= 1.0):
|
| 39 |
+
raise TemperatureControlError(f"Temperature must be between 0 and 1: {self.base_temperature}")
|
| 40 |
+
|
| 41 |
+
# Strategy configurations
|
| 42 |
+
self.strategy_configs = {TemperatureStrategy.FIXED : {"description" : "Fixed temperature for all queries", "range" : (0.0, 1.0)},
|
| 43 |
+
TemperatureStrategy.ADAPTIVE : {"description" : "Adapt temperature based on query complexity", "range" : (0.1, 0.8), "complexity_threshold" : 0.6},
|
| 44 |
+
TemperatureStrategy.CONFIDENCE : {"description" : "Adjust temperature based on retrieval confidence", "range" : (0.1, 0.9), "high_confidence_temp" : 0.1, "low_confidence_temp" : 0.7},
|
| 45 |
+
TemperatureStrategy.PROGRESSIVE : {"description" : "Progressively increase temperature for creative tasks", "range" : (0.1, 0.9), "creative_threshold" : 0.7}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
self.logger.info(f"Initialized TemperatureController: base={self.base_temperature}, strategy={self.strategy}")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_temperature(self, query: str = "", context: str = "", retrieval_scores: Optional[list] = None, query_type: str = "qa") -> float:
|
| 52 |
+
"""
|
| 53 |
+
Get appropriate temperature for generation
|
| 54 |
+
|
| 55 |
+
Arguments:
|
| 56 |
+
----------
|
| 57 |
+
query { str } : User query
|
| 58 |
+
|
| 59 |
+
context { str } : Retrieved context
|
| 60 |
+
|
| 61 |
+
retrieval_scores { list } : Scores of retrieved chunks
|
| 62 |
+
|
| 63 |
+
query_type { str } : Type of query ('qa', 'creative', 'analytical', 'summary')
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
--------
|
| 67 |
+
{ float } : Temperature value (0.0 - 1.0)
|
| 68 |
+
"""
|
| 69 |
+
if (self.strategy == TemperatureStrategy.FIXED):
|
| 70 |
+
return self._fixed_temperature()
|
| 71 |
+
|
| 72 |
+
elif (self.strategy == TemperatureStrategy.ADAPTIVE):
|
| 73 |
+
return self._adaptive_temperature(query = query,
|
| 74 |
+
context = context,
|
| 75 |
+
query_type = query_type,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
elif (self.strategy == TemperatureStrategy.CONFIDENCE):
|
| 79 |
+
return self._confidence_based_temperature(retrieval_scores = retrieval_scores,
|
| 80 |
+
query_type = query_type,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
elif (self.strategy == TemperatureStrategy.PROGRESSIVE):
|
| 84 |
+
return self._progressive_temperature(query_type = query_type,
|
| 85 |
+
query = query,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
else:
|
| 89 |
+
self.logger.warning(f"Unknown strategy: {self.strategy}, using fixed")
|
| 90 |
+
return self.base_temperature
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _fixed_temperature(self) -> float:
|
| 94 |
+
"""
|
| 95 |
+
Fixed temperature strategy
|
| 96 |
+
"""
|
| 97 |
+
return self.base_temperature
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _adaptive_temperature(self, query: str, context: str, query_type: str) -> float:
|
| 101 |
+
"""
|
| 102 |
+
Adaptive temperature based on query complexity and type
|
| 103 |
+
"""
|
| 104 |
+
base_temp = self.base_temperature
|
| 105 |
+
|
| 106 |
+
# Adjust based on query type
|
| 107 |
+
type_adjustments = {"qa" : -0.2, # More deterministic for Q&A
|
| 108 |
+
"creative" : 0.3, # More creative for creative tasks
|
| 109 |
+
"analytical" : -0.1, # Slightly deterministic for analysis
|
| 110 |
+
"summary" : -0.15, # Deterministic for summarization
|
| 111 |
+
"comparison" : 0.1, # Slightly creative for comparisons
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
adjustment = type_adjustments.get(query_type, 0.0)
|
| 115 |
+
temp = base_temp + adjustment
|
| 116 |
+
|
| 117 |
+
# Adjust based on query complexity
|
| 118 |
+
complexity = self._calculate_query_complexity(query = query)
|
| 119 |
+
|
| 120 |
+
if (complexity > 0.7):
|
| 121 |
+
# High complexity
|
| 122 |
+
temp += 0.1
|
| 123 |
+
|
| 124 |
+
elif (complexity < 0.3):
|
| 125 |
+
# Low complexity
|
| 126 |
+
temp -= 0.1
|
| 127 |
+
|
| 128 |
+
# Adjust based on context quality
|
| 129 |
+
if context:
|
| 130 |
+
context_quality = self._calculate_context_quality(context = context)
|
| 131 |
+
|
| 132 |
+
# Poor context
|
| 133 |
+
if (context_quality < 0.5):
|
| 134 |
+
# More creative when context is poor
|
| 135 |
+
temp += 0.15
|
| 136 |
+
|
| 137 |
+
return self._clamp_temperature(temperature = temp)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _confidence_based_temperature(self, retrieval_scores: Optional[list], query_type: str) -> float:
|
| 141 |
+
"""
|
| 142 |
+
Temperature based on retrieval confidence
|
| 143 |
+
"""
|
| 144 |
+
if not retrieval_scores:
|
| 145 |
+
self.logger.debug("No retrieval scores, using base temperature")
|
| 146 |
+
return self.base_temperature
|
| 147 |
+
|
| 148 |
+
# Calculate average confidence
|
| 149 |
+
avg_confidence = sum(retrieval_scores) / len(retrieval_scores)
|
| 150 |
+
|
| 151 |
+
config = self.strategy_configs[TemperatureStrategy.CONFIDENCE]
|
| 152 |
+
high_temp = config["high_confidence_temp"]
|
| 153 |
+
low_temp = config["low_confidence_temp"]
|
| 154 |
+
|
| 155 |
+
# High confidence -> low temperature (deterministic) & Low confidence -> high temperature (creative)
|
| 156 |
+
if (avg_confidence > 0.8):
|
| 157 |
+
temperature = high_temp
|
| 158 |
+
|
| 159 |
+
elif (avg_confidence < 0.3):
|
| 160 |
+
temperature = low_temp
|
| 161 |
+
|
| 162 |
+
else:
|
| 163 |
+
# Linear interpolation between high and low temps
|
| 164 |
+
normalized_confidence = (avg_confidence - 0.3) / (0.8 - 0.3)
|
| 165 |
+
temperature = high_temp + (low_temp - high_temp) * (1 - normalized_confidence)
|
| 166 |
+
|
| 167 |
+
# Adjust for query type
|
| 168 |
+
if (query_type == "creative"):
|
| 169 |
+
temperature = min(0.9, temperature + 0.2)
|
| 170 |
+
|
| 171 |
+
elif (query_type == "qa"):
|
| 172 |
+
temperature = max(0.1, temperature - 0.1)
|
| 173 |
+
|
| 174 |
+
return self._clamp_temperature(temperature = temperature)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _progressive_temperature(self, query_type: str, query: str) -> float:
|
| 178 |
+
"""
|
| 179 |
+
Progressive temperature based on task requirements
|
| 180 |
+
"""
|
| 181 |
+
base_temp = self.base_temperature
|
| 182 |
+
|
| 183 |
+
# Task-based progression
|
| 184 |
+
if (query_type == "creative"):
|
| 185 |
+
# High creativity
|
| 186 |
+
return self._clamp_temperature(temperature = 0.8)
|
| 187 |
+
|
| 188 |
+
elif (query_type == "analytical"):
|
| 189 |
+
# Balanced
|
| 190 |
+
return self._clamp_temperature(temperature = 0.3)
|
| 191 |
+
|
| 192 |
+
elif (query_type == "qa"):
|
| 193 |
+
# For factual Q&A, use lower temperature
|
| 194 |
+
if self._is_factual_query(query):
|
| 195 |
+
return self._clamp_temperature(temperature = 0.1)
|
| 196 |
+
|
| 197 |
+
else:
|
| 198 |
+
return self._clamp_temperature(temperature = 0.4)
|
| 199 |
+
|
| 200 |
+
elif (query_type == "summary"):
|
| 201 |
+
# Deterministic summaries
|
| 202 |
+
return self._clamp_temperature(temperature = 0.2)
|
| 203 |
+
|
| 204 |
+
else:
|
| 205 |
+
return self._clamp_temperature(temperature = base_temp)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _calculate_query_complexity(self, query: str) -> float:
|
| 209 |
+
"""
|
| 210 |
+
Simple, predictable complexity score
|
| 211 |
+
"""
|
| 212 |
+
if not query:
|
| 213 |
+
return 0.5
|
| 214 |
+
|
| 215 |
+
# Count words and questions
|
| 216 |
+
words = len(query.split())
|
| 217 |
+
has_why_how = any(word in query.lower() for word in ['why', 'how', 'explain'])
|
| 218 |
+
has_compare = any(word in query.lower() for word in ['compare', 'contrast', 'difference'])
|
| 219 |
+
|
| 220 |
+
# Simple rules
|
| 221 |
+
if has_compare:
|
| 222 |
+
# Complex
|
| 223 |
+
return 0.8
|
| 224 |
+
|
| 225 |
+
elif (has_why_how and( words > 15)):
|
| 226 |
+
return 0.7
|
| 227 |
+
|
| 228 |
+
elif words > 20:
|
| 229 |
+
return 0.6
|
| 230 |
+
|
| 231 |
+
else:
|
| 232 |
+
# Simple
|
| 233 |
+
return 0.3
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _calculate_context_quality(self, context: str) -> float:
|
| 237 |
+
"""
|
| 238 |
+
Calculate context quality (0.0 - 1.0)
|
| 239 |
+
"""
|
| 240 |
+
if not context:
|
| 241 |
+
return 0.0
|
| 242 |
+
|
| 243 |
+
factors = list()
|
| 244 |
+
|
| 245 |
+
# Length factor (adequate context)
|
| 246 |
+
words = len(context.split())
|
| 247 |
+
|
| 248 |
+
# Normalize
|
| 249 |
+
length_factor = min(words / 500, 1.0)
|
| 250 |
+
|
| 251 |
+
factors.append(length_factor)
|
| 252 |
+
|
| 253 |
+
# Diversity factor (multiple sources/citations)
|
| 254 |
+
citation_count = context.count('[')
|
| 255 |
+
diversity_factor = min(citation_count / 5, 1.0)
|
| 256 |
+
|
| 257 |
+
factors.append(diversity_factor)
|
| 258 |
+
|
| 259 |
+
# Coherence factor (simple measure)
|
| 260 |
+
sentence_count = context.count('.')
|
| 261 |
+
|
| 262 |
+
if (sentence_count > 0):
|
| 263 |
+
avg_sentence_length = words / sentence_count
|
| 264 |
+
# Ideal ~20 words/sentence
|
| 265 |
+
coherence_factor = 1.0 - min(abs(avg_sentence_length - 20) / 50, 1.0)
|
| 266 |
+
|
| 267 |
+
factors.append(coherence_factor)
|
| 268 |
+
|
| 269 |
+
return sum(factors) / len(factors)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _is_factual_query(self, query: str) -> bool:
|
| 273 |
+
"""
|
| 274 |
+
Check if query is factual (requires precise answers)
|
| 275 |
+
"""
|
| 276 |
+
factual_indicators = ['what is', 'who is', 'when did', 'where is', 'how many', 'how much', 'definition of', 'meaning of', 'calculate', 'number of']
|
| 277 |
+
|
| 278 |
+
query_lower = query.lower()
|
| 279 |
+
|
| 280 |
+
return any(indicator in query_lower for indicator in factual_indicators)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _clamp_temperature(self, temperature: float) -> float:
|
| 284 |
+
"""
|
| 285 |
+
Clamp temperature to valid range
|
| 286 |
+
"""
|
| 287 |
+
strategy_config = self.strategy_configs.get(self.strategy, {})
|
| 288 |
+
temp_range = strategy_config.get("range", (0.0, 1.0))
|
| 289 |
+
|
| 290 |
+
clamped = max(temp_range[0], min(temperature, temp_range[1]))
|
| 291 |
+
|
| 292 |
+
# Round to 2 decimal places
|
| 293 |
+
clamped = round(clamped, 2)
|
| 294 |
+
|
| 295 |
+
return clamped
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def get_temperature_parameters(self, temperature: float) -> Dict[str, Any]:
|
| 299 |
+
"""
|
| 300 |
+
Get additional parameters based on temperature
|
| 301 |
+
|
| 302 |
+
Arguments:
|
| 303 |
+
----------
|
| 304 |
+
temperature { float } : Temperature value
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
--------
|
| 308 |
+
{ dict } : Additional generation parameters
|
| 309 |
+
"""
|
| 310 |
+
params = {"temperature" : temperature,
|
| 311 |
+
"top_p" : 0.9,
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
# Adjust top_p based on temperature
|
| 315 |
+
if (temperature < 0.3):
|
| 316 |
+
# Broader distribution for low temp
|
| 317 |
+
params["top_p"] = 0.95
|
| 318 |
+
|
| 319 |
+
elif (temperature > 0.7):
|
| 320 |
+
# Narrower distribution for high temp
|
| 321 |
+
params["top_p"] = 0.7
|
| 322 |
+
|
| 323 |
+
# Adjust presence_penalty based on temperature
|
| 324 |
+
if (temperature > 0.5):
|
| 325 |
+
# Encourage novelty for creative tasks
|
| 326 |
+
params["presence_penalty"] = 0.1
|
| 327 |
+
|
| 328 |
+
else:
|
| 329 |
+
params["presence_penalty"] = 0.0
|
| 330 |
+
|
| 331 |
+
return params
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def explain_temperature_choice(self, query: str, context: str, retrieval_scores: list, query_type: str, final_temperature: float) -> Dict[str, Any]:
|
| 335 |
+
"""
|
| 336 |
+
Explain why a particular temperature was chosen
|
| 337 |
+
|
| 338 |
+
Arguments:
|
| 339 |
+
----------
|
| 340 |
+
query { str } : User query
|
| 341 |
+
|
| 342 |
+
context { str } : Retrieved context
|
| 343 |
+
|
| 344 |
+
retrieval_scores { list } : Retrieval scores
|
| 345 |
+
|
| 346 |
+
query_type { str } : Query type
|
| 347 |
+
|
| 348 |
+
final_temperature { float } : Chosen temperature
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
--------
|
| 352 |
+
{ dict } : Explanation dictionary
|
| 353 |
+
"""
|
| 354 |
+
explanation = {"strategy" : self.strategy.value,
|
| 355 |
+
"final_temperature" : final_temperature,
|
| 356 |
+
"base_temperature" : self.base_temperature,
|
| 357 |
+
"factors" : {},
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
if (self.strategy == TemperatureStrategy.ADAPTIVE):
|
| 361 |
+
complexity = self._calculate_query_complexity(query = query)
|
| 362 |
+
context_quality = self._calculate_context_quality(context = context)
|
| 363 |
+
|
| 364 |
+
explanation["factors"] = {"query_complexity" : round(complexity, 3),
|
| 365 |
+
"context_quality" : round(context_quality, 3),
|
| 366 |
+
"query_type" : query_type,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
elif (self.strategy == TemperatureStrategy.CONFIDENCE):
|
| 370 |
+
if retrieval_scores:
|
| 371 |
+
avg_confidence = sum(retrieval_scores) / len(retrieval_scores)
|
| 372 |
+
explanation["factors"] = {"average_retrieval_confidence" : round(avg_confidence, 3),
|
| 373 |
+
"query_type" : query_type,
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
elif (self.strategy == TemperatureStrategy.PROGRESSIVE):
|
| 377 |
+
is_factual = self._is_factual_query(query)
|
| 378 |
+
explanation["factors"] = {"query_type" : query_type,
|
| 379 |
+
"is_factual_query" : is_factual,
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
return explanation
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# Global temperature controller instance
|
| 386 |
+
_temperature_controller = None
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def get_temperature_controller() -> TemperatureController:
|
| 390 |
+
"""
|
| 391 |
+
Get global temperature controller instance (singleton)
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
--------
|
| 395 |
+
{ TemperatureController } : TemperatureController instance
|
| 396 |
+
"""
|
| 397 |
+
global _temperature_controller
|
| 398 |
+
|
| 399 |
+
if _temperature_controller is None:
|
| 400 |
+
_temperature_controller = TemperatureController()
|
| 401 |
+
|
| 402 |
+
return _temperature_controller
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
@handle_errors(error_type=TemperatureControlError, log_error=True, reraise=False)
|
| 406 |
+
def get_adaptive_temperature(query: str = "", context: str = "", retrieval_scores: list = None, query_type: str = "qa") -> float:
|
| 407 |
+
"""
|
| 408 |
+
Convenience function for getting adaptive temperature
|
| 409 |
+
|
| 410 |
+
Arguments:
|
| 411 |
+
----------
|
| 412 |
+
query { str } : User query
|
| 413 |
+
|
| 414 |
+
context { str } : Retrieved context
|
| 415 |
+
|
| 416 |
+
retrieval_scores { list } : Retrieval scores
|
| 417 |
+
|
| 418 |
+
query_type { str } : Query type
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
--------
|
| 422 |
+
{ float } : Temperature value
|
| 423 |
+
"""
|
| 424 |
+
controller = get_temperature_controller()
|
| 425 |
+
|
| 426 |
+
return controller.get_temperature(query = query,
|
| 427 |
+
context = context,
|
| 428 |
+
retrieval_scores = retrieval_scores,
|
| 429 |
+
query_type = query_type,
|
| 430 |
+
)
|
generation/token_manager.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import tiktoken
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from config.settings import get_settings
|
| 7 |
+
from config.logging_config import get_logger
|
| 8 |
+
from utils.error_handler import handle_errors
|
| 9 |
+
from utils.error_handler import TokenManagementError
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Setup Settings and Logging
|
| 13 |
+
settings = get_settings()
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TokenManager:
|
| 18 |
+
"""
|
| 19 |
+
Token management for LLM context windows: Handles token counting, context window management, and optimization for different LLM providers (Ollama, OpenAI)
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, model_name: str = None):
|
| 22 |
+
"""
|
| 23 |
+
Initialize token manager
|
| 24 |
+
|
| 25 |
+
Arguments:
|
| 26 |
+
----------
|
| 27 |
+
model_name { str } : Model name for tokenizer selection
|
| 28 |
+
"""
|
| 29 |
+
self.logger = logger
|
| 30 |
+
self.settings = get_settings()
|
| 31 |
+
self.model_name = model_name or self.settings.OLLAMA_MODEL
|
| 32 |
+
self.encoding = None
|
| 33 |
+
self.context_window = self._get_context_window()
|
| 34 |
+
|
| 35 |
+
self._initialize_tokenizer()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _initialize_tokenizer(self):
|
| 39 |
+
"""
|
| 40 |
+
Initialize appropriate tokenizer based on model
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
# Determine tokenizer based on model
|
| 44 |
+
if self.model_name.startswith(('gpt-3.5', 'gpt-4')):
|
| 45 |
+
# OpenAI models
|
| 46 |
+
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
| 47 |
+
self.logger.debug(f"Initialized tiktoken for {self.model_name}")
|
| 48 |
+
|
| 49 |
+
else:
|
| 50 |
+
# Default for Ollama/local models
|
| 51 |
+
self.encoding = tiktoken.get_encoding("cl100k_base")
|
| 52 |
+
self.logger.debug(f"Using cl100k_base tokenizer for local model {self.model_name}")
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
self.logger.warning(f"Failed to initialize specific tokenizer: {repr(e)}, using approximation")
|
| 56 |
+
self.encoding = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_context_window(self) -> int:
|
| 60 |
+
"""
|
| 61 |
+
Get context window size based on model
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
--------
|
| 65 |
+
{ int } : Context window size in tokens
|
| 66 |
+
"""
|
| 67 |
+
model_contexts = {"gpt-3.5-turbo" : 4096,
|
| 68 |
+
"mistral:7b" : 8192,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Find matching model
|
| 72 |
+
for model_pattern, context_size in model_contexts.items():
|
| 73 |
+
if model_pattern in self.model_name.lower():
|
| 74 |
+
return context_size
|
| 75 |
+
|
| 76 |
+
# Default context window
|
| 77 |
+
default_context = self.settings.CONTEXT_WINDOW
|
| 78 |
+
self.logger.info(f"Using default context window {default_context} for model {self.model_name}")
|
| 79 |
+
|
| 80 |
+
return default_context
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def count_tokens(self, text: str) -> int:
|
| 84 |
+
"""
|
| 85 |
+
Count tokens in text
|
| 86 |
+
|
| 87 |
+
Arguments:
|
| 88 |
+
----------
|
| 89 |
+
text { str } : Input text
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
--------
|
| 93 |
+
{ int } : Number of tokens
|
| 94 |
+
"""
|
| 95 |
+
if not text:
|
| 96 |
+
return 0
|
| 97 |
+
|
| 98 |
+
if self.encoding is not None:
|
| 99 |
+
try:
|
| 100 |
+
return len(self.encoding.encode(text))
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
self.logger.warning(f"Tokenizer failed, using approximation: {repr(e)}")
|
| 104 |
+
|
| 105 |
+
# Fallback approximation
|
| 106 |
+
return self._approximate_token_count(text = text)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _approximate_token_count(self, text: str) -> int:
|
| 110 |
+
"""
|
| 111 |
+
Approximate token count when tokenizer is unavailable
|
| 112 |
+
|
| 113 |
+
Arguments:
|
| 114 |
+
----------
|
| 115 |
+
text { str } : Input text
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
--------
|
| 119 |
+
{ int } : Approximate token count
|
| 120 |
+
"""
|
| 121 |
+
if not text:
|
| 122 |
+
return 0
|
| 123 |
+
|
| 124 |
+
# Use word-based approximation (more reliable than char-based)
|
| 125 |
+
words = text.split()
|
| 126 |
+
|
| 127 |
+
# English text averages ~1.3 tokens per word : (accounting for punctuation and subword tokenization)
|
| 128 |
+
estimated_tokens = int(len(words) * 1.3)
|
| 129 |
+
|
| 130 |
+
# Add 5% buffer for punctuation and special tokens
|
| 131 |
+
estimated_tokens = int(estimated_tokens * 1.05)
|
| 132 |
+
|
| 133 |
+
return estimated_tokens
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def count_message_tokens(self, messages: List[Dict]) -> int:
|
| 137 |
+
"""
|
| 138 |
+
Count tokens in chat messages
|
| 139 |
+
|
| 140 |
+
Arguments:
|
| 141 |
+
----------
|
| 142 |
+
messages { list } : List of message dictionaries
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
--------
|
| 146 |
+
{ int } : Total tokens in messages
|
| 147 |
+
"""
|
| 148 |
+
if not messages:
|
| 149 |
+
return 0
|
| 150 |
+
|
| 151 |
+
total_tokens = 0
|
| 152 |
+
|
| 153 |
+
for message in messages:
|
| 154 |
+
# Count content tokens
|
| 155 |
+
content = message.get('content', '')
|
| 156 |
+
total_tokens += self.count_tokens(text = content)
|
| 157 |
+
|
| 158 |
+
# Count role tokens (approximate)
|
| 159 |
+
role = message.get('role', '')
|
| 160 |
+
total_tokens += self.count_tokens(text = role)
|
| 161 |
+
|
| 162 |
+
# Add overhead for message structure: Approximate overhead per message
|
| 163 |
+
total_tokens += 5
|
| 164 |
+
|
| 165 |
+
return total_tokens
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def fits_in_context(self, prompt: str, max_completion_tokens: int = 1000) -> bool:
|
| 169 |
+
"""
|
| 170 |
+
Check if prompt fits in context window with room for completion
|
| 171 |
+
|
| 172 |
+
Arguments:
|
| 173 |
+
----------
|
| 174 |
+
prompt { str } : Prompt text
|
| 175 |
+
|
| 176 |
+
max_completion_tokens { int } : Tokens to reserve for completion
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
--------
|
| 180 |
+
{ bool } : True if prompt fits
|
| 181 |
+
"""
|
| 182 |
+
prompt_tokens = self.count_tokens(text = prompt)
|
| 183 |
+
total_required = prompt_tokens + max_completion_tokens
|
| 184 |
+
|
| 185 |
+
return (total_required <= self.context_window)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def truncate_to_fit(self, text: str, max_tokens: int, strategy: str = "end") -> str:
|
| 189 |
+
"""
|
| 190 |
+
Truncate text to fit within token limit
|
| 191 |
+
|
| 192 |
+
Arguments:
|
| 193 |
+
----------
|
| 194 |
+
text { str } : Text to truncate
|
| 195 |
+
|
| 196 |
+
max_tokens { int } : Maximum tokens allowed
|
| 197 |
+
|
| 198 |
+
strategy { str } : Truncation strategy ('end', 'start', 'middle')
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
--------
|
| 202 |
+
{ str } : Truncated text
|
| 203 |
+
"""
|
| 204 |
+
current_tokens = self.count_tokens(text = text)
|
| 205 |
+
|
| 206 |
+
if (current_tokens <= max_tokens):
|
| 207 |
+
return text
|
| 208 |
+
|
| 209 |
+
if (strategy == "end"):
|
| 210 |
+
return self._truncate_from_end(text = text,
|
| 211 |
+
max_tokens = max_tokens,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
elif (strategy == "start"):
|
| 215 |
+
return self._truncate_from_start(text = text,
|
| 216 |
+
max_tokens = max_tokens,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
elif (strategy == "middle"):
|
| 220 |
+
return self._truncate_from_middle(text = text,
|
| 221 |
+
max_tokens = max_tokens,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
else:
|
| 225 |
+
self.logger.warning(f"Unknown truncation strategy: {strategy}, using 'end'")
|
| 226 |
+
return self._truncate_from_end(text = text,
|
| 227 |
+
max_tokens = max_tokens,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _truncate_from_end(self, text: str, max_tokens: int) -> str:
|
| 232 |
+
"""
|
| 233 |
+
Truncate from the end of the text
|
| 234 |
+
"""
|
| 235 |
+
if self.encoding is not None:
|
| 236 |
+
tokens = self.encoding.encode(text)
|
| 237 |
+
truncated_tokens = tokens[:max_tokens]
|
| 238 |
+
return self.encoding.decode(truncated_tokens)
|
| 239 |
+
|
| 240 |
+
# Approximate truncation
|
| 241 |
+
words = text.split()
|
| 242 |
+
# Conservative estimate
|
| 243 |
+
target_words = int(max_tokens * 0.75)
|
| 244 |
+
truncated_words = words[:target_words]
|
| 245 |
+
|
| 246 |
+
return " ".join(truncated_words)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _truncate_from_start(self, text: str, max_tokens: int) -> str:
|
| 250 |
+
"""
|
| 251 |
+
Truncate from the start of the text
|
| 252 |
+
"""
|
| 253 |
+
if self.encoding is not None:
|
| 254 |
+
tokens = self.encoding.encode(text)
|
| 255 |
+
# Take from end
|
| 256 |
+
truncated_tokens = tokens[-max_tokens:]
|
| 257 |
+
return self.encoding.decode(truncated_tokens)
|
| 258 |
+
|
| 259 |
+
# Approximate truncation
|
| 260 |
+
words = text.split()
|
| 261 |
+
target_words = int(max_tokens * 0.75)
|
| 262 |
+
|
| 263 |
+
# Take from end
|
| 264 |
+
truncated_words = words[-target_words:]
|
| 265 |
+
|
| 266 |
+
return " ".join(truncated_words)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _truncate_from_middle(self, text: str, max_tokens: int) -> str:
|
| 270 |
+
"""
|
| 271 |
+
Truncate from the middle of the text
|
| 272 |
+
"""
|
| 273 |
+
if self.encoding is not None:
|
| 274 |
+
tokens = self.encoding.encode(text)
|
| 275 |
+
|
| 276 |
+
if (len(tokens) <= max_tokens):
|
| 277 |
+
return text
|
| 278 |
+
|
| 279 |
+
# Keep beginning and end, remove middle
|
| 280 |
+
keep_start = max_tokens // 3
|
| 281 |
+
keep_end = max_tokens - keep_start
|
| 282 |
+
|
| 283 |
+
start_tokens = tokens[:keep_start]
|
| 284 |
+
end_tokens = tokens[-keep_end:]
|
| 285 |
+
|
| 286 |
+
return self.encoding.decode(start_tokens) + " [...] " + self.encoding.decode(end_tokens)
|
| 287 |
+
|
| 288 |
+
# Approximate truncation
|
| 289 |
+
words = text.split()
|
| 290 |
+
|
| 291 |
+
if (len(words) <= max_tokens):
|
| 292 |
+
return text
|
| 293 |
+
|
| 294 |
+
keep_start = max_tokens // 3
|
| 295 |
+
keep_end = max_tokens - keep_start
|
| 296 |
+
|
| 297 |
+
start_words = words[:keep_start]
|
| 298 |
+
end_words = words[-keep_end:]
|
| 299 |
+
|
| 300 |
+
return " ".join(start_words) + " [...] " + " ".join(end_words)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def calculate_max_completion_tokens(self, prompt: str, reserve_tokens: int = 100) -> int:
|
| 304 |
+
"""
|
| 305 |
+
Calculate maximum completion tokens given prompt length
|
| 306 |
+
|
| 307 |
+
Arguments:
|
| 308 |
+
----------
|
| 309 |
+
prompt { str } : Prompt text
|
| 310 |
+
|
| 311 |
+
reserve_tokens { int } : Tokens to reserve for safety
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
--------
|
| 315 |
+
{ int } : Maximum completion tokens
|
| 316 |
+
"""
|
| 317 |
+
prompt_tokens = self.count_tokens(text = prompt)
|
| 318 |
+
available_tokens = self.context_window - prompt_tokens - reserve_tokens
|
| 319 |
+
|
| 320 |
+
return max(0, available_tokens)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def optimize_context_usage(self, context: str, prompt: str, max_completion_tokens: int = 1000) -> str:
|
| 324 |
+
"""
|
| 325 |
+
Optimize context to fit within context window
|
| 326 |
+
|
| 327 |
+
Arguments:
|
| 328 |
+
----------
|
| 329 |
+
context { str } : Context text
|
| 330 |
+
|
| 331 |
+
prompt { str } : Prompt template
|
| 332 |
+
|
| 333 |
+
max_completion_tokens { int } : Tokens needed for completion
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
--------
|
| 337 |
+
{ str } : Optimized context
|
| 338 |
+
"""
|
| 339 |
+
total_prompt_tokens = self.count_tokens(text = prompt.format(context=""))
|
| 340 |
+
available_for_context = self.context_window - total_prompt_tokens - max_completion_tokens
|
| 341 |
+
|
| 342 |
+
if (available_for_context <= 0):
|
| 343 |
+
self.logger.warning("Prompt too large for context window")
|
| 344 |
+
return ""
|
| 345 |
+
|
| 346 |
+
context_tokens = self.count_tokens(text = context)
|
| 347 |
+
|
| 348 |
+
if (context_tokens <= available_for_context):
|
| 349 |
+
return context
|
| 350 |
+
|
| 351 |
+
# Truncate context to fit
|
| 352 |
+
optimized_context = self.truncate_to_fit(context, available_for_context, strategy="end")
|
| 353 |
+
|
| 354 |
+
reduction_pct = ((context_tokens - self.count_tokens(text = optimized_context)) / context_tokens) * 100
|
| 355 |
+
|
| 356 |
+
self.logger.info(f"Context reduced by {reduction_pct:.1f}% to fit context window")
|
| 357 |
+
|
| 358 |
+
return optimized_context
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def get_token_stats(self, text: str) -> Dict:
|
| 362 |
+
"""
|
| 363 |
+
Get detailed token statistics
|
| 364 |
+
|
| 365 |
+
Arguments:
|
| 366 |
+
----------
|
| 367 |
+
text { str } : Input text
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
--------
|
| 371 |
+
{ dict } : Token statistics
|
| 372 |
+
"""
|
| 373 |
+
tokens = self.count_tokens(text = text)
|
| 374 |
+
chars = len(text)
|
| 375 |
+
words = len(text.split())
|
| 376 |
+
|
| 377 |
+
return {"tokens" : tokens,
|
| 378 |
+
"characters" : chars,
|
| 379 |
+
"words" : words,
|
| 380 |
+
"chars_per_token" : chars / tokens if tokens > 0 else 0,
|
| 381 |
+
"tokens_per_word" : tokens / words if words > 0 else 0,
|
| 382 |
+
"context_window" : self.context_window,
|
| 383 |
+
"model" : self.model_name,
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# Global token manager instance
|
| 388 |
+
_token_manager = None
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def get_token_manager(model_name: str = None) -> TokenManager:
|
| 392 |
+
"""
|
| 393 |
+
Get global token manager instance
|
| 394 |
+
|
| 395 |
+
Arguments:
|
| 396 |
+
----------
|
| 397 |
+
model_name { str } : Model name for tokenizer selection
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
--------
|
| 401 |
+
{ TokenManager } : TokenManager instance
|
| 402 |
+
"""
|
| 403 |
+
global _token_manager
|
| 404 |
+
|
| 405 |
+
if _token_manager is None or (model_name and _token_manager.model_name != model_name):
|
| 406 |
+
_token_manager = TokenManager(model_name)
|
| 407 |
+
|
| 408 |
+
return _token_manager
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@handle_errors(error_type = TokenManagementError, log_error = True, reraise = False)
|
| 412 |
+
def count_tokens_safe(text: str, model_name: str = None) -> int:
|
| 413 |
+
"""
|
| 414 |
+
Safe token counting with error handling
|
| 415 |
+
|
| 416 |
+
Arguments:
|
| 417 |
+
----------
|
| 418 |
+
text { str } : Text to count tokens for
|
| 419 |
+
|
| 420 |
+
model_name { str } : Model name for tokenizer
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
--------
|
| 424 |
+
{ int } : Token count (0 on error)
|
| 425 |
+
"""
|
| 426 |
+
try:
|
| 427 |
+
manager = get_token_manager(model_name = model_name)
|
| 428 |
+
return manager.count_tokens(text = text)
|
| 429 |
+
|
| 430 |
+
except Exception:
|
| 431 |
+
return 0
|
ingestion/__init__.py
ADDED
|
File without changes
|
ingestion/async_coordinator.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import asyncio
|
| 3 |
+
from typing import Any
|
| 4 |
+
from typing import List
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import concurrent.futures
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from config.settings import get_settings
|
| 11 |
+
from config.logging_config import get_logger
|
| 12 |
+
from utils.error_handler import handle_errors
|
| 13 |
+
from utils.error_handler import ProcessingException
|
| 14 |
+
from ingestion.progress_tracker import get_progress_tracker
|
| 15 |
+
from document_parser.parser_factory import get_parser_factory
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Setup Settings and Logging
|
| 19 |
+
settings = get_settings()
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AsyncCoordinator:
|
| 24 |
+
"""
|
| 25 |
+
Asynchronous document processing coordinator: Manages parallel processing of multiple documents with resource optimization
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, max_workers: Optional[int] = None):
|
| 28 |
+
"""
|
| 29 |
+
Initialize async coordinator
|
| 30 |
+
|
| 31 |
+
Arguments:
|
| 32 |
+
----------
|
| 33 |
+
max_workers { int } : Maximum parallel workers (default from settings)
|
| 34 |
+
"""
|
| 35 |
+
self.logger = logger
|
| 36 |
+
self.max_workers = max_workers or settings.MAX_WORKERS
|
| 37 |
+
self.parser_factory = get_parser_factory()
|
| 38 |
+
self.progress_tracker = get_progress_tracker()
|
| 39 |
+
|
| 40 |
+
# Processing statistics
|
| 41 |
+
self.total_processed = 0
|
| 42 |
+
self.total_failed = 0
|
| 43 |
+
self.avg_processing_time = 0.0
|
| 44 |
+
|
| 45 |
+
self.logger.info(f"Initialized AsyncCoordinator: max_workers={self.max_workers}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@handle_errors(error_type = ProcessingException, log_error = True, reraise = True)
|
| 49 |
+
async def process_documents_async(self, file_paths: List[Path], progress_callback: Optional[callable] = None) -> Dict[str, Any]:
|
| 50 |
+
"""
|
| 51 |
+
Process multiple documents asynchronously with progress tracking
|
| 52 |
+
|
| 53 |
+
Arguments:
|
| 54 |
+
----------
|
| 55 |
+
file_paths { list } : List of file paths to process
|
| 56 |
+
|
| 57 |
+
progress_callback { callable } : Callback for progress updates
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
--------
|
| 61 |
+
{ dict } : Processing results
|
| 62 |
+
"""
|
| 63 |
+
if not file_paths:
|
| 64 |
+
return {"processed" : 0,
|
| 65 |
+
"failed" : 0,
|
| 66 |
+
"results" : [],
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
self.logger.info(f"Starting async processing of {len(file_paths)} documents")
|
| 70 |
+
|
| 71 |
+
# Initialize progress tracking
|
| 72 |
+
task_id = self.progress_tracker.start_task(total_items = len(file_paths),
|
| 73 |
+
description = "Document processing",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
# Process files in parallel with semaphore for resource control
|
| 78 |
+
semaphore = asyncio.Semaphore(self.max_workers)
|
| 79 |
+
tasks = [self._process_single_file_async(file_path = file_path,
|
| 80 |
+
semaphore = semaphore,
|
| 81 |
+
task_id = task_id,
|
| 82 |
+
progress_callback = progress_callback,
|
| 83 |
+
) for file_path in file_paths]
|
| 84 |
+
|
| 85 |
+
results = await asyncio.gather(*tasks, return_exceptions = True)
|
| 86 |
+
|
| 87 |
+
# Process results
|
| 88 |
+
processed = list()
|
| 89 |
+
failed = list()
|
| 90 |
+
|
| 91 |
+
for file_path, result in zip(file_paths, results):
|
| 92 |
+
if isinstance(result, Exception):
|
| 93 |
+
self.logger.error(f"Failed to process {file_path}: {repr(result)}")
|
| 94 |
+
failed.append({"file_path": file_path, "error": str(result)})
|
| 95 |
+
self.total_failed += 1
|
| 96 |
+
|
| 97 |
+
else:
|
| 98 |
+
processed.append(result)
|
| 99 |
+
self.total_processed += 1
|
| 100 |
+
|
| 101 |
+
# Update statistics
|
| 102 |
+
self._update_statistics(processed_count = len(processed))
|
| 103 |
+
|
| 104 |
+
final_result = {"processed" : len(processed),
|
| 105 |
+
"failed" : len(failed),
|
| 106 |
+
"success_rate" : (len(processed) / len(file_paths)) * 100,
|
| 107 |
+
"results" : processed,
|
| 108 |
+
"failures" : failed,
|
| 109 |
+
"task_id" : task_id,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
self.progress_tracker.complete_task(task_id)
|
| 113 |
+
|
| 114 |
+
self.logger.info(f"Async processing completed: {len(processed)} successful, {len(failed)} failed")
|
| 115 |
+
|
| 116 |
+
return final_result
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
self.progress_tracker.fail_task(task_id, str(e))
|
| 120 |
+
raise ProcessingException(f"Async processing failed: {repr(e)}")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
async def _process_single_file_async(self, file_path: Path, semaphore: asyncio.Semaphore, task_id: str, progress_callback: Optional[callable] = None) -> Dict[str, Any]:
|
| 124 |
+
"""
|
| 125 |
+
Process single file asynchronously
|
| 126 |
+
|
| 127 |
+
Arguments:
|
| 128 |
+
----------
|
| 129 |
+
file_path { Path } : File to process
|
| 130 |
+
|
| 131 |
+
semaphore { Semaphore } : Resource semaphore
|
| 132 |
+
|
| 133 |
+
task_id { str } : Progress task ID
|
| 134 |
+
|
| 135 |
+
progress_callback { callable } : Progress callback
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
--------
|
| 139 |
+
{ dict } : Processing result
|
| 140 |
+
"""
|
| 141 |
+
async with semaphore:
|
| 142 |
+
try:
|
| 143 |
+
self.logger.debug(f"Processing file: {file_path}")
|
| 144 |
+
|
| 145 |
+
# Update progress
|
| 146 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 147 |
+
current_item = file_path.name,
|
| 148 |
+
current_status = "parsing",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
processed_so_far = len([t for t in self.progress_tracker.active_tasks.get(task_id, {}) if t])
|
| 152 |
+
self.progress_tracker.update_task(task_id = task_id, processed_items = processed_so_far)
|
| 153 |
+
|
| 154 |
+
if progress_callback:
|
| 155 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 156 |
+
|
| 157 |
+
# Parse document
|
| 158 |
+
start_time = datetime.now()
|
| 159 |
+
text, metadata = await asyncio.to_thread(self.parser_factory.parse,
|
| 160 |
+
file_path = file_path,
|
| 161 |
+
extract_metadata = True,
|
| 162 |
+
clean_text = True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
processing_time = (datetime.now() - start_time).total_seconds()
|
| 166 |
+
|
| 167 |
+
# Update progress
|
| 168 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 169 |
+
current_item = file_path.name,
|
| 170 |
+
current_status = "completed",
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Handle metadata (could be Pydantic model or dict)
|
| 174 |
+
metadata_dict = metadata.dict() if hasattr(metadata, 'dict') else (metadata if isinstance(metadata, dict) else {})
|
| 175 |
+
|
| 176 |
+
result = {"file_path" : str(file_path),
|
| 177 |
+
"file_name" : file_path.name,
|
| 178 |
+
"text_length" : len(text),
|
| 179 |
+
"processing_time" : processing_time,
|
| 180 |
+
"metadata" : metadata_dict,
|
| 181 |
+
"success" : True,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# Increment processed items count
|
| 185 |
+
current_progress = self.progress_tracker.get_task_progress(task_id)
|
| 186 |
+
|
| 187 |
+
if current_progress:
|
| 188 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 189 |
+
processed_items = current_progress.processed_items + 1,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if progress_callback:
|
| 193 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 194 |
+
|
| 195 |
+
return result
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
self.logger.error(f"Failed to process {file_path}: {repr(e)}")
|
| 199 |
+
|
| 200 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 201 |
+
current_item = file_path.name,
|
| 202 |
+
current_status = f"failed: {str(e)}",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
raise ProcessingException(f"File processing failed: {repr(e)}")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def process_documents_threaded(self, file_paths: List[Path], progress_callback: Optional[callable] = None) -> Dict[str, Any]:
|
| 209 |
+
"""
|
| 210 |
+
Process documents using thread pool (alternative to async)
|
| 211 |
+
|
| 212 |
+
Arguments:
|
| 213 |
+
----------
|
| 214 |
+
file_paths { list } : List of file paths
|
| 215 |
+
|
| 216 |
+
progress_callback { callable } : Progress callback
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
--------
|
| 220 |
+
{ dict } : Processing results
|
| 221 |
+
"""
|
| 222 |
+
self.logger.info(f"Starting threaded processing of {len(file_paths)} documents")
|
| 223 |
+
|
| 224 |
+
task_id = self.progress_tracker.start_task(total_items = len(file_paths),
|
| 225 |
+
description = "Threaded document processing",
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers = self.max_workers) as executor:
|
| 230 |
+
# Submit all tasks
|
| 231 |
+
future_to_file = {executor.submit(self._process_single_file_sync, file_path, task_id, progress_callback): file_path for file_path in file_paths}
|
| 232 |
+
|
| 233 |
+
results = list()
|
| 234 |
+
failed = list()
|
| 235 |
+
|
| 236 |
+
for future in concurrent.futures.as_completed(future_to_file):
|
| 237 |
+
file_path = future_to_file[future]
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
result = future.result()
|
| 241 |
+
results.append(result)
|
| 242 |
+
self.total_processed += 1
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
self.logger.error(f"Failed to process {file_path}: {repr(e)}")
|
| 246 |
+
failed.append({"file_path": file_path, "error": str(e)})
|
| 247 |
+
self.total_failed += 1
|
| 248 |
+
|
| 249 |
+
final_result = {"processed" : len(results),
|
| 250 |
+
"failed" : len(failed),
|
| 251 |
+
"success_rate" : (len(results) / len(file_paths)) * 100,
|
| 252 |
+
"results" : results,
|
| 253 |
+
"failures" : failed,
|
| 254 |
+
"task_id" : task_id,
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
self.progress_tracker.complete_task(task_id)
|
| 258 |
+
|
| 259 |
+
self.logger.info(f"Threaded processing completed: {len(results)} successful, {len(failed)} failed")
|
| 260 |
+
|
| 261 |
+
return final_result
|
| 262 |
+
|
| 263 |
+
except Exception as e:
|
| 264 |
+
self.progress_tracker.fail_task(task_id, str(e))
|
| 265 |
+
raise ProcessingException(f"Threaded processing failed: {repr(e)}")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _process_single_file_sync(self, file_path: Path, task_id: str, progress_callback: Optional[callable] = None) -> Dict[str, Any]:
|
| 269 |
+
"""
|
| 270 |
+
Process single file synchronously (for thread pool)
|
| 271 |
+
|
| 272 |
+
Arguments:
|
| 273 |
+
----------
|
| 274 |
+
file_path { Path } : File to process
|
| 275 |
+
|
| 276 |
+
task_id { str } : Progress task ID
|
| 277 |
+
|
| 278 |
+
progress_callback { callable } : Progress callback
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
--------
|
| 282 |
+
{ dict } : Processing result
|
| 283 |
+
"""
|
| 284 |
+
self.logger.debug(f"Processing file (sync): {file_path}")
|
| 285 |
+
|
| 286 |
+
# Update progress
|
| 287 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 288 |
+
current_item = file_path.name,
|
| 289 |
+
current_status = "parsing",
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if progress_callback:
|
| 293 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 294 |
+
|
| 295 |
+
# Parse document
|
| 296 |
+
start_time = datetime.now()
|
| 297 |
+
text, metadata = self.parser_factory.parse(file_path = file_path,
|
| 298 |
+
extract_metadata = True,
|
| 299 |
+
clean_text = True,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
processing_time = (datetime.now() - start_time).total_seconds()
|
| 303 |
+
|
| 304 |
+
# Update progress
|
| 305 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 306 |
+
current_item = file_path.name,
|
| 307 |
+
current_status = "completed",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Handle metadata (could be Pydantic model or dict)
|
| 311 |
+
metadata_dict = metadata.dict() if hasattr(metadata, 'dict') else (metadata if isinstance(metadata, dict) else {})
|
| 312 |
+
|
| 313 |
+
result = {"file_path" : str(file_path),
|
| 314 |
+
"file_name" : file_path.name,
|
| 315 |
+
"text_length" : len(text),
|
| 316 |
+
"processing_time" : processing_time,
|
| 317 |
+
"metadata" : metadata_dict,
|
| 318 |
+
"success" : True,
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
# Increment processed items count
|
| 322 |
+
current_progress = self.progress_tracker.get_task_progress(task_id)
|
| 323 |
+
|
| 324 |
+
if current_progress:
|
| 325 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 326 |
+
processed_items = current_progress.processed_items + 1,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if progress_callback:
|
| 330 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 331 |
+
|
| 332 |
+
return result
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _update_statistics(self, processed_count: int):
|
| 336 |
+
"""
|
| 337 |
+
Update processing statistics
|
| 338 |
+
|
| 339 |
+
Arguments:
|
| 340 |
+
----------
|
| 341 |
+
processed_count { int } : Number of documents processed in current batch
|
| 342 |
+
"""
|
| 343 |
+
# Update average processing time (simplified)
|
| 344 |
+
if (processed_count > 0):
|
| 345 |
+
self.avg_processing_time = (self.avg_processing_time + (processed_count * 1.0)) / 2
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def get_coordinator_stats(self) -> Dict[str, Any]:
|
| 349 |
+
"""
|
| 350 |
+
Get coordinator statistics
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
--------
|
| 354 |
+
{ dict } : Statistics dictionary
|
| 355 |
+
"""
|
| 356 |
+
return {"total_processed" : self.total_processed,
|
| 357 |
+
"total_failed" : self.total_failed,
|
| 358 |
+
"success_rate" : (self.total_processed / (self.total_processed + self.total_failed)) * 100 if (self.total_processed + self.total_failed) > 0 else 0,
|
| 359 |
+
"avg_processing_time" : self.avg_processing_time,
|
| 360 |
+
"max_workers" : self.max_workers,
|
| 361 |
+
"active_tasks" : self.progress_tracker.get_active_task_count(),
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def cleanup(self):
|
| 366 |
+
"""
|
| 367 |
+
Cleanup resources
|
| 368 |
+
"""
|
| 369 |
+
self.progress_tracker.cleanup_completed_tasks()
|
| 370 |
+
|
| 371 |
+
self.logger.debug("AsyncCoordinator cleanup completed")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# Global async coordinator instance
|
| 375 |
+
_async_coordinator = None
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def get_async_coordinator(max_workers: Optional[int] = None) -> AsyncCoordinator:
|
| 379 |
+
"""
|
| 380 |
+
Get global async coordinator instance
|
| 381 |
+
|
| 382 |
+
Arguments:
|
| 383 |
+
----------
|
| 384 |
+
max_workers { int } : Maximum workers
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
--------
|
| 388 |
+
{ AsyncCoordinator } : AsyncCoordinator instance
|
| 389 |
+
"""
|
| 390 |
+
global _async_coordinator
|
| 391 |
+
|
| 392 |
+
if _async_coordinator is None:
|
| 393 |
+
_async_coordinator = AsyncCoordinator(max_workers)
|
| 394 |
+
|
| 395 |
+
return _async_coordinator
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
async def process_documents_async(file_paths: List[Path], **kwargs) -> Dict[str, Any]:
|
| 399 |
+
"""
|
| 400 |
+
Convenience function for async document processing
|
| 401 |
+
|
| 402 |
+
Arguments:
|
| 403 |
+
----------
|
| 404 |
+
file_paths { list } : List of file paths
|
| 405 |
+
|
| 406 |
+
**kwargs : Additional arguments
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
--------
|
| 410 |
+
{ dict } : Processing results
|
| 411 |
+
"""
|
| 412 |
+
coordinator = get_async_coordinator()
|
| 413 |
+
|
| 414 |
+
return await coordinator.process_documents_async(file_paths, **kwargs)
|
ingestion/progress_tracker.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Any
|
| 6 |
+
from typing import Dict
|
| 7 |
+
from typing import List
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from config.settings import get_settings
|
| 12 |
+
from config.logging_config import get_logger
|
| 13 |
+
from utils.error_handler import handle_errors
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Setup Settings and Logging
|
| 17 |
+
settings = get_settings()
|
| 18 |
+
logger = get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TaskStatus(str, Enum):
|
| 22 |
+
"""
|
| 23 |
+
Task status enumeration
|
| 24 |
+
"""
|
| 25 |
+
PENDING = "pending"
|
| 26 |
+
RUNNING = "running"
|
| 27 |
+
COMPLETED = "completed"
|
| 28 |
+
FAILED = "failed"
|
| 29 |
+
CANCELLED = "cancelled"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class TaskProgress:
|
| 34 |
+
"""
|
| 35 |
+
Progress tracking data structure
|
| 36 |
+
"""
|
| 37 |
+
task_id : str
|
| 38 |
+
description : str
|
| 39 |
+
status : TaskStatus
|
| 40 |
+
total_items : int
|
| 41 |
+
processed_items : int
|
| 42 |
+
current_item : Optional[str]
|
| 43 |
+
current_status : str
|
| 44 |
+
start_time : datetime
|
| 45 |
+
end_time : Optional[datetime]
|
| 46 |
+
progress_percent : float
|
| 47 |
+
estimated_seconds_remaining : Optional[float]
|
| 48 |
+
metadata : Dict[str, Any]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ProgressTracker:
|
| 52 |
+
"""
|
| 53 |
+
Comprehensive progress tracking for long-running operations: Provides real-time progress monitoring and status updates
|
| 54 |
+
"""
|
| 55 |
+
def __init__(self, max_completed_tasks: int = 100):
|
| 56 |
+
"""
|
| 57 |
+
Initialize progress tracker
|
| 58 |
+
|
| 59 |
+
Arguments:
|
| 60 |
+
----------
|
| 61 |
+
max_completed_tasks { int } : Maximum number of completed tasks to keep in history
|
| 62 |
+
"""
|
| 63 |
+
self.logger = logger
|
| 64 |
+
self.max_completed_tasks = max_completed_tasks
|
| 65 |
+
|
| 66 |
+
# Task storage
|
| 67 |
+
self.active_tasks : Dict[str, TaskProgress] = dict()
|
| 68 |
+
self.completed_tasks : List[TaskProgress] = list()
|
| 69 |
+
self.failed_tasks : List[TaskProgress] = list()
|
| 70 |
+
|
| 71 |
+
self.logger.info("Initialized ProgressTracker")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def start_task(self, total_items: int, description: str, metadata: Optional[Dict[str, Any]] = None) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Start tracking a new task
|
| 77 |
+
|
| 78 |
+
Arguments:
|
| 79 |
+
----------
|
| 80 |
+
total_items { int } : Total number of items to process
|
| 81 |
+
|
| 82 |
+
description { str } : Task description
|
| 83 |
+
|
| 84 |
+
metadata { dict } : Additional task metadata
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
--------
|
| 88 |
+
{ str } : Generated task ID
|
| 89 |
+
"""
|
| 90 |
+
task_id = self._generate_task_id()
|
| 91 |
+
|
| 92 |
+
task_progress = TaskProgress(task_id = task_id,
|
| 93 |
+
description = description,
|
| 94 |
+
status = TaskStatus.RUNNING,
|
| 95 |
+
total_items = total_items,
|
| 96 |
+
processed_items = 0,
|
| 97 |
+
current_item = None,
|
| 98 |
+
current_status = "Starting...",
|
| 99 |
+
start_time = datetime.now(),
|
| 100 |
+
end_time = None,
|
| 101 |
+
progress_percent = 0.0,
|
| 102 |
+
estimated_seconds_remaining = None,
|
| 103 |
+
metadata = metadata or {},
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.active_tasks[task_id] = task_progress
|
| 107 |
+
|
| 108 |
+
self.logger.info(f"Started task {task_id}: {description} ({total_items} items)")
|
| 109 |
+
|
| 110 |
+
return task_id
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def update_task(self, task_id: str, processed_items: Optional[int] = None, current_item: Optional[str] = None,
|
| 114 |
+
current_status: Optional[str] = None, metadata_update: Optional[Dict[str, Any]] = None) -> bool:
|
| 115 |
+
"""
|
| 116 |
+
Update task progress
|
| 117 |
+
|
| 118 |
+
Arguments:
|
| 119 |
+
----------
|
| 120 |
+
task_id { str } : Task ID to update
|
| 121 |
+
|
| 122 |
+
processed_items { int } : Number of items processed
|
| 123 |
+
|
| 124 |
+
current_item { str } : Current item being processed
|
| 125 |
+
|
| 126 |
+
current_status { str } : Current status message
|
| 127 |
+
|
| 128 |
+
metadata_update { dict } : Metadata updates
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
--------
|
| 132 |
+
{ bool } : True if update successful
|
| 133 |
+
"""
|
| 134 |
+
if task_id not in self.active_tasks:
|
| 135 |
+
self.logger.warning(f"Task {task_id} not found in active tasks")
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
task = self.active_tasks[task_id]
|
| 139 |
+
|
| 140 |
+
# Update fields
|
| 141 |
+
if processed_items is not None:
|
| 142 |
+
task.processed_items = processed_items
|
| 143 |
+
|
| 144 |
+
if current_item is not None:
|
| 145 |
+
task.current_item = current_item
|
| 146 |
+
|
| 147 |
+
if current_status is not None:
|
| 148 |
+
task.current_status = current_status
|
| 149 |
+
|
| 150 |
+
if metadata_update:
|
| 151 |
+
task.metadata.update(metadata_update)
|
| 152 |
+
|
| 153 |
+
# Calculate progress
|
| 154 |
+
if (task.total_items > 0):
|
| 155 |
+
task.progress_percent = (task.processed_items / task.total_items) * 100.0
|
| 156 |
+
|
| 157 |
+
# Estimate remaining time
|
| 158 |
+
if (task.processed_items > 0):
|
| 159 |
+
elapsed_seconds = (datetime.now() - task.start_time).total_seconds()
|
| 160 |
+
items_per_second = task.processed_items / elapsed_seconds
|
| 161 |
+
|
| 162 |
+
if (items_per_second > 0):
|
| 163 |
+
remaining_items = task.total_items - task.processed_items
|
| 164 |
+
task.estimated_seconds_remaining = remaining_items / items_per_second
|
| 165 |
+
|
| 166 |
+
# Ensure progress doesn't exceed 100%
|
| 167 |
+
if (task.progress_percent > 100.0):
|
| 168 |
+
task.progress_percent = 100.0
|
| 169 |
+
|
| 170 |
+
self.logger.debug(f"Updated task {task_id}: {task.progress_percent:.1f}% complete")
|
| 171 |
+
|
| 172 |
+
return True
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def complete_task(self, task_id: str, final_metadata: Optional[Dict[str, Any]] = None) -> bool:
|
| 176 |
+
"""
|
| 177 |
+
Mark task as completed
|
| 178 |
+
|
| 179 |
+
Arguments:
|
| 180 |
+
----------
|
| 181 |
+
task_id { str } : Task ID to complete
|
| 182 |
+
|
| 183 |
+
final_metadata { dict } : Final metadata updates
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
--------
|
| 187 |
+
{ bool } : True if completion successful
|
| 188 |
+
"""
|
| 189 |
+
if task_id not in self.active_tasks:
|
| 190 |
+
self.logger.warning(f"Task {task_id} not found for completion")
|
| 191 |
+
return False
|
| 192 |
+
|
| 193 |
+
task = self.active_tasks[task_id]
|
| 194 |
+
|
| 195 |
+
# Update task
|
| 196 |
+
task.status = TaskStatus.COMPLETED
|
| 197 |
+
task.end_time = datetime.now()
|
| 198 |
+
task.processed_items = task.total_items # Ensure 100% completion
|
| 199 |
+
task.progress_percent = 100.0
|
| 200 |
+
task.estimated_seconds_remaining = 0.0
|
| 201 |
+
task.current_status = "Completed"
|
| 202 |
+
|
| 203 |
+
if final_metadata:
|
| 204 |
+
task.metadata.update(final_metadata)
|
| 205 |
+
|
| 206 |
+
# Move to completed tasks
|
| 207 |
+
self.completed_tasks.append(task)
|
| 208 |
+
del self.active_tasks[task_id]
|
| 209 |
+
|
| 210 |
+
# Maintain history size
|
| 211 |
+
if (len(self.completed_tasks) > self.max_completed_tasks):
|
| 212 |
+
self.completed_tasks = self.completed_tasks[-self.max_completed_tasks:]
|
| 213 |
+
|
| 214 |
+
total_time = (task.end_time - task.start_time).total_seconds()
|
| 215 |
+
|
| 216 |
+
self.logger.info(f"Completed task {task_id}: {task.description} in {total_time:.2f}s")
|
| 217 |
+
|
| 218 |
+
return True
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def fail_task(self, task_id: str, error_message: str, error_details: Optional[Dict[str, Any]] = None) -> bool:
|
| 222 |
+
"""
|
| 223 |
+
Mark task as failed
|
| 224 |
+
|
| 225 |
+
Arguments:
|
| 226 |
+
----------
|
| 227 |
+
task_id { str } : Task ID to mark as failed
|
| 228 |
+
|
| 229 |
+
error_message { str } : Error message
|
| 230 |
+
|
| 231 |
+
error_details { dict } : Additional error details
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
--------
|
| 235 |
+
{ bool } : True if failure marking successful
|
| 236 |
+
"""
|
| 237 |
+
if (task_id not in self.active_tasks):
|
| 238 |
+
self.logger.warning(f"Task {task_id} not found for failure marking")
|
| 239 |
+
return False
|
| 240 |
+
|
| 241 |
+
task = self.active_tasks[task_id]
|
| 242 |
+
|
| 243 |
+
# Update task
|
| 244 |
+
task.status = TaskStatus.FAILED
|
| 245 |
+
task.end_time = datetime.now()
|
| 246 |
+
task.current_status = f"Failed: {error_message}"
|
| 247 |
+
|
| 248 |
+
if error_details:
|
| 249 |
+
task.metadata["error_details"] = error_details
|
| 250 |
+
|
| 251 |
+
task.metadata["error_message"] = error_message
|
| 252 |
+
|
| 253 |
+
# Move to failed tasks
|
| 254 |
+
self.failed_tasks.append(task)
|
| 255 |
+
del self.active_tasks[task_id]
|
| 256 |
+
|
| 257 |
+
total_time = (task.end_time - task.start_time).total_seconds()
|
| 258 |
+
|
| 259 |
+
self.logger.error(f"Task {task_id} failed after {total_time:.2f}s: {error_message}")
|
| 260 |
+
|
| 261 |
+
return True
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def cancel_task(self, task_id: str, reason: str = "User cancelled") -> bool:
|
| 265 |
+
"""
|
| 266 |
+
Cancel a running task
|
| 267 |
+
|
| 268 |
+
Arguments:
|
| 269 |
+
----------
|
| 270 |
+
task_id { str } : Task ID to cancel
|
| 271 |
+
|
| 272 |
+
reason { str } : Cancellation reason
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
--------
|
| 276 |
+
{ bool } : True if cancellation successful
|
| 277 |
+
"""
|
| 278 |
+
if task_id not in self.active_tasks:
|
| 279 |
+
self.logger.warning(f"Task {task_id} not found for cancellation")
|
| 280 |
+
return False
|
| 281 |
+
|
| 282 |
+
task = self.active_tasks[task_id]
|
| 283 |
+
|
| 284 |
+
# Update task
|
| 285 |
+
task.status = TaskStatus.CANCELLED
|
| 286 |
+
task.end_time = datetime.now()
|
| 287 |
+
task.current_status = f"Cancelled: {reason}"
|
| 288 |
+
task.metadata["cancellation_reason"] = reason
|
| 289 |
+
|
| 290 |
+
# Move to completed tasks (as cancelled)
|
| 291 |
+
self.completed_tasks.append(task)
|
| 292 |
+
del self.active_tasks[task_id]
|
| 293 |
+
|
| 294 |
+
total_time = (task.end_time - task.start_time).total_seconds()
|
| 295 |
+
|
| 296 |
+
self.logger.info(f"Cancelled task {task_id} after {total_time:.2f}s: {reason}")
|
| 297 |
+
|
| 298 |
+
return True
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def get_task_progress(self, task_id: str) -> Optional[TaskProgress]:
|
| 302 |
+
"""
|
| 303 |
+
Get current progress for a task
|
| 304 |
+
|
| 305 |
+
Arguments:
|
| 306 |
+
----------
|
| 307 |
+
task_id { str } : Task ID
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
--------
|
| 311 |
+
{ TaskProgress } : Task progress or None if not found
|
| 312 |
+
"""
|
| 313 |
+
# Check active tasks first
|
| 314 |
+
if task_id in self.active_tasks:
|
| 315 |
+
return self.active_tasks[task_id]
|
| 316 |
+
|
| 317 |
+
# Check completed tasks
|
| 318 |
+
for task in self.completed_tasks:
|
| 319 |
+
if (task.task_id == task_id):
|
| 320 |
+
return task
|
| 321 |
+
|
| 322 |
+
# Check failed tasks
|
| 323 |
+
for task in self.failed_tasks:
|
| 324 |
+
if (task.task_id == task_id):
|
| 325 |
+
return task
|
| 326 |
+
|
| 327 |
+
return None
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def get_all_active_tasks(self) -> List[TaskProgress]:
|
| 331 |
+
"""
|
| 332 |
+
Get all active tasks
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
--------
|
| 336 |
+
{ list } : List of active task progresses
|
| 337 |
+
"""
|
| 338 |
+
return list(self.active_tasks.values())
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def get_active_task_count(self) -> int:
|
| 342 |
+
"""
|
| 343 |
+
Get number of active tasks
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
--------
|
| 347 |
+
{ int } : Number of active tasks
|
| 348 |
+
"""
|
| 349 |
+
return len(self.active_tasks)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def get_recent_completed_tasks(self, limit: int = 10) -> List[TaskProgress]:
|
| 353 |
+
"""
|
| 354 |
+
Get recently completed tasks
|
| 355 |
+
|
| 356 |
+
Arguments:
|
| 357 |
+
----------
|
| 358 |
+
limit { int } : Maximum number of tasks to return
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
--------
|
| 362 |
+
{ list } : List of completed tasks (newest first)
|
| 363 |
+
"""
|
| 364 |
+
# Return newest first
|
| 365 |
+
return self.completed_tasks[-limit:][::-1]
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def get_task_statistics(self, task_id: str) -> Optional[Dict[str, Any]]:
|
| 369 |
+
"""
|
| 370 |
+
Get detailed statistics for a task
|
| 371 |
+
|
| 372 |
+
Arguments:
|
| 373 |
+
----------
|
| 374 |
+
task_id { str } : Task ID
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
--------
|
| 378 |
+
{ dict } : Task statistics or None
|
| 379 |
+
"""
|
| 380 |
+
task = self.get_task_progress(task_id)
|
| 381 |
+
|
| 382 |
+
if not task:
|
| 383 |
+
return None
|
| 384 |
+
|
| 385 |
+
stats = {"task_id" : task.task_id,
|
| 386 |
+
"description" : task.description,
|
| 387 |
+
"status" : task.status.value,
|
| 388 |
+
"progress_percent" : task.progress_percent,
|
| 389 |
+
"processed_items" : task.processed_items,
|
| 390 |
+
"total_items" : task.total_items,
|
| 391 |
+
"current_item" : task.current_item,
|
| 392 |
+
"current_status" : task.current_status,
|
| 393 |
+
"start_time" : task.start_time.isoformat(),
|
| 394 |
+
"metadata" : task.metadata,
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
if task.end_time:
|
| 398 |
+
total_seconds = (task.end_time - task.start_time).total_seconds()
|
| 399 |
+
stats["total_time_seconds"] = total_seconds
|
| 400 |
+
stats["end_time"] = task.end_time.isoformat()
|
| 401 |
+
|
| 402 |
+
if task.estimated_seconds_remaining:
|
| 403 |
+
stats["estimated_seconds_remaining"] = task.estimated_seconds_remaining
|
| 404 |
+
|
| 405 |
+
return stats
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def get_global_statistics(self) -> Dict[str, Any]:
|
| 409 |
+
"""
|
| 410 |
+
Get global progress statistics
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
--------
|
| 414 |
+
{ dict } : Global statistics
|
| 415 |
+
"""
|
| 416 |
+
total_completed = len(self.completed_tasks)
|
| 417 |
+
total_failed = len(self.failed_tasks)
|
| 418 |
+
total_tasks = total_completed + total_failed + len(self.active_tasks)
|
| 419 |
+
|
| 420 |
+
# Calculate average completion time
|
| 421 |
+
avg_completion_time = 0.0
|
| 422 |
+
|
| 423 |
+
if (total_completed > 0):
|
| 424 |
+
total_time = sum((task.end_time - task.start_time).total_seconds() for task in self.completed_tasks if task.end_time)
|
| 425 |
+
avg_completion_time = total_time / total_completed
|
| 426 |
+
|
| 427 |
+
return {"active_tasks" : len(self.active_tasks),
|
| 428 |
+
"completed_tasks" : total_completed,
|
| 429 |
+
"failed_tasks" : total_failed,
|
| 430 |
+
"total_tasks" : total_tasks,
|
| 431 |
+
"success_rate" : (total_completed / total_tasks * 100) if total_tasks > 0 else 0,
|
| 432 |
+
"avg_completion_time_seconds" : avg_completion_time,
|
| 433 |
+
"max_completed_tasks" : self.max_completed_tasks,
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def cleanup_completed_tasks(self, older_than_hours: int = 24):
|
| 438 |
+
"""
|
| 439 |
+
Clean up old completed tasks
|
| 440 |
+
|
| 441 |
+
Arguments:
|
| 442 |
+
----------
|
| 443 |
+
older_than_hours { int } : Remove tasks older than this many hours
|
| 444 |
+
"""
|
| 445 |
+
cutoff_time = datetime.now().timestamp() - (older_than_hours * 3600)
|
| 446 |
+
initial_count = len(self.completed_tasks)
|
| 447 |
+
|
| 448 |
+
self.completed_tasks = [task for task in self.completed_tasks if task.end_time and task.end_time.timestamp() > cutoff_time]
|
| 449 |
+
|
| 450 |
+
removed_count = initial_count - len(self.completed_tasks)
|
| 451 |
+
|
| 452 |
+
if (removed_count > 0):
|
| 453 |
+
self.logger.info(f"Cleaned up {removed_count} completed tasks older than {older_than_hours} hours")
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def _generate_task_id() -> str:
|
| 458 |
+
"""
|
| 459 |
+
Generate unique task ID
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
--------
|
| 463 |
+
{ str } : Generated task ID
|
| 464 |
+
"""
|
| 465 |
+
return f"task_{uuid.uuid4().hex[:8]}_{int(time.time())}"
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def __del__(self):
|
| 469 |
+
"""
|
| 470 |
+
Cleanup on destruction
|
| 471 |
+
"""
|
| 472 |
+
try:
|
| 473 |
+
self.cleanup_completed_tasks()
|
| 474 |
+
|
| 475 |
+
except Exception:
|
| 476 |
+
# Ignore cleanup errors during destruction
|
| 477 |
+
pass
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# Global progress tracker instance
|
| 481 |
+
_progress_tracker = None
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def get_progress_tracker() -> ProgressTracker:
|
| 485 |
+
"""
|
| 486 |
+
Get global progress tracker instance
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
--------
|
| 490 |
+
{ ProgressTracker } : ProgressTracker instance
|
| 491 |
+
"""
|
| 492 |
+
global _progress_tracker
|
| 493 |
+
|
| 494 |
+
if _progress_tracker is None:
|
| 495 |
+
_progress_tracker = ProgressTracker()
|
| 496 |
+
|
| 497 |
+
return _progress_tracker
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def start_progress_task(total_items: int, description: str, **kwargs) -> str:
|
| 501 |
+
"""
|
| 502 |
+
Convenience function to start a progress task
|
| 503 |
+
|
| 504 |
+
Arguments:
|
| 505 |
+
----------
|
| 506 |
+
total_items { int } : Total items
|
| 507 |
+
|
| 508 |
+
description { str } : Task description
|
| 509 |
+
|
| 510 |
+
**kwargs : Additional arguments
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
--------
|
| 514 |
+
{ str } : Task ID
|
| 515 |
+
"""
|
| 516 |
+
tracker = get_progress_tracker()
|
| 517 |
+
|
| 518 |
+
return tracker.start_task(total_items, description, **kwargs)
|
ingestion/router.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Any
|
| 6 |
+
from typing import List
|
| 7 |
+
from typing import Dict
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from config.settings import get_settings
|
| 11 |
+
from utils.file_handler import FileHandler
|
| 12 |
+
from config.models import IngestionInputType
|
| 13 |
+
from config.logging_config import get_logger
|
| 14 |
+
from utils.error_handler import handle_errors
|
| 15 |
+
from utils.validators import validate_upload_file
|
| 16 |
+
from utils.error_handler import ProcessingException
|
| 17 |
+
from ingestion.progress_tracker import get_progress_tracker
|
| 18 |
+
from document_parser.zip_handler import get_archive_handler
|
| 19 |
+
from ingestion.async_coordinator import get_async_coordinator
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Setup Settings and Logging
|
| 23 |
+
settings = get_settings()
|
| 24 |
+
logger = get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class IngestionRouter:
|
| 28 |
+
"""
|
| 29 |
+
Intelligent ingestion router: Determines optimal processing strategy based on input type and characteristics
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""
|
| 33 |
+
Initialize ingestion router
|
| 34 |
+
"""
|
| 35 |
+
self.logger = logger
|
| 36 |
+
self.async_coordinator = get_async_coordinator()
|
| 37 |
+
self.progress_tracker = get_progress_tracker()
|
| 38 |
+
self.file_handler = FileHandler()
|
| 39 |
+
|
| 40 |
+
# Processing strategies
|
| 41 |
+
self.processing_strategies = {IngestionInputType.FILE : self._process_single_file,
|
| 42 |
+
IngestionInputType.ARCHIVE : self._process_archive,
|
| 43 |
+
IngestionInputType.URL : self._process_url,
|
| 44 |
+
IngestionInputType.TEXT : self._process_text,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
self.logger.info("Initialized IngestionRouter")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@handle_errors(error_type = ProcessingException, log_error = True, reraise = True)
|
| 51 |
+
def route_and_process(self, input_data: Any, input_type: IngestionInputType, metadata: Optional[Dict[str, Any]] = None, progress_callback: Optional[callable] = None) -> Dict[str, Any]:
|
| 52 |
+
"""
|
| 53 |
+
Route input to appropriate processing strategy
|
| 54 |
+
|
| 55 |
+
Arguments:
|
| 56 |
+
----------
|
| 57 |
+
input_data { Any } : Input data (file path, URL, text, etc.)
|
| 58 |
+
|
| 59 |
+
input_type { IngestionInputType } : Type of input
|
| 60 |
+
|
| 61 |
+
metadata { dict } : Additional metadata
|
| 62 |
+
|
| 63 |
+
progress_callback { callable } : Progress callback
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
--------
|
| 67 |
+
{ dict } : Processing results
|
| 68 |
+
"""
|
| 69 |
+
self.logger.info(f"Routing {input_type.value} input for processing")
|
| 70 |
+
|
| 71 |
+
# Validate input
|
| 72 |
+
self._validate_input(input_data, input_type)
|
| 73 |
+
|
| 74 |
+
# Get processing strategy
|
| 75 |
+
processor = self.processing_strategies.get(input_type)
|
| 76 |
+
|
| 77 |
+
if not processor:
|
| 78 |
+
raise ProcessingException(f"No processor available for input type: {input_type}")
|
| 79 |
+
|
| 80 |
+
# Process with progress tracking
|
| 81 |
+
task_id = self.progress_tracker.start_task(total_items = 1,
|
| 82 |
+
description = f"Processing {input_type.value}",
|
| 83 |
+
metadata = metadata or {},
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
result = processor(input_data, metadata, task_id, progress_callback)
|
| 88 |
+
|
| 89 |
+
self.progress_tracker.complete_task(task_id, {"processed_items": 1})
|
| 90 |
+
|
| 91 |
+
self.logger.info(f"Successfully processed {input_type.value} input")
|
| 92 |
+
|
| 93 |
+
return result
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
self.progress_tracker.fail_task(task_id, str(e))
|
| 97 |
+
raise ProcessingException(f"Failed to process {input_type.value} input: {repr(e)}")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _process_single_file(self, file_path: Path, metadata: Optional[Dict[str, Any]], task_id: str, progress_callback: Optional[callable]) -> Dict[str, Any]:
|
| 101 |
+
"""
|
| 102 |
+
Process single file
|
| 103 |
+
|
| 104 |
+
Arguments:
|
| 105 |
+
----------
|
| 106 |
+
file_path { Path } : File path to process
|
| 107 |
+
|
| 108 |
+
metadata { dict } : File metadata
|
| 109 |
+
|
| 110 |
+
task_id { str } : Progress task ID
|
| 111 |
+
|
| 112 |
+
progress_callback { callable } : Progress callback
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
--------
|
| 116 |
+
{ dict } : Processing results
|
| 117 |
+
"""
|
| 118 |
+
self.logger.info(f"Processing single file: {file_path}")
|
| 119 |
+
|
| 120 |
+
# Update progress
|
| 121 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 122 |
+
current_item = file_path.name,
|
| 123 |
+
current_status = "validating_file",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if progress_callback:
|
| 127 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 128 |
+
|
| 129 |
+
# Validate file
|
| 130 |
+
validate_upload_file(file_path)
|
| 131 |
+
|
| 132 |
+
# Process file
|
| 133 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 134 |
+
current_item = file_path.name,
|
| 135 |
+
current_status = "processing_file",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
result = self.async_coordinator.process_documents_threaded([file_path], progress_callback)
|
| 139 |
+
|
| 140 |
+
# Extract single result (since we processed one file)
|
| 141 |
+
if result["results"]:
|
| 142 |
+
file_result = result["results"][0]
|
| 143 |
+
|
| 144 |
+
else:
|
| 145 |
+
raise ProcessingException(f"File processing failed: {result.get('failures', [])}")
|
| 146 |
+
|
| 147 |
+
return {"success" : True,
|
| 148 |
+
"file_path" : str(file_path),
|
| 149 |
+
"file_name" : file_path.name,
|
| 150 |
+
"text_length" : file_result.get("text_length", 0),
|
| 151 |
+
"content_type" : self._detect_content_type(file_path),
|
| 152 |
+
"metadata" : file_result.get("metadata", {}),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _process_archive(self, archive_path: Path, metadata: Optional[Dict[str, Any]], task_id: str, progress_callback: Optional[callable]) -> Dict[str, Any]:
|
| 157 |
+
"""
|
| 158 |
+
Process archive file (ZIP, RAR, etc.)
|
| 159 |
+
|
| 160 |
+
Arguments:
|
| 161 |
+
----------
|
| 162 |
+
archive_path { Path } : Archive file path
|
| 163 |
+
|
| 164 |
+
metadata { dict } : Archive metadata
|
| 165 |
+
|
| 166 |
+
task_id { str } : Progress task ID
|
| 167 |
+
|
| 168 |
+
progress_callback { callable } : Progress callback
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
--------
|
| 172 |
+
{ dict } : Processing results
|
| 173 |
+
"""
|
| 174 |
+
self.logger.info(f"Processing archive: {archive_path}")
|
| 175 |
+
|
| 176 |
+
# Update progress
|
| 177 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 178 |
+
current_item = archive_path.name,
|
| 179 |
+
current_status = "extracting_archive",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if progress_callback:
|
| 183 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 184 |
+
|
| 185 |
+
# Extract archive
|
| 186 |
+
archive_handler = get_archive_handler()
|
| 187 |
+
extracted_files = archive_handler.extract_archive(archive_path)
|
| 188 |
+
|
| 189 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 190 |
+
current_item = archive_path.name,
|
| 191 |
+
current_status = f"processing_{len(extracted_files)}_files",
|
| 192 |
+
processed_items = 0,
|
| 193 |
+
total_items = len(extracted_files),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Process extracted files
|
| 197 |
+
result = self.async_coordinator.process_documents_threaded(extracted_files, progress_callback)
|
| 198 |
+
|
| 199 |
+
return {"success" : True,
|
| 200 |
+
"archive_path" : str(archive_path),
|
| 201 |
+
"archive_name" : archive_path.name,
|
| 202 |
+
"extracted_files" : len(extracted_files),
|
| 203 |
+
"processed_files" : result["processed"],
|
| 204 |
+
"failed_files" : result["failed"],
|
| 205 |
+
"success_rate" : result["success_rate"],
|
| 206 |
+
"results" : result["results"],
|
| 207 |
+
"failures" : result["failures"],
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _process_url(self, url: str, metadata: Optional[Dict[str, Any]], task_id: str, progress_callback: Optional[callable]) -> Dict[str, Any]:
|
| 212 |
+
"""
|
| 213 |
+
Process URL (web page scraping)
|
| 214 |
+
|
| 215 |
+
WARNING: URL processing not implemented yet. Requires Playwright/BeautifulSoup integration
|
| 216 |
+
TODO: Implement web scraping functionality
|
| 217 |
+
|
| 218 |
+
Arguments:
|
| 219 |
+
----------
|
| 220 |
+
url { str } : URL to process
|
| 221 |
+
|
| 222 |
+
metadata { dict } : URL metadata
|
| 223 |
+
|
| 224 |
+
task_id { str } : Progress task ID
|
| 225 |
+
|
| 226 |
+
progress_callback { callable } : Progress callback
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
--------
|
| 230 |
+
{ dict } : Processing results
|
| 231 |
+
"""
|
| 232 |
+
self.logger.info(f"Processing URL: {url}")
|
| 233 |
+
|
| 234 |
+
# Update progress
|
| 235 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 236 |
+
current_item = url,
|
| 237 |
+
current_status = "scraping_url",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if progress_callback:
|
| 241 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 242 |
+
|
| 243 |
+
# Note: Web scraping would be implemented here: For now, return placeholder
|
| 244 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 245 |
+
current_item = url,
|
| 246 |
+
current_status = "processing_content",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Placeholder implementation: In production, this would use playwright/beautifulsoup scrapers
|
| 250 |
+
return {"success" : True,
|
| 251 |
+
"url" : url,
|
| 252 |
+
"content_type" : "web_page",
|
| 253 |
+
"text_length" : 0, # Would be actual content length
|
| 254 |
+
"message" : "URL processing placeholder - implement web scraping",
|
| 255 |
+
"metadata" : metadata or {},
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _process_text(self, text: str, metadata: Optional[Dict[str, Any]], task_id: str, progress_callback: Optional[callable]) -> Dict[str, Any]:
|
| 260 |
+
"""
|
| 261 |
+
Process raw text input
|
| 262 |
+
|
| 263 |
+
Arguments:
|
| 264 |
+
----------
|
| 265 |
+
text { str } : Text content to process
|
| 266 |
+
|
| 267 |
+
metadata { dict } : Text metadata
|
| 268 |
+
|
| 269 |
+
task_id { str } : Progress task ID
|
| 270 |
+
|
| 271 |
+
progress_callback { callable } : Progress callback
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
--------
|
| 275 |
+
{ dict } : Processing results
|
| 276 |
+
"""
|
| 277 |
+
self.logger.info(f"Processing text input ({len(text)} characters)")
|
| 278 |
+
|
| 279 |
+
# Update progress
|
| 280 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 281 |
+
current_item = "text_input",
|
| 282 |
+
current_status = "processing_text",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if progress_callback:
|
| 286 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 287 |
+
|
| 288 |
+
# For text input, create a temporary file and process it
|
| 289 |
+
with tempfile.NamedTemporaryFile(mode = 'w', suffix = '.txt', delete = False, encoding = 'utf-8') as temp_file:
|
| 290 |
+
temp_file.write(text)
|
| 291 |
+
temp_path = Path(temp_file.name)
|
| 292 |
+
|
| 293 |
+
try:
|
| 294 |
+
# Process as a file
|
| 295 |
+
file_result = self._process_single_file(file_path = temp_path,
|
| 296 |
+
metadata = metadata,
|
| 297 |
+
task_id = task_id,
|
| 298 |
+
progress_callback = progress_callback,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Add text-specific metadata
|
| 302 |
+
file_result["input_type"] = "direct_text"
|
| 303 |
+
file_result["original_text_length"] = len(text)
|
| 304 |
+
|
| 305 |
+
return file_result
|
| 306 |
+
|
| 307 |
+
finally:
|
| 308 |
+
# Cleanup temporary file
|
| 309 |
+
try:
|
| 310 |
+
os.unlink(temp_path)
|
| 311 |
+
|
| 312 |
+
except Exception as e:
|
| 313 |
+
self.logger.warning(f"Failed to delete temporary file: {repr(e)}")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _validate_input(self, input_data: Any, input_type: IngestionInputType):
|
| 317 |
+
"""
|
| 318 |
+
Validate input based on type
|
| 319 |
+
|
| 320 |
+
Arguments:
|
| 321 |
+
----------
|
| 322 |
+
input_data { Any } : Input data to validate
|
| 323 |
+
|
| 324 |
+
input_type { IngestionInputType } : Type of input
|
| 325 |
+
|
| 326 |
+
Raises:
|
| 327 |
+
-------
|
| 328 |
+
ProcessingException : If validation fails
|
| 329 |
+
"""
|
| 330 |
+
if (input_type == IngestionInputType.FILE):
|
| 331 |
+
if not isinstance(input_data, (str, Path)):
|
| 332 |
+
raise ProcessingException("File input must be a path string or Path object")
|
| 333 |
+
|
| 334 |
+
file_path = Path(input_data)
|
| 335 |
+
if not file_path.exists():
|
| 336 |
+
raise ProcessingException(f"File not found: {file_path}")
|
| 337 |
+
|
| 338 |
+
elif (input_type == IngestionInputType.URL):
|
| 339 |
+
if not isinstance(input_data, str):
|
| 340 |
+
raise ProcessingException("URL input must be a string")
|
| 341 |
+
|
| 342 |
+
if not input_data.startswith(('http://', 'https://')):
|
| 343 |
+
raise ProcessingException("URL must start with http:// or https://")
|
| 344 |
+
|
| 345 |
+
elif (input_type == IngestionInputType.TEXT):
|
| 346 |
+
if not isinstance(input_data, str):
|
| 347 |
+
raise ProcessingException("Text input must be a string")
|
| 348 |
+
|
| 349 |
+
if len(input_data.strip()) == 0:
|
| 350 |
+
raise ProcessingException("Text input cannot be empty")
|
| 351 |
+
|
| 352 |
+
elif (input_type == IngestionInputType.ARCHIVE):
|
| 353 |
+
if not isinstance(input_data, (str, Path)):
|
| 354 |
+
raise ProcessingException("Archive input must be a path string or Path object")
|
| 355 |
+
|
| 356 |
+
file_path = Path(input_data)
|
| 357 |
+
if not file_path.exists():
|
| 358 |
+
raise ProcessingException(f"Archive file not found: {file_path}")
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _detect_content_type(self, file_path: Path) -> str:
|
| 362 |
+
"""
|
| 363 |
+
Detect content type from file extension
|
| 364 |
+
|
| 365 |
+
Arguments:
|
| 366 |
+
----------
|
| 367 |
+
file_path { Path } : File path
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
--------
|
| 371 |
+
{ str } : Content type
|
| 372 |
+
"""
|
| 373 |
+
extension = file_path.suffix.lower()
|
| 374 |
+
|
| 375 |
+
content_types = {'.pdf' : 'application/pdf',
|
| 376 |
+
'.docx' : 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
| 377 |
+
'.doc' : 'application/msword',
|
| 378 |
+
'.txt' : 'text/plain',
|
| 379 |
+
'.zip' : 'application/zip',
|
| 380 |
+
'.rar' : 'application/vnd.rar',
|
| 381 |
+
'.7z' : 'application/x-7z-compressed',
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
return content_types.get(extension, 'application/octet-stream')
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def batch_process(self, inputs: List[Dict[str, Any]], progress_callback: Optional[callable] = None) -> Dict[str, Any]:
|
| 388 |
+
"""
|
| 389 |
+
Process multiple inputs in batch
|
| 390 |
+
|
| 391 |
+
Arguments:
|
| 392 |
+
----------
|
| 393 |
+
inputs { list } : List of input dictionaries
|
| 394 |
+
|
| 395 |
+
progress_callback { callable } : Progress callback
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
--------
|
| 399 |
+
{ dict } : Batch processing results
|
| 400 |
+
"""
|
| 401 |
+
self.logger.info(f"Starting batch processing of {len(inputs)} inputs")
|
| 402 |
+
|
| 403 |
+
task_id = self.progress_tracker.start_task(total_items = len(inputs),
|
| 404 |
+
description = "Batch processing",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
results = list()
|
| 408 |
+
failed = list()
|
| 409 |
+
|
| 410 |
+
for i, input_config in enumerate(inputs):
|
| 411 |
+
try:
|
| 412 |
+
self.progress_tracker.update_task(task_id = task_id,
|
| 413 |
+
processed_items = i,
|
| 414 |
+
current_item = f"Input {i + 1}",
|
| 415 |
+
current_status = "processing",
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
input_type = IngestionInputType(input_config.get("type"))
|
| 419 |
+
input_data = input_config.get("data")
|
| 420 |
+
metadata = input_config.get("metadata", {})
|
| 421 |
+
|
| 422 |
+
result = self.route_and_process(input_data = input_data,
|
| 423 |
+
input_type = input_type,
|
| 424 |
+
metadata = metadata,
|
| 425 |
+
progress_callback = None,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
results.append(result)
|
| 429 |
+
|
| 430 |
+
if progress_callback:
|
| 431 |
+
progress_callback(self.progress_tracker.get_task_progress(task_id))
|
| 432 |
+
|
| 433 |
+
except Exception as e:
|
| 434 |
+
self.logger.error(f"Failed to process input {i + 1}: {repr(e)}")
|
| 435 |
+
failed.append({"input_index": i, "input_config": input_config, "error": str(e)})
|
| 436 |
+
|
| 437 |
+
self.progress_tracker.complete_task(task_id, {"processed_items": len(results)})
|
| 438 |
+
|
| 439 |
+
return {"processed" : len(results),
|
| 440 |
+
"failed" : len(failed),
|
| 441 |
+
"success_rate" : (len(results) / len(inputs)) * 100,
|
| 442 |
+
"results" : results,
|
| 443 |
+
"failures" : failed,
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def cleanup(self):
|
| 448 |
+
"""
|
| 449 |
+
Cleanup router resources
|
| 450 |
+
"""
|
| 451 |
+
if hasattr(self, 'async_coordinator'):
|
| 452 |
+
self.async_coordinator.cleanup()
|
| 453 |
+
|
| 454 |
+
self.logger.debug("IngestionRouter cleanup completed")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def get_processing_capabilities(self) -> Dict[str, Any]:
|
| 458 |
+
"""
|
| 459 |
+
Get supported processing capabilities
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
--------
|
| 463 |
+
{ dict } : Capabilities information
|
| 464 |
+
"""
|
| 465 |
+
# URL type is not fully implemented yet
|
| 466 |
+
supported_types = [t.value for t in IngestionInputType if (t != IngestionInputType.URL)]
|
| 467 |
+
|
| 468 |
+
return {"supported_input_types" : supported_types,
|
| 469 |
+
"max_file_size_mb" : settings.MAX_FILE_SIZE_MB,
|
| 470 |
+
"max_batch_files" : settings.MAX_BATCH_FILES,
|
| 471 |
+
"allowed_extensions" : settings.ALLOWED_EXTENSIONS,
|
| 472 |
+
"max_workers" : settings.MAX_WORKERS,
|
| 473 |
+
"async_supported" : True,
|
| 474 |
+
"batch_supported" : True,
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# Global ingestion router instance
|
| 479 |
+
_ingestion_router = None
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def get_ingestion_router() -> IngestionRouter:
|
| 483 |
+
"""
|
| 484 |
+
Get global ingestion router instance
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
--------
|
| 488 |
+
{ IngestionRouter } : IngestionRouter instance
|
| 489 |
+
"""
|
| 490 |
+
global _ingestion_router
|
| 491 |
+
|
| 492 |
+
if _ingestion_router is None:
|
| 493 |
+
_ingestion_router = IngestionRouter()
|
| 494 |
+
|
| 495 |
+
return _ingestion_router
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def process_input(input_data: Any, input_type: IngestionInputType, **kwargs) -> Dict[str, Any]:
|
| 499 |
+
"""
|
| 500 |
+
Convenience function for input processing
|
| 501 |
+
|
| 502 |
+
Arguments:
|
| 503 |
+
----------
|
| 504 |
+
input_data { Any } : Input data
|
| 505 |
+
|
| 506 |
+
input_type { IngestionInputType } : Input type
|
| 507 |
+
|
| 508 |
+
**kwargs : Additional arguments
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
--------
|
| 512 |
+
{ dict } : Processing results
|
| 513 |
+
"""
|
| 514 |
+
router = get_ingestion_router()
|
| 515 |
+
|
| 516 |
+
return router.route_and_process(input_data, input_type, **kwargs)
|
requirements.txt
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Dependencies
|
| 2 |
+
fastapi==0.104.1
|
| 3 |
+
uvicorn[standard]==0.24.0
|
| 4 |
+
pydantic==2.12.5
|
| 5 |
+
pydantic-settings==2.12.0
|
| 6 |
+
python-multipart==0.0.6
|
| 7 |
+
|
| 8 |
+
# AI/ML & NLP
|
| 9 |
+
#ollama==0.1.7
|
| 10 |
+
sentence-transformers==2.2.2
|
| 11 |
+
transformers==4.37.2
|
| 12 |
+
torch==2.9.1
|
| 13 |
+
tokenizers==0.15.2
|
| 14 |
+
accelerate==0.24.1
|
| 15 |
+
|
| 16 |
+
# Vector Database & Search
|
| 17 |
+
faiss-cpu==1.7.4
|
| 18 |
+
rank-bm25==0.2.2
|
| 19 |
+
|
| 20 |
+
# Document Processing
|
| 21 |
+
PyMuPDF==1.23.8
|
| 22 |
+
PyPDF2==3.0.1
|
| 23 |
+
python-docx==1.1.0
|
| 24 |
+
python-pptx==0.6.23
|
| 25 |
+
|
| 26 |
+
# OCR & Image Processing
|
| 27 |
+
paddleocr==2.7.3
|
| 28 |
+
easyocr==1.7.2
|
| 29 |
+
paddlepaddle==3.2.2
|
| 30 |
+
Pillow==10.0.1
|
| 31 |
+
opencv-python==4.6.0.66
|
| 32 |
+
|
| 33 |
+
# Archive Handling
|
| 34 |
+
py7zr==0.20.6
|
| 35 |
+
rarfile==4.1
|
| 36 |
+
python-magic==0.4.27
|
| 37 |
+
|
| 38 |
+
# Text Processing & Utilities
|
| 39 |
+
chardet==5.2.0
|
| 40 |
+
numpy==1.24.3
|
| 41 |
+
nltk==3.8.1
|
| 42 |
+
tqdm==4.66.1
|
| 43 |
+
filetype==1.2.0
|
| 44 |
+
sentencepiece==0.2.1
|
| 45 |
+
|
| 46 |
+
# Caching & Performance
|
| 47 |
+
redis==5.0.1
|
| 48 |
+
psutil==5.9.6
|
| 49 |
+
|
| 50 |
+
# Ragas Evaluation Framework
|
| 51 |
+
ragas==0.1.2
|
| 52 |
+
datasets==2.14.6
|
| 53 |
+
evaluate==0.4.1
|
| 54 |
+
|
| 55 |
+
# Evaluation Metrics & Utilities
|
| 56 |
+
scikit-learn==1.3.2
|
| 57 |
+
scipy==1.11.4
|
| 58 |
+
pandas==2.0.3
|
| 59 |
+
seaborn==0.13.0
|
| 60 |
+
matplotlib==3.10.7
|
| 61 |
+
|
| 62 |
+
# Development
|
| 63 |
+
black==25.11.0
|
| 64 |
+
flake8==7.3.0
|
| 65 |
+
mypy==1.18.2
|
| 66 |
+
|
| 67 |
+
# Async & Concurrency
|
| 68 |
+
aiofiles==23.2.1
|
| 69 |
+
|
| 70 |
+
# System & Utilities
|
| 71 |
+
python-dateutil==2.8.2
|
| 72 |
+
typing-extensions==4.15.0
|
| 73 |
+
protobuf==4.25.8
|
| 74 |
+
|
| 75 |
+
# Additional dependencies from your environment
|
| 76 |
+
aiohttp>=3.9.3
|
| 77 |
+
anyio==3.7.1
|
| 78 |
+
async-timeout==4.0.3
|
| 79 |
+
attrs==25.4.0
|
| 80 |
+
click==8.1.7
|
| 81 |
+
colorlog==6.10.1
|
| 82 |
+
cryptography==42.0.2
|
| 83 |
+
filelock==3.20.0
|
| 84 |
+
h11==0.16.0
|
| 85 |
+
huggingface-hub==0.20.0
|
| 86 |
+
httpx==0.25.2
|
| 87 |
+
idna==3.11
|
| 88 |
+
importlib-metadata==6.11.0
|
| 89 |
+
joblib==1.5.2
|
| 90 |
+
jsonpatch==1.33
|
| 91 |
+
jsonschema==4.25.0
|
| 92 |
+
langchain==0.1.7
|
| 93 |
+
langchain-community==0.0.20
|
| 94 |
+
langchain-core==0.1.23
|
| 95 |
+
langsmith==0.0.87
|
| 96 |
+
loguru==0.7.2
|
| 97 |
+
lxml==5.1.0
|
| 98 |
+
MarkupSafe==3.0.2
|
| 99 |
+
msgpack==1.0.7
|
| 100 |
+
networkx==3.4.2
|
| 101 |
+
openai>=1.54.0
|
| 102 |
+
orjson==3.11.4
|
| 103 |
+
packaging==23.2
|
| 104 |
+
pandas==2.0.3
|
| 105 |
+
pip==25.2
|
| 106 |
+
platformdirs==4.5.0
|
| 107 |
+
pluggy==1.6.0
|
| 108 |
+
pydantic_core==2.41.5
|
| 109 |
+
pygments==2.19.2
|
| 110 |
+
pypdf==4.3.1
|
| 111 |
+
pytest==9.0.1
|
| 112 |
+
python-dotenv==1.0.1
|
| 113 |
+
PyYAML==6.0.2
|
| 114 |
+
requests==2.32.5
|
| 115 |
+
rich==13.7.0
|
| 116 |
+
scikit-image==0.21.0
|
| 117 |
+
setuptools==80.9.0
|
| 118 |
+
six==1.17.0
|
| 119 |
+
sniffio==1.3.0
|
| 120 |
+
SQLAlchemy==2.0.27
|
| 121 |
+
starlette==0.27.0
|
| 122 |
+
sympy==1.14.0
|
| 123 |
+
tenacity==8.5.0
|
| 124 |
+
threadpoolctl==3.6.0
|
| 125 |
+
tiktoken==0.5.2
|
| 126 |
+
tomli==2.2.1
|
| 127 |
+
torchvision==0.24.1
|
| 128 |
+
typer==0.9.4
|
| 129 |
+
urllib3==2.5.0
|
| 130 |
+
wasabi==1.1.3
|
| 131 |
+
Werkzeug==3.1.4
|
| 132 |
+
wheel==0.45.1
|
| 133 |
+
xxhash==3.6.0
|
| 134 |
+
yarl==1.22.0
|
| 135 |
+
zipp==3.23.0
|
retrieval/__init__.py
ADDED
|
File without changes
|
retrieval/citation_tracker.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DEPENDENCIES
|
| 2 |
+
import re
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from config.models import DocumentChunk
|
| 9 |
+
from config.models import ChunkWithScore
|
| 10 |
+
from config.logging_config import get_logger
|
| 11 |
+
from utils.error_handler import CitationError
|
| 12 |
+
from utils.error_handler import handle_errors
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Setup Logging
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CitationTracker:
|
| 20 |
+
"""
|
| 21 |
+
Citation tracking and management: Tracks source citations in generated text and provides citation formatting and validation
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""
|
| 25 |
+
Initialize citation tracker
|
| 26 |
+
"""
|
| 27 |
+
self.logger = logger
|
| 28 |
+
self.citation_pattern = re.compile(r'\[(\d+)\]')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def extract_citations(self, text: str) -> List[int]:
|
| 32 |
+
"""
|
| 33 |
+
Extract citation numbers from text
|
| 34 |
+
|
| 35 |
+
Arguments:
|
| 36 |
+
----------
|
| 37 |
+
text { str } : Text containing citations
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
--------
|
| 41 |
+
{ list } : List of citation numbers found in text
|
| 42 |
+
"""
|
| 43 |
+
if not text:
|
| 44 |
+
return []
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
matches = self.citation_pattern.findall(text)
|
| 48 |
+
citation_numbers = [int(match) for match in matches]
|
| 49 |
+
|
| 50 |
+
# Remove duplicates and sort
|
| 51 |
+
unique_citations = sorted(set(citation_numbers))
|
| 52 |
+
|
| 53 |
+
self.logger.debug(f"Extracted {len(unique_citations)} citations from text")
|
| 54 |
+
|
| 55 |
+
return unique_citations
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
self.logger.error(f"Citation extraction failed: {repr(e)}")
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def validate_citations(self, text: str, sources: List[ChunkWithScore]) -> Tuple[bool, List[int]]:
|
| 63 |
+
"""
|
| 64 |
+
Validate that all citations in text reference existing sources
|
| 65 |
+
|
| 66 |
+
Arguments:
|
| 67 |
+
----------
|
| 68 |
+
text { str } : Text containing citations
|
| 69 |
+
|
| 70 |
+
sources { list } : List of available sources
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
--------
|
| 74 |
+
{ Tuple[bool, List[int]] } : (is_valid, invalid_citations)
|
| 75 |
+
"""
|
| 76 |
+
citation_numbers = self.extract_citations(text = text)
|
| 77 |
+
|
| 78 |
+
if not citation_numbers:
|
| 79 |
+
return True, []
|
| 80 |
+
|
| 81 |
+
# Check if all citation numbers are within valid range
|
| 82 |
+
max_valid = len(sources)
|
| 83 |
+
invalid_citations = [num for num in citation_numbers if (num < 1) or (num > max_valid)]
|
| 84 |
+
|
| 85 |
+
if invalid_citations:
|
| 86 |
+
self.logger.warning(f"Invalid citations found: {invalid_citations}. Valid range: 1-{max_valid}")
|
| 87 |
+
return False, invalid_citations
|
| 88 |
+
|
| 89 |
+
return True, []
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def format_citations(self, sources: List[ChunkWithScore], style: str = "numeric") -> str:
|
| 93 |
+
"""
|
| 94 |
+
Format citations as reference list
|
| 95 |
+
|
| 96 |
+
Arguments:
|
| 97 |
+
----------
|
| 98 |
+
sources { list } : List of sources to format
|
| 99 |
+
|
| 100 |
+
style { str } : Citation style ('numeric', 'verbose')
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
--------
|
| 104 |
+
{ str } : Formatted citation text
|
| 105 |
+
"""
|
| 106 |
+
if not sources:
|
| 107 |
+
return ""
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
citations = list()
|
| 111 |
+
|
| 112 |
+
for i, source in enumerate(sources, 1):
|
| 113 |
+
if (style == "verbose"):
|
| 114 |
+
citation = self._format_verbose_citation(source = source,
|
| 115 |
+
number = i,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
citation = self._format_numeric_citation(source = source,
|
| 120 |
+
number = i,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
citations.append(citation)
|
| 124 |
+
|
| 125 |
+
citation_text = "\n".join(citations)
|
| 126 |
+
|
| 127 |
+
self.logger.debug(f"Formatted {len(citations)} citations in {style} style")
|
| 128 |
+
|
| 129 |
+
return citation_text
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
self.logger.error(f"Citation formatting failed: {repr(e)}")
|
| 133 |
+
return ""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _format_numeric_citation(self, source: ChunkWithScore, number: int) -> str:
|
| 137 |
+
"""
|
| 138 |
+
Format citation in numeric style with sanitization
|
| 139 |
+
|
| 140 |
+
Arguments:
|
| 141 |
+
----------
|
| 142 |
+
source { ChunkWithScore } : Source to format
|
| 143 |
+
|
| 144 |
+
number { int } : Citation number
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
--------
|
| 148 |
+
{ str } : Formatted citation
|
| 149 |
+
"""
|
| 150 |
+
chunk = source.chunk
|
| 151 |
+
|
| 152 |
+
parts = [f"[{number}]"]
|
| 153 |
+
|
| 154 |
+
# Add source information with proper sanitization
|
| 155 |
+
if (hasattr(chunk, 'metadata') and chunk.metadata):
|
| 156 |
+
if ('filename' in chunk.metadata):
|
| 157 |
+
# Sanitize filename more thoroughly
|
| 158 |
+
filename = str(chunk.metadata['filename'])
|
| 159 |
+
|
| 160 |
+
# Remove problematic characters that could break citation parsing: Keep only alphanumeric, spaces, dots, hyphens, underscores
|
| 161 |
+
filename = re.sub(r'[^\w\s\.\-]', '_', filename)
|
| 162 |
+
|
| 163 |
+
# Limit length to prevent overflow
|
| 164 |
+
if (len(filename) > 50):
|
| 165 |
+
filename = filename[:47] + "..."
|
| 166 |
+
|
| 167 |
+
parts.append(f"Source: {filename}")
|
| 168 |
+
|
| 169 |
+
if chunk.page_number:
|
| 170 |
+
parts.append(f"Page {chunk.page_number}")
|
| 171 |
+
|
| 172 |
+
if chunk.section_title:
|
| 173 |
+
# Sanitize section title similarly
|
| 174 |
+
section = str(chunk.section_title)
|
| 175 |
+
section = re.sub(r'[^\w\s\.\-]', '_', section)
|
| 176 |
+
|
| 177 |
+
if (len(section) > 40):
|
| 178 |
+
section = section[:37] + "..."
|
| 179 |
+
|
| 180 |
+
parts.append(f"Section: {section}")
|
| 181 |
+
|
| 182 |
+
# Add relevance score if available
|
| 183 |
+
if (source.score > 0):
|
| 184 |
+
parts.append(f"(Relevance: {source.score:.2f})")
|
| 185 |
+
|
| 186 |
+
return " ".join(parts)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _format_verbose_citation(self, source: ChunkWithScore, number: int) -> str:
|
| 190 |
+
"""
|
| 191 |
+
Format citation in verbose style - SAFER VERSION
|
| 192 |
+
|
| 193 |
+
Arguments:
|
| 194 |
+
----------
|
| 195 |
+
source { ChunkWithScore } : Source to format
|
| 196 |
+
|
| 197 |
+
number { int } : Citation number
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
--------
|
| 201 |
+
{ str } : Formatted citation
|
| 202 |
+
"""
|
| 203 |
+
chunk = source.chunk
|
| 204 |
+
|
| 205 |
+
parts = [f"Citation {number}:"]
|
| 206 |
+
|
| 207 |
+
# Document information with sanitization
|
| 208 |
+
if (hasattr(chunk, 'metadata')):
|
| 209 |
+
meta = chunk.metadata
|
| 210 |
+
|
| 211 |
+
if ('filename' in meta):
|
| 212 |
+
filename = str(meta['filename'])
|
| 213 |
+
filename = re.sub(r'[^\w\s\.\-]', '_', filename)
|
| 214 |
+
|
| 215 |
+
if (len(filename) > 50):
|
| 216 |
+
filename = filename[:47] + "..."
|
| 217 |
+
|
| 218 |
+
parts.append(f"Document: {filename}")
|
| 219 |
+
|
| 220 |
+
if ('title' in meta):
|
| 221 |
+
title = str(meta['title'])
|
| 222 |
+
title = re.sub(r'[^\w\s\.\-]', '_', title)
|
| 223 |
+
|
| 224 |
+
if (len(title) > 60):
|
| 225 |
+
title = title[:57] + "..."
|
| 226 |
+
|
| 227 |
+
parts.append(f"Title: {title}")
|
| 228 |
+
|
| 229 |
+
if ('author' in meta):
|
| 230 |
+
author = str(meta['author'])
|
| 231 |
+
author = re.sub(r'[^\w\s\.\-]', '_', author)
|
| 232 |
+
|
| 233 |
+
if (len(author) > 40):
|
| 234 |
+
author = author[:37] + "..."
|
| 235 |
+
|
| 236 |
+
parts.append(f"Author: {author}")
|
| 237 |
+
|
| 238 |
+
# Location information
|
| 239 |
+
location_parts = list()
|
| 240 |
+
|
| 241 |
+
if chunk.page_number:
|
| 242 |
+
location_parts.append(f"page {chunk.page_number}")
|
| 243 |
+
|
| 244 |
+
if chunk.section_title:
|
| 245 |
+
section = str(chunk.section_title)
|
| 246 |
+
section = re.sub(r'[^\w\s\.\-]', '_', section)
|
| 247 |
+
|
| 248 |
+
if (len(section) > 40):
|
| 249 |
+
section = section[:37] + "..."
|
| 250 |
+
|
| 251 |
+
location_parts.append(f"section '{section}'")
|
| 252 |
+
|
| 253 |
+
if location_parts:
|
| 254 |
+
parts.append("(" + ", ".join(location_parts) + ")")
|
| 255 |
+
|
| 256 |
+
# Relevance information
|
| 257 |
+
if (source.score > 0):
|
| 258 |
+
parts.append(f"[Relevance score: {source.score:.3f}]")
|
| 259 |
+
|
| 260 |
+
return " ".join(parts)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def generate_citation_map(self, sources: List[ChunkWithScore]) -> Dict[int, Dict]:
|
| 264 |
+
"""
|
| 265 |
+
Generate mapping from citation numbers to source details
|
| 266 |
+
|
| 267 |
+
Arguments:
|
| 268 |
+
----------
|
| 269 |
+
sources { list } : List of sources
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
--------
|
| 273 |
+
{ dict } : Dictionary mapping citation numbers to source details
|
| 274 |
+
"""
|
| 275 |
+
citation_map = dict()
|
| 276 |
+
|
| 277 |
+
for i, source in enumerate(sources, 1):
|
| 278 |
+
chunk = source.chunk
|
| 279 |
+
|
| 280 |
+
citation_map[i] = {'chunk_id' : chunk.chunk_id,
|
| 281 |
+
'document_id' : chunk.document_id,
|
| 282 |
+
'score' : source.score,
|
| 283 |
+
'text_preview' : chunk.text[:200] + "..." if (len(chunk.text) > 200) else chunk.text,
|
| 284 |
+
'metadata' : getattr(chunk, 'metadata', {}),
|
| 285 |
+
'page_number' : chunk.page_number,
|
| 286 |
+
'section_title' : chunk.section_title,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
return citation_map
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def replace_citation_markers(self, text: str, citation_map: Dict[int, str]) -> str:
|
| 293 |
+
"""
|
| 294 |
+
Replace citation markers with formatted citations - FIXED
|
| 295 |
+
|
| 296 |
+
Arguments:
|
| 297 |
+
----------
|
| 298 |
+
text { str } : Text containing citation markers
|
| 299 |
+
|
| 300 |
+
citation_map { dict } : Mapping of citation numbers to formatted strings
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
--------
|
| 304 |
+
{ str } : Text with replaced citations
|
| 305 |
+
"""
|
| 306 |
+
def replacement(match):
|
| 307 |
+
try:
|
| 308 |
+
citation_num = int(match.group(1))
|
| 309 |
+
|
| 310 |
+
# Get replacement text and sanitize it
|
| 311 |
+
replacement_text = citation_map.get(citation_num, match.group(0))
|
| 312 |
+
|
| 313 |
+
return str(replacement_text)
|
| 314 |
+
|
| 315 |
+
except (ValueError, IndexError):
|
| 316 |
+
# Return original match if parsing fails
|
| 317 |
+
return match.group(0)
|
| 318 |
+
|
| 319 |
+
try:
|
| 320 |
+
return self.citation_pattern.sub(replacement, text)
|
| 321 |
+
|
| 322 |
+
except Exception as e:
|
| 323 |
+
self.logger.error(f"Citation replacement failed: {repr(e)}")
|
| 324 |
+
# Return original text on error
|
| 325 |
+
return text
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def get_citation_statistics(self, text: str, sources: List[ChunkWithScore]) -> Dict:
|
| 329 |
+
"""
|
| 330 |
+
Get statistics about citations in text
|
| 331 |
+
|
| 332 |
+
Arguments:
|
| 333 |
+
----------
|
| 334 |
+
text { str } : Text containing citations
|
| 335 |
+
|
| 336 |
+
sources { list } : List of sources
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
--------
|
| 340 |
+
{ dict } : Citation statistics
|
| 341 |
+
"""
|
| 342 |
+
citation_numbers = self.extract_citations(text = text)
|
| 343 |
+
|
| 344 |
+
if not citation_numbers:
|
| 345 |
+
return {"total_citations": 0}
|
| 346 |
+
|
| 347 |
+
# Calculate citation distribution
|
| 348 |
+
citation_counts = defaultdict(int)
|
| 349 |
+
|
| 350 |
+
for num in citation_numbers:
|
| 351 |
+
if 1 <= num <= len(sources):
|
| 352 |
+
source = sources[num - 1]
|
| 353 |
+
doc_id = source.chunk.document_id
|
| 354 |
+
citation_counts[doc_id] += 1
|
| 355 |
+
|
| 356 |
+
return {"total_citations" : len(citation_numbers),
|
| 357 |
+
"unique_citations" : len(set(citation_numbers)),
|
| 358 |
+
"citation_distribution": dict(citation_counts),
|
| 359 |
+
"citations_per_source" : {i: citation_numbers.count(i) for i in set(citation_numbers)},
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def ensure_citation_consistency(self, text: str, sources: List[ChunkWithScore]) -> str:
|
| 364 |
+
"""
|
| 365 |
+
Ensure citation numbers are consistent and sequential
|
| 366 |
+
|
| 367 |
+
Arguments:
|
| 368 |
+
----------
|
| 369 |
+
text { str } : Text containing citations
|
| 370 |
+
|
| 371 |
+
sources { list } : List of sources
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
--------
|
| 375 |
+
{ str } : Text with consistent citations
|
| 376 |
+
"""
|
| 377 |
+
is_valid, invalid_citations = self.validate_citations(text, sources)
|
| 378 |
+
|
| 379 |
+
if not is_valid:
|
| 380 |
+
self.logger.warning("Invalid citations found, attempting to fix consistency")
|
| 381 |
+
|
| 382 |
+
# Extract current citations and create mapping
|
| 383 |
+
current_citations = self.extract_citations(text = text)
|
| 384 |
+
|
| 385 |
+
if not current_citations:
|
| 386 |
+
return text
|
| 387 |
+
|
| 388 |
+
# Create mapping from old to new citation numbers
|
| 389 |
+
citation_mapping = dict()
|
| 390 |
+
|
| 391 |
+
for i, old_num in enumerate(sorted(set(current_citations)), 1):
|
| 392 |
+
if (old_num <= len(sources)):
|
| 393 |
+
citation_mapping[old_num] = i
|
| 394 |
+
|
| 395 |
+
# Replace citations in text
|
| 396 |
+
def consistent_replacement(match):
|
| 397 |
+
old_num = int(match.group(1))
|
| 398 |
+
new_num = citation_mapping.get(old_num, old_num)
|
| 399 |
+
|
| 400 |
+
return f"[{new_num}]"
|
| 401 |
+
|
| 402 |
+
fixed_text = self.citation_pattern.sub(consistent_replacement, text)
|
| 403 |
+
|
| 404 |
+
self.logger.info(f"Fixed citation consistency: {current_citations} -> {list(citation_mapping.values())}")
|
| 405 |
+
|
| 406 |
+
return fixed_text
|
| 407 |
+
|
| 408 |
+
return text
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# Global citation tracker instance
|
| 412 |
+
_citation_tracker = None
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def get_citation_tracker() -> CitationTracker:
|
| 416 |
+
"""
|
| 417 |
+
Get global citation tracker instance (singleton)
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
--------
|
| 421 |
+
{ CitationTracker } : CitationTracker instance
|
| 422 |
+
"""
|
| 423 |
+
global _citation_tracker
|
| 424 |
+
|
| 425 |
+
if _citation_tracker is None:
|
| 426 |
+
_citation_tracker = CitationTracker()
|
| 427 |
+
|
| 428 |
+
return _citation_tracker
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
@handle_errors(error_type = CitationError, log_error = True, reraise = False)
|
| 432 |
+
def extract_and_validate_citations(text: str, sources: List[ChunkWithScore]) -> Tuple[List[int], bool]:
|
| 433 |
+
"""
|
| 434 |
+
Convenience function for citation extraction and validation
|
| 435 |
+
|
| 436 |
+
Arguments:
|
| 437 |
+
----------
|
| 438 |
+
text { str } : Text containing citations
|
| 439 |
+
|
| 440 |
+
sources { list } : List of sources
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
--------
|
| 444 |
+
{ Tuple[List[int], bool] } : (citation_numbers, is_valid)
|
| 445 |
+
"""
|
| 446 |
+
tracker = get_citation_tracker()
|
| 447 |
+
citations = tracker.extract_citations(text = text)
|
| 448 |
+
is_valid, _ = tracker.validate_citations(text = text,
|
| 449 |
+
sources = sources,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
return citations, is_valid
|