|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
from g4f.client import Client |
|
|
import markdown2 |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import io |
|
|
|
|
|
client = Client() |
|
|
|
|
|
|
|
|
def truncate_history(history, max_tokens=2048): |
|
|
truncated_history = [] |
|
|
total_tokens = 0 |
|
|
|
|
|
for user_msg, assistant_msg in reversed(history): |
|
|
user_tokens = len(user_msg.split()) |
|
|
assistant_tokens = len(assistant_msg.split()) |
|
|
|
|
|
if total_tokens + user_tokens + assistant_tokens > max_tokens: |
|
|
break |
|
|
|
|
|
total_tokens += user_tokens + assistant_tokens |
|
|
truncated_history.insert(0, (user_msg, assistant_msg)) |
|
|
|
|
|
return truncated_history |
|
|
|
|
|
|
|
|
def format_output(text): |
|
|
""" |
|
|
チャットGPTスタイルのマークダウン形式に対応するためのフォーマット関数 |
|
|
""" |
|
|
return markdown2.markdown(text, extras=[ |
|
|
"fenced-code-blocks", |
|
|
"tables", |
|
|
"task_list", |
|
|
"strike", |
|
|
"spoiler", |
|
|
"markdown-in-html" |
|
|
]) |
|
|
|
|
|
def convert_image_to_dataurl(image): |
|
|
""" |
|
|
アップロードされた画像をdataURL形式に変換 |
|
|
""" |
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
return f"data:image/png;base64,{img_str}" |
|
|
|
|
|
def respond(message, history, system_message, max_tokens, temperature, top_p, model_choice, web_search, image=None): |
|
|
|
|
|
history = truncate_history(history, max_tokens=2048) |
|
|
|
|
|
|
|
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
|
|
|
|
|
for user_msg, assistant_msg in history: |
|
|
if user_msg: |
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
if assistant_msg: |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
|
|
|
if image: |
|
|
|
|
|
img = Image.open(io.BytesIO(base64.b64decode(image.split(",")[1]))) |
|
|
img = img.resize((512, 512)) |
|
|
buffered = io.BytesIO() |
|
|
img.save(buffered, format="PNG") |
|
|
img_data = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
image_message = [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_data}"}}] |
|
|
messages.append({"role": "user", "content": image_message}) |
|
|
else: |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model=model_choice, |
|
|
messages=messages, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
web_search=web_search |
|
|
) |
|
|
formatted_response = format_output(response.choices[0].message.content) |
|
|
return formatted_response |
|
|
except Exception as e: |
|
|
print(f"エラー発生: {e}") |
|
|
return "エラーが発生しました。再試行してください。" |
|
|
|
|
|
|
|
|
def chat(message, history, system_message, max_tokens, temperature, top_p, model_choice, web_search, image): |
|
|
if message.strip() == "" and not image: |
|
|
return "", history, history |
|
|
|
|
|
|
|
|
if image: |
|
|
image_data_url = convert_image_to_dataurl(image) |
|
|
else: |
|
|
image_data_url = None |
|
|
|
|
|
print("メッセージ送信直後の履歴:", history) |
|
|
|
|
|
response = respond(message, history, system_message, max_tokens, temperature, top_p, model_choice, web_search, image_data_url) |
|
|
history = history + [(message, response)] |
|
|
|
|
|
print("AIの回答後の履歴:", history) |
|
|
|
|
|
|
|
|
return "", history, history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
system_message = gr.Textbox( |
|
|
value="あなたは日本語しか話せません。あなたは最新の医療支援AIです。薬の紹介、薬の提案、薬の作成など、さまざまなタスクに答えます。また、新しい薬を開発する際は、既存のものに頼らずに画期的なアイデアを出します。", |
|
|
label="システムメッセージ" |
|
|
) |
|
|
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="トークン制限") |
|
|
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=2, step=0.1, label="Temperature (数値が大きいほど様々な回答をします。)") |
|
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling) (数値が低いと回答候補が上位のみになります。)") |
|
|
model_choice = gr.Radio(choices=["gpt-4o-mini", "gpt-4o", "o3-mini"], value="gpt-4o", label="モデル選択") |
|
|
web_search = gr.Checkbox(value=True, label="WEB検索") |
|
|
|
|
|
|
|
|
chatbot_input = gr.Textbox(show_label=False, placeholder="ここにメッセージを入力してください...", lines=2) |
|
|
image_input = gr.Image(type="pil", label="画像をアップロード") |
|
|
submit_btn = gr.Button("送信") |
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
chat_history_display = gr.Chatbot(label="チャット履歴") |
|
|
|
|
|
|
|
|
state = gr.State([]) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
chat, |
|
|
inputs=[chatbot_input, state, system_message, max_tokens, temperature, top_p, model_choice, web_search, image_input], |
|
|
outputs=[chatbot_input, chat_history_display, state] |
|
|
) |
|
|
|
|
|
|
|
|
chatbot_input.submit( |
|
|
chat, |
|
|
inputs=[chatbot_input, state, system_message, max_tokens, temperature, top_p, model_choice, web_search, image_input], |
|
|
outputs=[chatbot_input, chat_history_display, state] |
|
|
) |
|
|
|
|
|
demo.launch() |