Spaces:
Runtime error
Runtime error
| """Streamlit web app for radiological condition prediction from chest X-ray images""" | |
| import os | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| import keras.backend.tensorflow_backend as tb | |
| tb._SYMBOLIC_SCOPE.value = True | |
| import numpy as np | |
| import time | |
| import cv2 | |
| import streamlit as st | |
| import pandas as pd | |
| from tensorflow.keras.models import load_model | |
| from keras.utils.data_utils import get_file | |
| from tensorflow import keras | |
| from PIL import Image | |
| from explain import get_gradcam | |
| st.set_option("deprecation.showfileUploaderEncoding", False) | |
| IMAGE_SIZE = 224 | |
| classes = [ | |
| 'Normal', # No Finding | |
| 'Enlarged- \nCardiomediastinum', | |
| 'Cardiomegaly', | |
| 'Lung Opacity', | |
| 'Lung Lesion', | |
| 'Edema', | |
| 'Consolidation', | |
| 'Pneumonia', | |
| 'Atelectasis', | |
| 'Pneumothorax', | |
| 'Pleural Effusion', | |
| 'Pleural Other', | |
| 'Fracture', | |
| 'Support Devices' | |
| ] | |
| def cached_model(): | |
| URL = "https://github.com/hasibzunair/cxr-predictor/releases/latest/download/CheXpert_DenseNet121_res224.h5" | |
| weights_path = get_file( | |
| "CheXpert_DenseNet121_res224.h5", | |
| URL) | |
| model = load_model(weights_path, compile = False) | |
| return model | |
| def preprocess_image(uploaded_file): | |
| # Load image | |
| img_array = np.array(Image.open(uploaded_file)) | |
| # Normalize to [0,1] | |
| img_array = img_array.astype('float32') | |
| img_array /= 255 | |
| # Check that images are 2D arrays | |
| if len(img_array.shape) > 2: | |
| img_array = img_array[:, :, 0] | |
| # Convert to 3-channel | |
| img_array = np.stack((img_array, img_array, img_array), axis=-1) | |
| # Convert to array | |
| img_array = cv2.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE)) | |
| return img_array | |
| def make_prediction(file): | |
| # Preprocess input image | |
| image = preprocess_image(file) | |
| # Add batch axis | |
| image = np.expand_dims(image, 0) | |
| # Predict | |
| predictions = model.predict(image) | |
| return predictions | |
| # Get model | |
| model = cached_model() | |
| if __name__ == '__main__': | |
| logo = np.array(Image.open("media/logo_rs.png")) | |
| st.image(logo, use_column_width=True) | |
| st.write(""" | |
| # AI Assisted Radiology Tool | |
| :red_circle: NOT FOR MEDICAL USE! | |
| This is a prototype application which demonstrates how artifical intelligence based systems can identify | |
| medical conditions from images. Using this tool, medical professionals can process an image to | |
| confirm or aid in their diagnosis which may serve as a second opinion. | |
| The tool predicts the presence of 14 different radiological conditions from a given chest X-ray image. This is built using data from a large public | |
| [database](https://stanfordmlgroup.github.io/competitions/chexpert/). | |
| Questions? Email me at `[email protected]`. | |
| If you continue, you assume all liability when using the system. | |
| Please upload a posterior to anterior (PA) view chest X-ray image file (PNG, JPG, JPEG) | |
| to predict the presence of the radiological conditions. Here's an example. | |
| """) | |
| example_image = np.array(Image.open("media/example.jpg")) | |
| st.image(example_image, caption="An example input.", width=100) | |
| uploaded_file = st.file_uploader("Upload file either by clicking 'Browze Files' or drag and drop the image.", type=None) | |
| if uploaded_file is not None: | |
| # Uploaded image | |
| original_image = np.array(Image.open(uploaded_file)) | |
| st.image(original_image, caption="Input chest X-ray image", use_column_width=True) | |
| st.write("") | |
| st.write("Analyzing the input image. Please wait...") | |
| start_time = time.time() | |
| # Preprocess input image | |
| image = preprocess_image(uploaded_file) | |
| image = np.expand_dims(image, 0) | |
| # Predict | |
| predictions = make_prediction(uploaded_file) | |
| st.write("Took {} seconds to run.".format( | |
| round(time.time() - start_time, 3))) | |
| # Convert probabilty scores to percent for easy interpretation | |
| predictions_percent = [x*100 for x in predictions[0]] | |
| df = pd.DataFrame({'classes' : classes, 'predictions' : predictions_percent}) | |
| df = df.sort_values('predictions') | |
| # Top predicted class | |
| top_predicted_class = list(df['classes'])[-1] | |
| fig, ax = plt.subplots() | |
| ax.grid(False) | |
| ax.set_xlim(0, 1) | |
| ax.set_xticks([x for x in range(0,110, 10)]) | |
| #ax.set_xticklabels(['Low','Average','High']) | |
| ax.tick_params(axis='y', labelcolor='r', which='major', labelsize=8) | |
| ax.barh(df['classes'], df['predictions'], color='green') | |
| ax.set_xlabel('Confidence (in percent)', fontsize=15) | |
| ax.set_ylabel('Radiological conditions', fontsize=15) | |
| st.pyplot(fig) | |
| st.write(""" | |
| For more information about the top most predicted radiological finding, | |
| you can click on 'Get Heatmap' button which will highlight the most influential features | |
| in the image affecting the prediction. | |
| """) | |
| if st.button('Get Heatmap'): | |
| st.write('Generating heatmap for regions predictive of {}. Indicated by red.'.format(top_predicted_class)) | |
| heatmap = get_gradcam(uploaded_file, model, "conv5_block16_2_conv", predictions) | |
| orig_heatmap = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0])) | |
| # Convert to 3 channel if not | |
| if len(original_image.shape) > 2: | |
| original_image = original_image[:, :, 0] | |
| original_image = np.stack((original_image, original_image, original_image), axis=-1) | |
| st.image(np.concatenate((original_image,orig_heatmap),axis=1), caption="Input image + Heatmap ", use_column_width=True) |