GyroScope / use_with_UI.py
LH-Tech-AI's picture
Update use_with_UI.py
97805b4 verified
Raw
History Blame Contribute Delete
4.06 kB
import streamlit as st
import torch
import requests
from io import BytesIO
from PIL import Image
from torchvision import transforms
from transformers import ResNetForImageClassification
# --- 1. UI Configuration ---
# 'centered' ensures the app doesn't stretch across massive screens
st.set_page_config(page_title="GyroScope Rotation Corrector", layout="centered", page_icon="🔄")
# --- 2. Model Caching ---
# @st.cache_resource prevents reloading the model every time the user interacts with the UI
@st.cache_resource
def load_model():
model = ResNetForImageClassification.from_pretrained("LH-Tech-AI/GyroScope")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, device
model, device = load_model()
# --- 3. Preprocessing & Logic ---
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
ANGLES = [0, 90, 180, 270]
def predict_and_correct(img):
# Ensure image is RGB
img = img.convert("RGB")
tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(pixel_values=tensor).logits
probs = torch.softmax(logits, dim=1)[0]
pred = probs.argmax().item()
detected = ANGLES[pred]
correction = (360 - detected) % 360
# Apply correction (PIL rotate is counter-clockwise)
corrected_img = img.rotate(correction, expand=True)
# Format probabilities for the UI
prob_dict = {f"{a}°": f"{p:.4f}" for a, p in zip(ANGLES, probs)}
return corrected_img, detected, correction, prob_dict
# --- 4. Frontend Layout ---
st.title("🔄 Auto Rotation Corrector")
st.markdown("Upload an image or provide a URL to automatically fix its orientation.")
st.divider()
# Input Selection
input_method = st.radio("Select Image Source:", ["Upload a File", "Enter Image URL"], horizontal=True)
img = None
# Input Handling
if input_method == "Upload a File":
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file:
img = Image.open(uploaded_file)
else:
url = st.text_input("Enter Image URL:", placeholder="https://example.com/image.jpg")
if url:
try:
response = requests.get(url, timeout=5)
img = Image.open(BytesIO(response.content))
except Exception as e:
st.error(f"Could not load image from URL. Error: {e}")
# Preview & Processing Section
if img:
st.divider()
manual_angle = st.slider("Manual Pre-rotation", min_value=0, max_value=360, value=0, step=90)
if manual_angle != 0:
img = img.rotate(manual_angle, expand=True) # expand=True prevents cropping
# Use columns to keep the UI compact and side-by-side
col_left, col_right = st.columns(2)
with col_left:
st.subheader("Input Preview")
st.image(img, use_container_width=True)
# The primary action button
process_btn = st.button("✨ Correct Rotation", type="primary", use_container_width=True)
with col_right:
st.subheader("Output Preview")
if process_btn:
with st.spinner("Analyzing..."):
corrected_img, detected, correction, prob_dict = predict_and_correct(img)
# Show result
st.image(corrected_img, use_container_width=True)
# Show stats
st.success(f"✅ Detected: **{detected}°** | Correction: **{correction}°**")
# Hidden expander for clean UI, but available if the user wants details
with st.expander("📊 View Probability Details"):
st.json(prob_dict)
else:
# Placeholder container before the button is clicked
st.info("Waiting for processing... Click the button on the left to correct the rotation.")