badal-12 commited on
Commit
4388126
Β·
verified Β·
1 Parent(s): d0c859a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import uuid
4
+ import numpy as np
5
+ from datetime import datetime
6
+ from flask import Flask, request, jsonify, send_from_directory
7
+ from flask_cors import CORS
8
+ from werkzeug.utils import secure_filename
9
+ import google.generativeai as genai
10
+ from datasets import load_dataset
11
+ from sentence_transformers import SentenceTransformer
12
+ from transformers import pipeline
13
+ import faiss
14
+ import markdown
15
+
16
+ # Configuration
17
+ GEMINI_API_KEY = (
18
+ "API_KEY" # Replace with your actual API key
19
+ )
20
+ genai.configure(api_key=GEMINI_API_KEY)
21
+
22
+ # Initialize Flask app
23
+ app = Flask(__name__, static_folder="../frontend", static_url_path="")
24
+ CORS(app)
25
+
26
+ # RAG Model Initialization
27
+ print("πŸš€ Initializing RAG System...")
28
+
29
+ # Load medical guidelines dataset
30
+ print("πŸ“‚ Loading dataset...")
31
+ dataset = load_dataset("epfl-llm/guidelines", split="train")
32
+ TITLE_COL = "title"
33
+ CONTENT_COL = "clean_text"
34
+
35
+ # Initialize models
36
+ print("πŸ€– Loading AI models...")
37
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
38
+ qa_pipeline = pipeline(
39
+ "question-answering", model="distilbert-base-cased-distilled-squad"
40
+ )
41
+
42
+ # Build FAISS index
43
+ print("πŸ” Building FAISS index...")
44
+
45
+
46
+ def embed_text(batch):
47
+ combined_texts = [
48
+ f"{title} {content[:200]}"
49
+ for title, content in zip(batch[TITLE_COL], batch[CONTENT_COL])
50
+ ]
51
+ return {"embeddings": embedder.encode(combined_texts, show_progress_bar=False)}
52
+
53
+
54
+ dataset = dataset.map(embed_text, batched=True, batch_size=32)
55
+ dataset.add_faiss_index(column="embeddings")
56
+
57
+
58
+ # Processing Functions
59
+ def format_response(text):
60
+ """Convert Markdown text to HTML for proper frontend display."""
61
+ return markdown.markdown(text)
62
+
63
+
64
+ def summarize_report(report):
65
+ """Generate a clinical summary using QA and Gemini model."""
66
+ questions = [
67
+ "Patient's age?",
68
+ "Patient's gender?",
69
+ "Current symptoms?",
70
+ "Medical history?",
71
+ ]
72
+
73
+ answers = []
74
+ for q in questions:
75
+ result = qa_pipeline(question=q, context=report)
76
+ answers.append(result["answer"] if result["score"] > 0.1 else "Not specified")
77
+
78
+ model = genai.GenerativeModel("gemini-1.5-flash")
79
+ prompt = f"""Create clinical summary from:
80
+ - Age: {answers[0]}
81
+ - Gender: {answers[1]}
82
+ - Symptoms: {answers[2]}
83
+ - History: {answers[3]}
84
+
85
+ Format: "[Age] [Gender] with [History], presenting with [Symptoms]"
86
+ Add relevant medical context."""
87
+ return format_response(model.generate_content(prompt).text.strip())
88
+
89
+
90
+ def rag_retrieval(query, k=3):
91
+ """Retrieve relevant guidelines using FAISS."""
92
+ query_embedding = embedder.encode([query])
93
+ scores, examples = dataset.get_nearest_examples("embeddings", query_embedding, k=k)
94
+ return [
95
+ {
96
+ "title": title,
97
+ "content": content[:1000],
98
+ "source": examples.get("source", ["N/A"] * len(examples[TITLE_COL]))[
99
+ i
100
+ ], # Default to "N/A" if no source field
101
+ "score": float(score),
102
+ }
103
+ for i, (title, content, score) in enumerate(
104
+ zip(
105
+ examples[TITLE_COL],
106
+ examples[CONTENT_COL],
107
+ scores,
108
+ )
109
+ )
110
+ ]
111
+
112
+
113
+ def generate_recommendations(report):
114
+ """Generate treatment recommendations with RAG context."""
115
+ guidelines = rag_retrieval(report)
116
+ context = "Relevant Clinical Guidelines:\n" + "\n".join(
117
+ [f"β€’ {g['title']}: {g['content']} [Source: {g['source']}]" for g in guidelines]
118
+ )
119
+
120
+ model = genai.GenerativeModel("gemini-1.5-flash")
121
+ prompt = f"""Generate treatment recommendations using these guidelines:
122
+ {context}
123
+
124
+ Patient Presentation:
125
+ {report}
126
+
127
+ Format with:
128
+ - Bold section headers
129
+ - Clear bullet points
130
+ - Evidence markers [Guideline #]
131
+ - Risk-benefit analysis
132
+ - Include references to the sources provided where applicable
133
+ """
134
+ recommendations = model.generate_content(prompt).text.strip()
135
+ # Extract references (non-"N/A" sources)
136
+ references = [g["source"] for g in guidelines if g["source"] != "N/A"]
137
+ return format_response(recommendations), references
138
+
139
+
140
+ def generate_risk_assessment(summary):
141
+ """Generate risk assessment using the summary."""
142
+ model = genai.GenerativeModel("gemini-1.5-flash")
143
+ prompt = f"""Analyze clinical risk:
144
+ {summary}
145
+
146
+ Output format:
147
+ Risk Score: 0-100
148
+ Alert Level: πŸ”΄ High/🟑 Medium/🟒 Low
149
+ Key Risk Factors: bullet points
150
+ Recommended Actions: bullet points"""
151
+ return format_response(model.generate_content(prompt).text.strip())
152
+
153
+
154
+ # Flask Endpoints
155
+ @app.route("/upload-txt", methods=["POST"])
156
+ def handle_upload():
157
+ """Handle text file upload and return processed data."""
158
+ if "file" not in request.files:
159
+ return jsonify({"error": "No file provided"}), 400
160
+
161
+ file = request.files["file"]
162
+ if not file or not file.filename.endswith(".txt"):
163
+ return jsonify({"error": "Invalid file, must be a .txt file"}), 400
164
+
165
+ try:
166
+ content = file.read().decode("utf-8")
167
+ if not content.strip():
168
+ return jsonify({"error": "File is empty"}), 400
169
+
170
+ summary = summarize_report(content)
171
+ recommendations, references = generate_recommendations(content)
172
+ risk_assessment = generate_risk_assessment(summary)
173
+
174
+ return jsonify(
175
+ {
176
+ "session_id": str(uuid.uuid4()),
177
+ "timestamp": datetime.now().isoformat(),
178
+ "summary": summary,
179
+ "recommendations": recommendations,
180
+ "risk_assessment": risk_assessment,
181
+ "references": references, # Added references field
182
+ }
183
+ )
184
+ except Exception as e:
185
+ return jsonify({"error": f"Processing failed: {str(e)}"}), 500
186
+
187
+
188
+ # Serve static files
189
+ @app.route("/")
190
+ def serve_index():
191
+ """Serve the index.html file."""
192
+ return send_from_directory(app.static_folder, "index.html")
193
+
194
+
195
+ @app.route("/<path:path>")
196
+ def serve_static(path):
197
+ """Serve other static files from the frontend directory."""
198
+ return send_from_directory(app.static_folder, path)
199
+
200
+
201
+ # Run the app
202
+ if __name__ == "__main__":
203
+ app.run(host="0.0.0.0", port=5000, debug=True)