Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from utils import convert_to_base64, convert_to_html | |
| import requests | |
| import boto3 | |
| import sagemaker | |
| import os | |
| import json | |
| region = os.getenv("region") | |
| sm_endpoint_name = os.getenv("sm_endpoint_name") | |
| access_key = os.getenv("access_key") | |
| secret_key = os.getenv("secret_key") | |
| hf_token = os.getenv("hf_read_access") | |
| session = boto3.Session( | |
| aws_access_key_id=access_key, | |
| aws_secret_access_key=secret_key, | |
| region_name=region | |
| ) | |
| sess = sagemaker.Session(boto_session=session) | |
| smr = session.client("sagemaker-runtime") | |
| headers = {'Content-Type': 'application/json'} | |
| st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide") | |
| #st.set_page_config(layout="wide") | |
| st.title("Multimodal Model on AWS Inf2") | |
| st.subheader("LLaVA-1.6-Mistral-7B") | |
| st.text(" LLaVA (or Large Language and Vision Assistant), an open-source large multi-modal model. This demo is running on AWS Inferentia2 built with Llava1.6.") | |
| def upload_image(): | |
| image_list=["./images/view.jpg", | |
| "./images/cat.jpg", | |
| "./images/olympic.jpg", | |
| "./images/usa.jpg", | |
| "./images/box.jpg"] | |
| name_list=["view(from internet)", | |
| "cat(from internet)", | |
| "paris 2024(from internet)", | |
| "statue of liberty(from internet)", | |
| "box(from my camera)"] | |
| images_all = dict(zip(name_list, image_list)) | |
| user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list) | |
| print(user_option) | |
| if user_option!="–Select–": | |
| image_names=[images_all[user_option]] | |
| else: | |
| image_names=[] | |
| st.text("OR") | |
| images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
| #print(images) | |
| # assert max number of images, e.g. 1 | |
| assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop()) | |
| if images or image_names: | |
| if images: | |
| image_names=[] | |
| # convert images to base64 | |
| images_b64 = [] | |
| for image in images+image_names: | |
| image_b64 = convert_to_base64(image) | |
| images_b64.append(image_b64) | |
| # display images in multiple columns | |
| cols = st.columns(len(images_b64)) ##only process first image | |
| for i, col in enumerate(cols): | |
| col.markdown(f"**Image {i+1}**") | |
| col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True) | |
| break #only process first image | |
| st.markdown("---") | |
| return images_b64[0] #only process first image | |
| st.stop() | |
| def ask_llm(prompt, byte_image): | |
| payload = { | |
| "prompt":prompt, | |
| "image": byte_image, | |
| "parameters": { | |
| "top_k": 100, | |
| "top_p": 0.1, | |
| "temperature": 0.2, | |
| } | |
| } | |
| #response = requests.post(url, json=payload, headers=headers) | |
| response_model = smr.invoke_endpoint( | |
| EndpointName=sm_endpoint_name, | |
| Body=json.dumps(payload), | |
| ContentType="application/json", | |
| ) | |
| #return response.text | |
| return response_model['Body'].read().decode('utf8') | |
| def app(): | |
| st.markdown("---") | |
| c1, c2 = st.columns(2) | |
| with c2: | |
| image_b64 = upload_image() | |
| with c1: | |
| question = st.chat_input("Ask a question about this image") | |
| if not question: st.stop() | |
| with c1: | |
| with st.chat_message("question"): | |
| st.markdown(question, unsafe_allow_html=True) | |
| with st.spinner("Thinking..."): | |
| res = ask_llm(question, image_b64) | |
| with st.chat_message("response"): | |
| st.write(res) | |
| if __name__ == "__main__": | |
| app() |