-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit.py
121 lines (104 loc) · 3.97 KB
/
streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from PIL import Image
import torch
import torchvision.transforms as transforms
from aws_controls import AwsControl
import mlflow
from mlflow import MlflowClient
import dagshub
# Define a function to make predictions
def predict(image):
class_dict = {
0: "AbdomenCT",
1: "BreastMRI",
2: "ChestCT",
3: "CXR",
4: "Hand",
5: "HeadCT",
}
transform = transforms.Compose(
[
transforms.Resize((64, 64)),
transforms.ToTensor(),
]
)
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
# Perform prediction
with torch.no_grad():
outputs = st.session_state["model"](image_tensor)
_, predicted = torch.max(outputs, 1)
return class_dict[predicted.item()]
# Streamlit app
def main():
if "prediction" not in st.session_state:
st.session_state["prediction"] = None
st.session_state["uploaded"] = False
if "loaded" not in st.session_state:
if "aws" not in st.session_state:
st.session_state["aws"] = AwsControl(
st.secrets["aws_key"], st.secrets["aws_secret"]
)
dagshub.init("Medical-MNIST-MLOPs-CT-CD", "JatinSingh28", mlflow=True)
client = MlflowClient()
registered_model_name = "Production Model V1"
version = client.get_model_version_by_alias(
registered_model_name, "prod"
).version
model_uri = f"models:/{registered_model_name}/{version}"
if "model" not in st.session_state:
st.session_state["model"] = mlflow.pytorch.load_model(model_uri)
# model.load_state_dict(torch.load("./ResNet/model_ckpt/last-v4.ckpt")["state_dict"])
st.session_state["model"].eval()
st.session_state["loaded"] = True
st.title("Medical MNIST")
# with st.expander("Download sample images"):
# st.download_button(
# "Download",
# data=sample_imgs,
# file_name="sample_imgs.zip",
# mime="application/zip",
# )
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_image is None:
st.session_state["prediction"] = None
st.session_state["uploaded"] = False
if uploaded_image is not None:
image = Image.open(uploaded_image)
image.save("uploaded_img.jpeg")
st.image(image, caption="Uploaded Image", use_column_width=True)
if st.button("Classify"):
st.session_state["prediction"] = predict(image)
prediction = st.session_state["prediction"]
st.success(f"Prediction Class: {prediction}")
# st.session_state
if st.session_state["prediction"] is not None:
options = ["YES", "NO"]
st.write("Select NO for queuing image for retraining")
selected_option = st.radio("Is the prediction correct?", options)
if selected_option == "NO":
classes = [
"AbdomenCT",
"BreastMRI",
"ChestCT",
"CXR",
"Hand",
"HeadCT",
] # Define your classes
corrected_class = st.selectbox("Select Correct Class", classes)
if st.button("Submit"):
if not st.session_state["uploaded"]:
st.session_state["aws"].upload_image(
"uploaded_img.jpeg", corrected_class
)
st.success(
f"Image and corrected class {corrected_class} uploaded to AWS for retraining"
)
st.session_state["uploaded"] = True
else:
st.warning("Image already uploaded for retraining")
if __name__ == "__main__":
st.set_page_config(
page_title="Medical MNIST",
page_icon="🧠",
)
main()