diff --git a/README.md b/README.md
index 9d2aeb0b87046f32cebd4c63496a456a71efdff6..3cc9821b153e7d673f0c8cbc12de0a91a772fd88 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
---
-title: AioMedica2
-emoji: 👁
-colorFrom: red
-colorTo: indigo
+title: AioMedica
+emoji: 🏃
+colorFrom: purple
+colorTo: yellow
sdk: streamlit
sdk_version: 1.10.0
app_file: app.py
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..43f465f240d3a5412e33f0126eb2c23df0729e94
--- /dev/null
+++ b/app.py
@@ -0,0 +1,175 @@
+import streamlit as st
+import openslide
+import os
+from streamlit_option_menu import option_menu
+import torch
+
+
+@st.cache(suppress_st_warning=True)
+def load_model():
+ from predict import Predictor
+ predictor = Predictor()
+ return predictor
+
+@st.cache(suppress_st_warning=True)
+def load_dependencies():
+
+ os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
+ os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
+ os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
+
+
+def main():
+
+
+
+
+
+ # environment variables for the inference api
+ os.environ['DATA_DIR'] = 'queries'
+ os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
+ os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
+ os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
+ os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
+
+
+ # manually put the metadata in the metadata folder
+ os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
+
+ # manually put the desired weights in the weights folder
+ os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
+ os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
+ os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
+
+
+ #st.set_page_config(page_title="",layout='wide')
+ predictor = load_model()#Predictor()
+
+
+
+
+
+ ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
+ CONTACT_TEXT = """
+ _Built by Christian Cancedda and LabLab lads with love_ ❤️
+ [![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
+ [![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
+ Star project repository:
+ [![GitHub stars](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus/inference-graph-transformer)
+ """
+ VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
+ DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
+
+
+
+ with st.sidebar:
+ choice = option_menu("LastMinute - Diagnosis",
+ ["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
+ icons=['house', 'upload', 'activity', 'person lines fill'],
+ menu_icon="app-indicator", default_index=0,
+ styles={
+ # "container": {"padding": "5!important", "background-color": "#fafafa", },
+ "container": {"border-radius": ".0rem"},
+ # "icon": {"color": "orange", "font-size": "25px"},
+ # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
+ # "--hover-color": "#eee"},
+ # "nav-link-selected": {"background-color": "#02ab21"},
+ }
+ )
+ st.sidebar.markdown(
+ """
+
+
+ Project Repository
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ unsafe_allow_html=True,
+ )
+
+
+ if choice == "About":
+ st.title(choice)
+ README = requests.get("https://raw.githubusercontent.com/Chris1nexus/inference-graph-transformer/master/README.md").text
+ README = str(README).replace('width="1200"','width="700"')
+ # st.title(choose)
+ st.markdown(README, unsafe_allow_html=True)
+
+ if choice == "Visualize WSI slide":
+ st.title(choice)
+ st.markdown(VISUALIZE_TEXT)
+
+ uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
+ if uploaded_file is not None:
+ ori = openslide.OpenSlide(uploaded_file.name)
+ width, height = ori.dimensions
+
+ REDUCTION_FACTOR = 20
+ w, h = int(width/512), int(height/512)
+ w_r, h_r = int(width/20), int(height/20)
+ resized_img = ori.get_thumbnail((w_r,h_r))
+ resized_img = resized_img.resize((w_r,h_r))
+ ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
+ #print('ratios ', ratio_w, ratio_h)
+ w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
+ st.image(resized_img, use_column_width='never')
+
+ if choice == "Cancer Detection":
+ state = dict()
+
+ st.title(choice)
+ st.markdown(DETECT_TEXT)
+ uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
+ st.markdown("Examples can be chosen at the [GDC Data repository](https://portal.gdc.cancer.gov/repository?facetTab=cases&filters=%7B%22op%22%3A%22and%22%2C%22content%22%3A%5B%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.primary_site%22%2C%22value%22%3A%5B%22bronchus%20and%20lung%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.project.program.name%22%2C%22value%22%3A%5B%22TCGA%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.project.project_id%22%2C%22value%22%3A%5B%22TCGA-LUAD%22%2C%22TCGA-LUSC%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22files.experimental_strategy%22%2C%22value%22%3A%5B%22Tissue%20Slide%22%5D%7D%7D%5D%7D)")
+ st.markdown("Alternatively, for simplicity few test cases are provided at the [drive link](https://drive.google.com/drive/folders/1u3SQa2dytZBHHh6eXTlMKY-pZGZ-pwkk?usp=share_link)")
+
+
+ if uploaded_file is not None:
+ # To read file as bytes:
+ #print(uploaded_file)
+ with open(os.path.join(uploaded_file.name),"wb") as f:
+ f.write(uploaded_file.getbuffer())
+ with st.spinner(text="Computation is running"):
+ predicted_class, viz_dict = predictor.predict(uploaded_file.name)
+ st.info('Computation completed.')
+ st.header(f'Predicted to be: {predicted_class}')
+ st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
+ state['cur'] = predicted_class
+ mapper = {'ORI': predicted_class, predicted_class:'ORI'}
+ readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
+ #def fn():
+ # st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
+ # state['cur'] = mapper[state['cur']]
+ # return
+
+ #st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
+ #st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
+ st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
+ # use_column_width='never',
+ )
+
+
+ if choice == "Contact":
+ st.title(choice)
+ st.markdown(CONTACT_TEXT)
+
+if __name__ == '__main__':
+ #'''
+ load_dependencies()
+ #'''
+ main()
diff --git a/feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb b/feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..c91a0bd0b15954274975eaea110bd1110e834ff9
--- /dev/null
+++ b/feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb
@@ -0,0 +1,3503 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "11c4fe3c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "OrderedDict([('module.features.0.weight',\n",
+ " tensor([[[[ 2.3375e-02, 5.5993e-03, 4.2364e-02, ..., 1.2101e-02,\n",
+ " -2.6842e-02, -3.0364e-02],\n",
+ " [ 1.2243e-02, -9.1156e-03, -1.9976e-02, ..., -4.5601e-02,\n",
+ " -3.3681e-02, -4.0585e-03],\n",
+ " [-7.8155e-03, 1.4921e-02, 2.3364e-02, ..., 8.7603e-03,\n",
+ " -3.1223e-02, 7.4876e-03],\n",
+ " ...,\n",
+ " [-2.9299e-02, -2.6057e-02, -4.0052e-02, ..., 4.2710e-03,\n",
+ " 1.9729e-03, -1.2235e-02],\n",
+ " [ 1.6445e-02, 3.8652e-03, 4.2917e-03, ..., -8.5154e-03,\n",
+ " -1.3266e-02, -7.6143e-03],\n",
+ " [ 2.3999e-02, 8.4055e-03, -2.8751e-02, ..., -1.1365e-02,\n",
+ " 3.8881e-03, -1.7586e-02]],\n",
+ " \n",
+ " [[-1.4130e-02, -4.9244e-02, -1.6324e-02, ..., 4.1007e-02,\n",
+ " -4.0374e-02, -2.6552e-02],\n",
+ " [-3.2905e-02, -2.9145e-02, -5.5822e-03, ..., -1.6007e-02,\n",
+ " -1.5566e-03, -1.6690e-02],\n",
+ " [ 2.4724e-02, -2.8561e-02, 1.9321e-02, ..., -3.5075e-02,\n",
+ " -1.6752e-02, 2.1253e-02],\n",
+ " ...,\n",
+ " [-2.0854e-02, -1.6552e-02, -3.2742e-02, ..., 1.2465e-02,\n",
+ " 1.9453e-02, -4.9739e-02],\n",
+ " [-2.5184e-02, 3.3581e-02, 1.6366e-03, ..., -1.6559e-02,\n",
+ " -4.3148e-02, -8.8248e-03],\n",
+ " [-1.7976e-02, -1.0308e-02, 1.9864e-02, ..., -2.1598e-02,\n",
+ " 5.0608e-04, -2.4172e-02]],\n",
+ " \n",
+ " [[ 4.9666e-02, -1.2670e-02, 1.9931e-02, ..., 8.9254e-03,\n",
+ " 4.6066e-02, 4.8928e-02],\n",
+ " [ 1.5310e-02, -1.3443e-02, 2.6382e-02, ..., -3.9132e-03,\n",
+ " -1.9607e-03, -3.5969e-02],\n",
+ " [-1.9942e-02, -5.7225e-02, 1.8700e-02, ..., 3.6640e-02,\n",
+ " 3.5779e-03, 1.2500e-02],\n",
+ " ...,\n",
+ " [ 1.1875e-02, -3.3648e-03, -3.0441e-02, ..., -5.6659e-02,\n",
+ " 1.8092e-02, 4.2179e-02],\n",
+ " [-1.9221e-02, 8.7840e-03, 2.1695e-02, ..., 1.2839e-03,\n",
+ " -2.7966e-02, 5.1216e-03],\n",
+ " [-1.9038e-02, 9.0134e-04, 2.1077e-03, ..., 2.9699e-02,\n",
+ " 1.8513e-02, 3.3447e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 3.7965e-02, -1.2022e-02, 2.2249e-02, ..., 1.6461e-02,\n",
+ " 1.6206e-02, -1.6585e-02],\n",
+ " [-2.7620e-02, -4.7865e-02, 1.3980e-02, ..., 3.9625e-02,\n",
+ " 1.2485e-02, -5.5151e-02],\n",
+ " [-4.1348e-04, -2.9432e-02, -1.4788e-02, ..., 2.3406e-02,\n",
+ " 1.6614e-02, 1.7552e-02],\n",
+ " ...,\n",
+ " [ 2.3246e-02, 2.1007e-02, -1.2156e-02, ..., -2.6140e-02,\n",
+ " 3.8020e-02, -3.0928e-02],\n",
+ " [ 2.1980e-02, -9.2860e-03, -5.9852e-03, ..., 2.2137e-02,\n",
+ " 4.7298e-03, 9.8544e-04],\n",
+ " [-2.5694e-02, -1.1514e-02, 5.8983e-02, ..., 3.6341e-02,\n",
+ " -2.8853e-02, -2.8959e-02]],\n",
+ " \n",
+ " [[-1.0707e-02, -2.0920e-02, -3.5067e-03, ..., 1.5692e-02,\n",
+ " -4.2022e-02, -5.8786e-02],\n",
+ " [-1.2049e-02, -6.1877e-03, -2.7395e-02, ..., 8.5530e-03,\n",
+ " 5.5608e-02, -3.1539e-02],\n",
+ " [-2.4982e-02, -1.3160e-02, 3.1325e-02, ..., -2.7327e-02,\n",
+ " -5.6841e-02, -2.9666e-03],\n",
+ " ...,\n",
+ " [ 2.1402e-03, 4.6426e-02, -2.4192e-02, ..., 5.4034e-03,\n",
+ " -5.9481e-02, -1.2253e-02],\n",
+ " [-1.2507e-02, -9.2855e-05, 3.8571e-03, ..., 4.7442e-02,\n",
+ " -1.3722e-02, 1.5930e-03],\n",
+ " [ 1.1638e-02, -4.4012e-02, 5.2303e-02, ..., 2.1680e-02,\n",
+ " 3.6332e-02, -2.9015e-02]],\n",
+ " \n",
+ " [[-6.1262e-03, 2.1868e-03, 2.9714e-03, ..., -3.9569e-03,\n",
+ " 8.6403e-03, -4.3809e-03],\n",
+ " [-5.1549e-03, -1.5469e-02, -1.4357e-03, ..., -2.9453e-02,\n",
+ " -7.2058e-04, 3.4153e-02],\n",
+ " [-3.5302e-04, -1.2305e-02, -4.1532e-02, ..., 7.4666e-03,\n",
+ " 3.7797e-02, 1.0348e-02],\n",
+ " ...,\n",
+ " [-3.3086e-02, 5.9402e-03, 5.4123e-02, ..., 1.9769e-02,\n",
+ " -6.2268e-02, 1.9798e-03],\n",
+ " [-5.6262e-02, -1.4047e-02, -1.7646e-02, ..., 1.5906e-03,\n",
+ " -7.5155e-03, -1.0734e-02],\n",
+ " [ 3.9025e-02, -4.2364e-02, 2.8937e-05, ..., -2.7328e-02,\n",
+ " 9.2807e-03, 9.3495e-04]]],\n",
+ " \n",
+ " \n",
+ " [[[ 2.3433e-03, -1.9035e-02, 1.0023e-02, ..., 1.1675e-02,\n",
+ " 4.3943e-05, 5.4307e-04],\n",
+ " [ 2.4110e-02, 2.1324e-02, 4.6116e-03, ..., -8.9188e-03,\n",
+ " 2.5617e-02, 7.1546e-03],\n",
+ " [-2.5087e-02, 2.3352e-03, 6.2174e-03, ..., 2.4089e-02,\n",
+ " -4.7804e-03, 5.0964e-03],\n",
+ " ...,\n",
+ " [-1.1212e-03, 1.2286e-02, 3.0641e-02, ..., 2.2648e-02,\n",
+ " 4.1326e-03, -3.0177e-02],\n",
+ " [ 3.4274e-03, 4.3688e-03, -3.4311e-02, ..., 8.7346e-03,\n",
+ " 2.9993e-02, 9.9230e-04],\n",
+ " [ 1.0400e-02, -4.4071e-03, 6.5103e-03, ..., -1.4547e-02,\n",
+ " -9.2856e-03, -2.8808e-03]],\n",
+ " \n",
+ " [[ 8.4968e-03, 3.1743e-02, 1.8266e-02, ..., -3.0564e-02,\n",
+ " 4.9976e-02, -1.1739e-02],\n",
+ " [-3.2356e-03, -1.9133e-02, 1.0271e-02, ..., 3.0369e-03,\n",
+ " 5.1930e-03, -2.6841e-03],\n",
+ " [-1.4563e-02, 7.1441e-03, -2.2971e-02, ..., -1.3364e-02,\n",
+ " -1.6875e-02, 5.2100e-02],\n",
+ " ...,\n",
+ " [-8.4915e-03, 8.6577e-03, 7.7203e-03, ..., 2.9735e-02,\n",
+ " 3.5839e-02, 6.2928e-03],\n",
+ " [ 1.7178e-02, 3.5923e-02, -8.4103e-02, ..., 2.5245e-02,\n",
+ " 2.9331e-02, -9.0751e-03],\n",
+ " [-4.7001e-02, 1.9169e-02, 6.7551e-03, ..., 4.2300e-02,\n",
+ " -3.5133e-02, -2.9424e-02]],\n",
+ " \n",
+ " [[ 4.2076e-02, 3.0278e-02, 9.7546e-03, ..., 2.6859e-02,\n",
+ " -3.9730e-03, 1.9414e-02],\n",
+ " [-1.0068e-03, 1.8244e-02, 2.1646e-02, ..., -7.3066e-03,\n",
+ " -2.0182e-02, 4.2991e-02],\n",
+ " [ 2.4365e-02, -1.9178e-02, 3.2335e-02, ..., -3.5166e-02,\n",
+ " 1.4258e-02, -4.2840e-02],\n",
+ " ...,\n",
+ " [-1.3621e-02, 2.2653e-02, 3.2518e-02, ..., -1.1237e-02,\n",
+ " 3.9021e-02, 4.3618e-03],\n",
+ " [ 4.7386e-03, -3.7146e-02, -1.8900e-02, ..., -5.6209e-03,\n",
+ " 3.0549e-02, 3.5451e-02],\n",
+ " [-2.3577e-02, 2.1167e-02, -4.5990e-02, ..., -6.4087e-03,\n",
+ " -2.0090e-02, 9.2213e-03]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-1.7300e-02, -1.0968e-02, 2.3756e-02, ..., -2.8067e-02,\n",
+ " -1.2273e-02, -6.7963e-03],\n",
+ " [-1.9014e-02, 4.2994e-02, -2.9301e-02, ..., -1.5515e-02,\n",
+ " 3.0392e-02, 2.3771e-02],\n",
+ " [-3.1195e-02, 3.3641e-02, -4.9220e-02, ..., -1.7844e-02,\n",
+ " -1.1322e-02, -4.7037e-03],\n",
+ " ...,\n",
+ " [ 1.3981e-02, 5.4877e-02, -6.8753e-03, ..., 1.3635e-02,\n",
+ " -9.1420e-03, 3.1688e-02],\n",
+ " [ 1.3518e-02, 1.8627e-02, -2.9229e-02, ..., 8.1206e-03,\n",
+ " -6.5161e-03, -1.2963e-02],\n",
+ " [-4.1449e-02, 1.2236e-02, 1.1663e-02, ..., 3.8267e-02,\n",
+ " -1.9472e-02, 1.8405e-02]],\n",
+ " \n",
+ " [[-9.8533e-03, 4.9424e-02, 4.8071e-02, ..., -9.0419e-03,\n",
+ " 2.4288e-02, -4.3603e-03],\n",
+ " [ 5.8574e-04, 1.4928e-02, -2.0026e-02, ..., 1.6564e-02,\n",
+ " 7.6141e-03, 2.9512e-02],\n",
+ " [ 2.0338e-02, 1.0564e-02, -1.0914e-02, ..., -6.4865e-03,\n",
+ " 1.8478e-02, -2.3697e-03],\n",
+ " ...,\n",
+ " [ 5.6546e-03, 2.4127e-02, -5.7034e-02, ..., 3.6024e-02,\n",
+ " 2.8125e-02, -1.4412e-02],\n",
+ " [-1.4562e-03, 1.1494e-02, 2.0532e-02, ..., -3.4795e-02,\n",
+ " 1.6016e-02, 4.6263e-02],\n",
+ " [ 1.8855e-02, 2.8185e-02, 4.2835e-02, ..., -2.0183e-02,\n",
+ " 3.4963e-02, -5.9240e-03]],\n",
+ " \n",
+ " [[ 3.2398e-02, 2.7336e-02, -1.1512e-02, ..., 4.0024e-02,\n",
+ " 4.8417e-02, -8.8634e-03],\n",
+ " [-6.7418e-03, 6.6913e-03, 1.0604e-02, ..., 1.0133e-02,\n",
+ " -5.0271e-02, 2.7529e-02],\n",
+ " [ 4.2106e-02, 3.0234e-02, -3.6480e-03, ..., 1.3692e-02,\n",
+ " -1.0858e-02, -1.4118e-02],\n",
+ " ...,\n",
+ " [-2.4369e-02, -3.8789e-02, 4.2428e-03, ..., -7.5641e-03,\n",
+ " 4.2958e-02, 1.8423e-02],\n",
+ " [ 2.1679e-02, 2.9357e-02, 7.8422e-03, ..., 4.7591e-03,\n",
+ " 2.7958e-02, -3.0234e-03],\n",
+ " [ 5.8209e-03, 3.4338e-03, 2.4520e-02, ..., 2.5085e-03,\n",
+ " 5.5165e-02, 2.3223e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-6.9245e-02, -1.1918e-02, 3.0330e-02, ..., -7.0775e-03,\n",
+ " -1.3785e-02, -9.5928e-03],\n",
+ " [-1.1632e-02, -2.3685e-02, -3.9180e-02, ..., -5.9309e-02,\n",
+ " 2.6369e-02, 6.5659e-04],\n",
+ " [-1.1453e-03, 1.4085e-03, 9.6764e-03, ..., -3.5827e-02,\n",
+ " 5.9550e-03, 1.0719e-02],\n",
+ " ...,\n",
+ " [ 3.3185e-03, 2.2316e-02, -2.3351e-02, ..., -1.0927e-02,\n",
+ " -3.0209e-02, 1.1315e-02],\n",
+ " [-3.6965e-02, -2.6860e-02, -1.8028e-02, ..., -1.6357e-04,\n",
+ " 4.1140e-02, -1.8615e-03],\n",
+ " [ 1.7740e-02, 9.2312e-03, 4.3650e-03, ..., -3.0605e-02,\n",
+ " -1.1486e-02, 1.2793e-02]],\n",
+ " \n",
+ " [[-5.6697e-03, -1.4674e-02, -1.7249e-02, ..., -6.2317e-03,\n",
+ " 2.7349e-02, -3.9416e-02],\n",
+ " [-8.8279e-03, -2.5957e-03, 3.5875e-02, ..., 5.8308e-03,\n",
+ " -5.5580e-03, -2.8438e-02],\n",
+ " [-4.1452e-02, 2.2671e-02, -3.2239e-02, ..., 2.7244e-02,\n",
+ " -2.0010e-03, 5.7491e-02],\n",
+ " ...,\n",
+ " [-2.7492e-02, 5.1052e-02, -4.3853e-02, ..., -3.1139e-02,\n",
+ " 1.8314e-02, -4.4898e-03],\n",
+ " [-1.2398e-02, 3.1807e-02, 1.0428e-02, ..., -1.7304e-03,\n",
+ " 1.7393e-04, -2.6142e-02],\n",
+ " [-3.6152e-02, 2.2367e-02, -2.1544e-02, ..., 3.4823e-04,\n",
+ " -2.1448e-03, -4.5074e-03]],\n",
+ " \n",
+ " [[-1.9647e-02, -4.5944e-02, 1.5610e-02, ..., 1.8324e-02,\n",
+ " 1.8523e-02, -2.6029e-03],\n",
+ " [-5.0610e-02, 3.0383e-02, -1.2389e-02, ..., -1.4688e-02,\n",
+ " -3.4507e-03, -1.0137e-02],\n",
+ " [-2.5496e-02, -4.6650e-03, -3.2878e-02, ..., -4.2710e-02,\n",
+ " -4.7481e-03, 1.9729e-02],\n",
+ " ...,\n",
+ " [-2.5733e-03, 5.7768e-02, -1.2957e-04, ..., 8.6745e-03,\n",
+ " -2.5417e-02, -8.5791e-03],\n",
+ " [ 2.9283e-02, -8.3721e-03, -2.3964e-03, ..., -2.1602e-02,\n",
+ " -3.0959e-02, 4.2844e-02],\n",
+ " [ 2.9002e-02, -2.3411e-02, -3.9169e-02, ..., -3.9955e-02,\n",
+ " -2.9184e-02, 1.1949e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-6.4566e-03, 1.3145e-02, 1.3045e-02, ..., -3.0958e-02,\n",
+ " -2.2957e-02, -2.1346e-02],\n",
+ " [ 3.1348e-02, -2.1111e-02, 1.8779e-03, ..., -2.2084e-03,\n",
+ " 1.0736e-03, 6.5589e-03],\n",
+ " [-1.6862e-02, -1.4509e-02, -2.2391e-02, ..., 2.8254e-02,\n",
+ " 1.1151e-02, 3.5738e-02],\n",
+ " ...,\n",
+ " [ 3.2327e-02, -4.5994e-02, -7.5606e-03, ..., -1.7074e-02,\n",
+ " -6.4185e-03, -1.5941e-02],\n",
+ " [ 1.6623e-02, -1.3665e-02, 1.9817e-02, ..., -2.1725e-02,\n",
+ " -3.5567e-02, -2.4748e-02],\n",
+ " [-1.2258e-02, 3.5740e-02, -2.3850e-02, ..., -1.9130e-02,\n",
+ " -1.8982e-02, 9.1186e-03]],\n",
+ " \n",
+ " [[-4.1899e-02, 9.4187e-04, 4.6815e-02, ..., -1.1866e-02,\n",
+ " -9.6984e-03, -5.8996e-03],\n",
+ " [-2.1281e-02, 3.6341e-02, -1.4044e-02, ..., -6.0348e-04,\n",
+ " -1.1665e-02, -8.4142e-03],\n",
+ " [ 3.1403e-03, 2.9866e-02, 2.4069e-02, ..., -5.0233e-03,\n",
+ " -3.3209e-02, 1.8269e-02],\n",
+ " ...,\n",
+ " [-5.1381e-02, 3.3562e-02, 1.6398e-02, ..., -3.9078e-02,\n",
+ " -7.3412e-03, -2.5585e-03],\n",
+ " [-1.9666e-02, 1.8058e-02, -1.3134e-02, ..., 1.5786e-02,\n",
+ " 2.4163e-02, -1.1480e-02],\n",
+ " [-2.3845e-02, -1.3153e-02, 3.2843e-02, ..., -1.8643e-02,\n",
+ " 2.1755e-03, -2.5375e-02]],\n",
+ " \n",
+ " [[-1.9703e-02, 1.9801e-02, -3.0810e-02, ..., -1.6574e-02,\n",
+ " -2.3212e-02, -2.3294e-02],\n",
+ " [ 1.2184e-02, -3.8564e-02, 3.4919e-02, ..., -3.7129e-02,\n",
+ " 5.0054e-04, 1.6151e-03],\n",
+ " [ 4.5318e-02, -7.4944e-03, 2.0338e-02, ..., -1.8937e-02,\n",
+ " -2.8308e-02, -2.0223e-03],\n",
+ " ...,\n",
+ " [ 5.4708e-02, -2.0650e-02, -1.2043e-02, ..., 1.2570e-02,\n",
+ " -7.4001e-03, -3.8468e-02],\n",
+ " [ 3.1326e-05, -4.1838e-02, -2.3395e-03, ..., -8.4410e-03,\n",
+ " 3.6380e-02, 4.8269e-02],\n",
+ " [-9.6509e-03, -5.2517e-02, 3.6821e-02, ..., -2.9391e-02,\n",
+ " -3.9250e-02, 2.0744e-02]]]], device='cuda:0')),\n",
+ " ('module.features.4.0.conv1.weight',\n",
+ " tensor([[[[-9.2462e-02, -4.6187e-02, -6.3962e-02],\n",
+ " [ 3.6464e-02, 3.1752e-02, 5.5634e-02],\n",
+ " [ 5.3728e-02, 2.3950e-02, 5.7112e-02]],\n",
+ " \n",
+ " [[ 1.6550e-02, 3.7590e-02, -5.6057e-02],\n",
+ " [ 4.2331e-02, -4.6298e-02, -4.1679e-02],\n",
+ " [ 9.8096e-03, -2.4877e-02, -5.0601e-02]],\n",
+ " \n",
+ " [[ 1.6585e-02, 5.2531e-02, -1.1046e-02],\n",
+ " [-2.8249e-02, -1.3678e-02, -1.1900e-01],\n",
+ " [ 4.3205e-02, 1.1884e-01, 1.7970e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 1.5649e-02, 1.3387e-02, -1.0417e-01],\n",
+ " [-1.3216e-01, 3.2272e-03, -4.3647e-02],\n",
+ " [-3.5717e-02, 1.0049e-01, 3.0284e-02]],\n",
+ " \n",
+ " [[-4.5162e-02, 2.3759e-02, -1.8583e-02],\n",
+ " [-1.3453e-02, -5.3536e-02, 3.5570e-03],\n",
+ " [ 8.8063e-02, -2.8492e-02, -3.5059e-04]],\n",
+ " \n",
+ " [[ 6.3269e-02, -2.6969e-02, -7.0333e-03],\n",
+ " [ 1.4697e-02, -1.0806e-02, -8.7952e-02],\n",
+ " [-3.8361e-03, 1.5422e-04, -1.3386e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 2.2461e-02, -1.3477e-01, -6.0933e-02],\n",
+ " [-7.8706e-02, -1.1187e-01, 7.2931e-02],\n",
+ " [-1.0034e-02, -1.3897e-02, 1.0557e-01]],\n",
+ " \n",
+ " [[ 2.0277e-02, -2.3068e-02, 9.4356e-02],\n",
+ " [-8.1401e-02, -1.4270e-01, -8.5097e-02],\n",
+ " [-6.9182e-02, 8.8395e-02, -3.0874e-02]],\n",
+ " \n",
+ " [[-4.1500e-02, -5.5379e-02, 3.7250e-02],\n",
+ " [-8.8914e-03, -1.9892e-02, -4.7787e-02],\n",
+ " [-9.7006e-02, -1.7777e-02, -4.0463e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 9.2270e-02, 5.8624e-03, 6.5071e-02],\n",
+ " [ 1.1497e-01, 6.8832e-02, 3.2386e-02],\n",
+ " [-3.6723e-03, -7.2844e-04, 1.0449e-01]],\n",
+ " \n",
+ " [[ 4.2105e-02, 1.7258e-02, 2.7150e-02],\n",
+ " [-1.3325e-02, -3.0972e-02, -2.2033e-03],\n",
+ " [ 9.3903e-04, 2.1462e-02, 4.3957e-02]],\n",
+ " \n",
+ " [[ 9.2400e-03, 7.4915e-02, 1.8164e-02],\n",
+ " [ 1.8690e-02, 2.0276e-02, -2.9706e-02],\n",
+ " [ 4.1231e-02, 6.7357e-02, -1.1762e-01]]],\n",
+ " \n",
+ " \n",
+ " [[[ 1.1161e-01, 6.0464e-02, 6.2415e-02],\n",
+ " [ 3.0668e-02, -1.3328e-01, 1.4706e-02],\n",
+ " [-6.8011e-02, -3.7102e-02, -8.1162e-02]],\n",
+ " \n",
+ " [[-3.5553e-02, 5.0089e-03, 2.1187e-02],\n",
+ " [-6.3591e-02, -4.6052e-02, 3.4658e-02],\n",
+ " [-1.2683e-01, 3.3427e-02, 1.3262e-01]],\n",
+ " \n",
+ " [[ 3.0757e-02, -2.3997e-02, -6.3890e-02],\n",
+ " [-4.0926e-03, 5.7332e-02, 1.7442e-02],\n",
+ " [-2.9423e-02, -5.4034e-03, -9.0974e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 4.0161e-02, -4.6695e-02, 4.1360e-03],\n",
+ " [ 7.7046e-03, -3.3115e-02, -4.3606e-02],\n",
+ " [-1.1876e-01, 2.6893e-03, -4.9484e-02]],\n",
+ " \n",
+ " [[-1.2849e-02, 1.6794e-01, -6.0573e-03],\n",
+ " [-2.2438e-02, 9.0174e-03, 5.0475e-03],\n",
+ " [ 3.6589e-02, -3.5933e-02, 1.2792e-02]],\n",
+ " \n",
+ " [[ 5.5376e-02, 7.7030e-02, 8.7771e-02],\n",
+ " [ 5.4986e-03, 1.0659e-01, 1.6883e-03],\n",
+ " [-3.4126e-02, 1.3669e-01, -2.8325e-02]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-7.7317e-03, -4.6668e-02, 7.5650e-02],\n",
+ " [ 3.5873e-02, -4.3000e-02, -2.7691e-02],\n",
+ " [ 1.0248e-01, 1.3898e-02, -1.3904e-02]],\n",
+ " \n",
+ " [[-8.6134e-02, 6.3448e-02, -2.8724e-02],\n",
+ " [-4.2872e-02, 4.5522e-02, 1.0647e-01],\n",
+ " [-4.8477e-02, 1.3277e-02, 5.6465e-02]],\n",
+ " \n",
+ " [[-8.3595e-02, 6.6930e-02, -5.8604e-02],\n",
+ " [-6.2126e-02, -7.1881e-03, -1.2867e-02],\n",
+ " [-1.7585e-02, -8.0980e-03, -1.4380e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-4.1383e-02, 6.8068e-03, -9.0357e-03],\n",
+ " [-6.6095e-02, 1.3102e-02, -7.3603e-04],\n",
+ " [ 8.8705e-03, 4.9332e-02, -2.4792e-03]],\n",
+ " \n",
+ " [[ 3.5180e-03, 1.0274e-01, 2.6740e-02],\n",
+ " [ 5.2504e-02, -3.7355e-02, 3.9357e-02],\n",
+ " [-3.7058e-02, 1.9517e-02, 5.6316e-02]],\n",
+ " \n",
+ " [[-2.6484e-02, 9.8777e-02, -4.1960e-02],\n",
+ " [ 7.4271e-02, -1.0216e-02, -5.2095e-02],\n",
+ " [ 6.0615e-02, 7.6350e-02, -4.5450e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 3.4035e-02, 1.0943e-01, -1.3764e-01],\n",
+ " [-7.1371e-02, 5.8172e-02, -2.2341e-02],\n",
+ " [-3.0836e-02, 5.2470e-02, -6.0414e-02]],\n",
+ " \n",
+ " [[-2.3613e-02, -4.2929e-03, 7.0351e-02],\n",
+ " [ 5.8288e-02, 6.5354e-02, 5.5242e-02],\n",
+ " [-1.0364e-02, 2.7791e-02, -2.2508e-02]],\n",
+ " \n",
+ " [[ 8.9985e-02, 3.6585e-02, 1.3680e-02],\n",
+ " [-5.1424e-02, 4.0535e-02, 1.3652e-01],\n",
+ " [-7.4183e-02, 7.6157e-02, -1.9116e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-9.7462e-03, -5.6896e-02, -9.0187e-02],\n",
+ " [-2.2037e-02, 1.3929e-01, -4.4938e-02],\n",
+ " [-4.9996e-02, 4.9984e-02, -6.0025e-02]],\n",
+ " \n",
+ " [[ 2.5050e-02, -4.9562e-02, 7.2298e-02],\n",
+ " [ 7.2273e-02, 1.2818e-02, -3.1320e-02],\n",
+ " [-3.0858e-02, 3.4074e-02, 6.9350e-02]],\n",
+ " \n",
+ " [[ 3.9936e-02, 6.4482e-03, -1.4794e-02],\n",
+ " [ 1.4991e-03, 8.6426e-02, 6.8410e-02],\n",
+ " [-4.8722e-03, -3.9833e-02, 4.4456e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-5.7302e-03, -1.8632e-01, -2.2117e-02],\n",
+ " [-1.2020e-01, -6.2009e-02, 9.3804e-02],\n",
+ " [ 5.9880e-02, -2.3881e-02, -1.2358e-02]],\n",
+ " \n",
+ " [[-3.5485e-02, -1.1522e-03, -8.2160e-02],\n",
+ " [ 8.2682e-02, -2.3291e-02, -2.8066e-02],\n",
+ " [-1.9191e-02, 1.2396e-02, -8.7263e-02]],\n",
+ " \n",
+ " [[ 3.0666e-02, 6.9301e-02, 2.3709e-02],\n",
+ " [-3.1250e-02, 8.0222e-02, 2.2501e-02],\n",
+ " [ 3.4345e-02, 5.6890e-02, 5.9732e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-1.2429e-01, 6.9400e-02, -7.0972e-02],\n",
+ " [-7.4741e-02, -8.7197e-03, 6.0359e-03],\n",
+ " [ 8.8969e-02, 2.9732e-02, -7.2561e-02]],\n",
+ " \n",
+ " [[-1.1770e-02, 5.1681e-02, 6.5577e-02],\n",
+ " [-1.3407e-01, -1.3068e-01, -5.4128e-02],\n",
+ " [ 8.3333e-02, 7.6867e-02, -1.1552e-02]],\n",
+ " \n",
+ " [[-1.3286e-01, 5.4414e-02, -3.6517e-02],\n",
+ " [-5.2994e-02, 3.9329e-02, 1.1094e-02],\n",
+ " [-9.7109e-02, 6.7629e-02, -8.7167e-02]]]], device='cuda:0')),\n",
+ " ('module.features.4.0.conv2.weight',\n",
+ " tensor([[[[ 1.1233e-01, -2.2981e-03, 3.6705e-02],\n",
+ " [-1.2030e-02, -1.5351e-02, 9.2952e-02],\n",
+ " [-4.1985e-02, 5.4107e-02, -5.9251e-02]],\n",
+ " \n",
+ " [[ 8.1965e-02, -1.0954e-01, 6.8691e-02],\n",
+ " [ 6.0726e-02, 1.0515e-03, 1.0493e-01],\n",
+ " [-5.5332e-02, -3.6784e-02, 7.3365e-02]],\n",
+ " \n",
+ " [[-8.5047e-02, -2.0464e-02, -8.1830e-02],\n",
+ " [-2.3363e-02, 1.0971e-01, 7.4004e-02],\n",
+ " [-7.5470e-02, 8.6039e-02, -7.4229e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 4.2190e-02, 1.4027e-01, 5.8237e-02],\n",
+ " [ 5.6310e-02, -6.7712e-02, -6.2967e-02],\n",
+ " [-3.5065e-02, 3.8621e-02, -2.5502e-03]],\n",
+ " \n",
+ " [[ 2.2199e-02, -7.8476e-02, 8.9634e-03],\n",
+ " [ 2.7594e-02, 9.7943e-02, 5.6846e-02],\n",
+ " [-6.5797e-02, 4.9289e-03, -2.2984e-02]],\n",
+ " \n",
+ " [[-3.6827e-02, -7.8728e-03, -4.4337e-02],\n",
+ " [ 8.1624e-02, 8.3161e-03, 5.9610e-02],\n",
+ " [ 5.9407e-02, -3.9335e-02, -7.0567e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-8.1664e-02, 7.6529e-03, -3.4560e-03],\n",
+ " [ 6.5344e-02, -1.4775e-01, -6.3836e-02],\n",
+ " [-3.1269e-03, -2.9126e-02, -1.3252e-01]],\n",
+ " \n",
+ " [[ 8.3339e-02, -7.4031e-02, 8.3576e-02],\n",
+ " [ 5.1510e-03, 7.2560e-02, -2.5081e-02],\n",
+ " [ 3.1496e-02, 1.3989e-02, -1.6677e-02]],\n",
+ " \n",
+ " [[-7.5721e-02, -2.4193e-02, 1.4763e-01],\n",
+ " [-2.1573e-02, -4.5769e-03, -5.6464e-02],\n",
+ " [-6.6852e-02, -2.0293e-02, -3.7721e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 1.2082e-02, -3.0667e-02, 3.9955e-02],\n",
+ " [-3.5713e-02, 7.0917e-03, 1.1693e-01],\n",
+ " [-1.1999e-01, -2.3686e-02, 1.0753e-01]],\n",
+ " \n",
+ " [[-8.8364e-02, 7.2671e-02, 1.1850e-02],\n",
+ " [ 4.0822e-02, 9.6776e-02, -7.8760e-02],\n",
+ " [ 1.0391e-01, 1.5550e-02, 1.0867e-01]],\n",
+ " \n",
+ " [[-4.8628e-02, -1.1506e-02, 5.5846e-02],\n",
+ " [-6.7405e-04, 1.8967e-02, 5.6449e-03],\n",
+ " [-2.2366e-02, 7.1308e-02, 1.0016e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-2.1073e-02, -7.4765e-02, 4.0197e-02],\n",
+ " [ 6.5436e-02, 6.8928e-02, -3.0027e-02],\n",
+ " [-1.6745e-02, -2.1730e-02, 8.4882e-02]],\n",
+ " \n",
+ " [[ 1.1282e-01, -2.6245e-02, 6.4570e-02],\n",
+ " [-1.1578e-01, -1.5154e-02, 2.7958e-02],\n",
+ " [ 3.0415e-03, 4.7246e-02, 1.2109e-02]],\n",
+ " \n",
+ " [[ 1.1064e-02, -1.1093e-01, -2.7692e-02],\n",
+ " [-4.2947e-02, 3.4327e-02, 3.7007e-02],\n",
+ " [-3.0342e-02, 7.2168e-03, 1.9143e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-6.0603e-02, 1.8067e-01, 9.0881e-03],\n",
+ " [-5.0757e-02, 3.4056e-04, 6.2487e-02],\n",
+ " [-3.6068e-02, 4.6166e-02, 8.6541e-02]],\n",
+ " \n",
+ " [[ 6.9447e-02, -3.0232e-03, -1.4447e-02],\n",
+ " [-1.1953e-01, 3.6767e-02, 2.4693e-02],\n",
+ " [-1.2821e-01, -6.6559e-03, -4.7528e-02]],\n",
+ " \n",
+ " [[-4.7991e-02, 3.7157e-02, 1.9292e-02],\n",
+ " [-3.7560e-03, -7.4758e-02, 6.9171e-03],\n",
+ " [ 5.4880e-03, -1.0589e-01, 5.9222e-02]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 7.4470e-02, -3.2481e-02, 4.1360e-02],\n",
+ " [-3.6666e-03, 2.7318e-02, -2.3968e-02],\n",
+ " [ 9.2488e-02, -3.6974e-02, 9.5300e-03]],\n",
+ " \n",
+ " [[ 7.9547e-02, 2.0034e-02, -1.9778e-02],\n",
+ " [ 7.4181e-03, 6.6158e-02, -2.6734e-02],\n",
+ " [ 1.8545e-02, -1.0150e-01, 4.9060e-02]],\n",
+ " \n",
+ " [[-4.1054e-02, -1.2121e-01, 6.2199e-02],\n",
+ " [ 2.2207e-02, -1.8837e-02, 1.0597e-01],\n",
+ " [-3.9223e-02, 7.7222e-02, -2.2536e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-2.1562e-02, -5.3268e-03, 1.2969e-01],\n",
+ " [ 3.4468e-02, -6.5299e-02, 9.2592e-02],\n",
+ " [ 1.6014e-02, 5.5946e-02, -2.9213e-02]],\n",
+ " \n",
+ " [[ 1.1376e-02, -1.4155e-02, -7.2439e-02],\n",
+ " [-1.5408e-02, -2.0305e-02, 5.3932e-02],\n",
+ " [ 6.8005e-02, 6.4583e-02, 1.0505e-01]],\n",
+ " \n",
+ " [[-6.2856e-03, -6.9690e-03, 7.1899e-02],\n",
+ " [-5.6182e-03, -6.7596e-02, 8.8580e-02],\n",
+ " [-5.1786e-02, 5.0984e-02, -4.0118e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-1.2869e-01, -9.6242e-03, -3.9101e-02],\n",
+ " [ 1.1690e-01, -2.0011e-02, 1.7888e-01],\n",
+ " [ 1.3289e-01, -3.4749e-02, -4.5241e-02]],\n",
+ " \n",
+ " [[ 4.1905e-02, -1.0684e-01, -4.4939e-02],\n",
+ " [-8.2466e-03, 3.0330e-02, 1.9460e-03],\n",
+ " [-1.0260e-01, -4.2263e-02, -5.0739e-02]],\n",
+ " \n",
+ " [[-1.4587e-02, -1.7936e-02, 2.9521e-02],\n",
+ " [-1.3464e-01, 4.0443e-02, 4.9810e-02],\n",
+ " [ 4.7797e-02, 1.4375e-02, 8.8259e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 9.2974e-03, 2.9998e-02, -3.2140e-02],\n",
+ " [ 8.5341e-02, -2.1467e-02, 5.0297e-02],\n",
+ " [ 7.9830e-02, 1.3020e-02, -3.3859e-02]],\n",
+ " \n",
+ " [[-3.1406e-02, -3.9117e-02, -1.1063e-01],\n",
+ " [ 2.8238e-02, -8.2110e-02, 2.4808e-02],\n",
+ " [ 4.7568e-02, -2.1383e-01, -9.1426e-03]],\n",
+ " \n",
+ " [[-2.0197e-02, -9.5517e-06, 4.0441e-02],\n",
+ " [ 6.8926e-02, 1.9768e-02, 2.5460e-02],\n",
+ " [ 4.4746e-02, 1.7082e-02, 3.0773e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 6.1841e-02, -6.5545e-03, 5.5460e-03],\n",
+ " [-1.4118e-02, 1.2440e-02, -7.2500e-02],\n",
+ " [-1.3664e-03, -6.3462e-02, -5.2453e-02]],\n",
+ " \n",
+ " [[-2.1329e-02, 8.8623e-02, 1.7045e-02],\n",
+ " [ 2.5736e-02, 2.4333e-02, 9.1462e-02],\n",
+ " [ 1.6656e-02, -7.4451e-02, 6.4044e-02]],\n",
+ " \n",
+ " [[ 1.8426e-02, -2.6405e-02, 1.0379e-01],\n",
+ " [-1.7170e-02, 9.1147e-03, 3.7215e-02],\n",
+ " [-1.1100e-01, -4.4881e-02, 8.8977e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 2.5414e-02, -5.1767e-03, 1.1035e-02],\n",
+ " [ 7.9268e-02, -2.5036e-02, -4.6462e-02],\n",
+ " [-1.4646e-01, -2.6711e-02, -1.0542e-01]],\n",
+ " \n",
+ " [[-4.2877e-02, -8.2058e-02, -4.6482e-02],\n",
+ " [ 6.8054e-03, -1.4620e-02, -5.4781e-02],\n",
+ " [ 1.0607e-02, -8.7744e-03, -5.6498e-02]],\n",
+ " \n",
+ " [[-6.3612e-03, -2.1136e-02, 1.2030e-02],\n",
+ " [-1.0481e-02, -2.1997e-02, -9.0608e-03],\n",
+ " [-3.8871e-02, -6.0018e-04, 1.0755e-01]]]], device='cuda:0')),\n",
+ " ('module.features.4.1.conv1.weight',\n",
+ " tensor([[[[-2.9280e-02, 5.7228e-02, 3.5699e-02],\n",
+ " [-2.3619e-02, 1.8647e-02, 8.9333e-02],\n",
+ " [-1.3515e-02, -3.3410e-02, 6.6962e-02]],\n",
+ " \n",
+ " [[-2.7122e-02, -9.0241e-03, -8.4448e-02],\n",
+ " [ 1.0711e-02, 7.1210e-02, 2.8500e-02],\n",
+ " [-3.2766e-02, -1.5924e-02, 9.2216e-02]],\n",
+ " \n",
+ " [[ 4.7169e-02, -7.6823e-02, -7.7361e-02],\n",
+ " [-1.5131e-02, -8.4519e-03, 1.7770e-02],\n",
+ " [ 9.1017e-02, 1.6489e-01, -1.1881e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-2.6887e-02, 1.1245e-02, 4.5791e-03],\n",
+ " [-3.6761e-02, 1.2582e-01, 1.1945e-02],\n",
+ " [ 6.3552e-02, -4.0764e-02, 7.0398e-02]],\n",
+ " \n",
+ " [[ 6.4111e-03, 5.8965e-02, -6.6704e-02],\n",
+ " [ 7.0903e-03, -4.4084e-02, 5.8607e-03],\n",
+ " [-7.9157e-02, 8.0710e-02, 5.3255e-02]],\n",
+ " \n",
+ " [[ 4.9754e-02, -3.8833e-02, 5.9919e-02],\n",
+ " [ 9.8103e-03, -4.2643e-04, 3.0538e-02],\n",
+ " [ 9.3018e-02, -5.5808e-02, 1.5056e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 8.1160e-04, -8.5566e-02, -3.4079e-02],\n",
+ " [ 3.5419e-02, -1.5424e-02, -5.0355e-02],\n",
+ " [-3.8926e-02, -6.8501e-02, 9.5525e-03]],\n",
+ " \n",
+ " [[ 1.3876e-02, -3.5668e-02, -1.5049e-02],\n",
+ " [ 4.1467e-04, -9.9467e-03, 4.6658e-02],\n",
+ " [-5.0346e-03, 3.1488e-02, -1.3610e-04]],\n",
+ " \n",
+ " [[ 3.4313e-02, 6.0405e-02, 7.3079e-02],\n",
+ " [ 3.0630e-02, 2.2675e-03, -1.3825e-02],\n",
+ " [ 1.9256e-02, -8.3649e-02, -1.9868e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-2.7544e-02, -7.5189e-02, 1.2101e-02],\n",
+ " [-5.2104e-02, -8.1663e-02, 2.3889e-02],\n",
+ " [-3.1997e-02, 1.0296e-01, 4.5187e-03]],\n",
+ " \n",
+ " [[ 4.1556e-02, 8.7092e-02, 1.0084e-02],\n",
+ " [ 2.0435e-02, 1.4916e-02, -5.1732e-02],\n",
+ " [ 5.2304e-02, 3.0570e-02, -2.0941e-02]],\n",
+ " \n",
+ " [[ 2.3169e-02, -4.5008e-02, -3.2341e-02],\n",
+ " [ 1.0699e-02, -8.6836e-02, 1.9169e-02],\n",
+ " [ 1.5688e-02, 1.6211e-01, 5.3303e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-6.5159e-03, 9.3245e-03, -1.1683e-02],\n",
+ " [-1.0310e-01, 7.4767e-02, 2.9678e-02],\n",
+ " [-3.9806e-03, -1.0650e-01, 8.1413e-02]],\n",
+ " \n",
+ " [[-5.1333e-02, 1.1462e-01, -9.6628e-02],\n",
+ " [ 1.0913e-01, -7.2096e-03, -8.1218e-02],\n",
+ " [ 2.6997e-02, -3.9076e-02, 1.6155e-03]],\n",
+ " \n",
+ " [[ 4.3165e-03, 6.9659e-02, -2.2012e-02],\n",
+ " [-7.1076e-03, 1.0487e-02, -1.6196e-02],\n",
+ " [-2.0985e-02, 5.7500e-02, 4.4786e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-4.9050e-02, 1.3593e-02, 1.6235e-01],\n",
+ " [ 9.1918e-02, -1.8695e-02, 4.6067e-02],\n",
+ " [-1.2878e-02, -3.5153e-02, -3.4204e-03]],\n",
+ " \n",
+ " [[ 9.2478e-02, -8.1305e-02, 1.1401e-01],\n",
+ " [-5.5514e-02, 3.0807e-04, 2.9226e-02],\n",
+ " [-3.5388e-03, -1.0554e-02, 4.9842e-02]],\n",
+ " \n",
+ " [[-1.6965e-03, 6.3461e-02, -4.4206e-04],\n",
+ " [-6.7280e-02, 2.2864e-02, -2.1471e-02],\n",
+ " [-2.7613e-02, 3.8291e-02, 1.5294e-01]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 3.5842e-02, 7.2023e-02, 4.1372e-03],\n",
+ " [-5.6176e-03, -2.5475e-02, 7.2281e-02],\n",
+ " [ 5.3528e-02, -3.8676e-02, 5.2703e-03]],\n",
+ " \n",
+ " [[ 1.8223e-02, 6.9698e-02, -1.7604e-02],\n",
+ " [ 6.9803e-03, -5.7443e-02, -3.3450e-02],\n",
+ " [-4.4618e-03, 1.8633e-02, 1.2111e-01]],\n",
+ " \n",
+ " [[-1.1872e-01, -1.0022e-01, 2.0097e-02],\n",
+ " [ 1.9121e-02, 3.3582e-02, -4.3687e-02],\n",
+ " [-1.7471e-02, 4.4938e-02, -6.1471e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 6.0367e-03, 7.5599e-04, 8.8230e-02],\n",
+ " [-2.7833e-02, 1.1672e-01, -5.7861e-03],\n",
+ " [ 5.0682e-02, 3.0452e-02, -1.0254e-01]],\n",
+ " \n",
+ " [[ 3.3197e-02, 5.2363e-02, 4.0486e-02],\n",
+ " [-1.4445e-02, -1.1716e-02, -1.8212e-02],\n",
+ " [-1.2584e-02, -6.3745e-02, 4.8277e-02]],\n",
+ " \n",
+ " [[ 3.4524e-02, 4.8264e-02, 3.4181e-02],\n",
+ " [-2.1468e-02, -6.5613e-02, -4.3188e-02],\n",
+ " [ 5.4996e-03, -3.5989e-02, 7.9056e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[-1.5350e-02, -3.6039e-02, -1.0346e-01],\n",
+ " [ 1.0253e-01, 4.1605e-02, -4.7496e-02],\n",
+ " [-1.3337e-02, 2.6657e-04, -1.8195e-02]],\n",
+ " \n",
+ " [[-5.3061e-02, -3.5731e-02, 3.0896e-02],\n",
+ " [-7.0008e-02, 8.0442e-02, -5.3065e-02],\n",
+ " [ 3.8142e-02, -7.9275e-02, -5.3120e-02]],\n",
+ " \n",
+ " [[-8.0647e-02, -2.1549e-02, -3.9406e-02],\n",
+ " [-1.5421e-02, 2.9551e-03, -4.9613e-02],\n",
+ " [-4.2970e-02, -1.4718e-02, -1.8991e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 1.2678e-02, -5.3625e-02, 2.6030e-02],\n",
+ " [-5.3147e-02, 2.6828e-02, 6.6291e-02],\n",
+ " [ 9.3168e-02, 1.2636e-02, 2.5365e-02]],\n",
+ " \n",
+ " [[-3.3369e-02, 1.1124e-02, -1.9820e-02],\n",
+ " [ 2.1811e-02, 6.4112e-03, 4.1800e-02],\n",
+ " [ 6.0804e-02, 7.5496e-02, 2.2505e-02]],\n",
+ " \n",
+ " [[ 2.9828e-02, -1.1538e-01, 7.6337e-02],\n",
+ " [-1.5814e-01, -3.1391e-02, -1.1998e-02],\n",
+ " [-4.7611e-02, -1.5590e-02, 1.2594e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-4.9918e-02, 1.6541e-02, -2.1747e-02],\n",
+ " [-5.1807e-02, 7.8236e-02, 4.5203e-02],\n",
+ " [ 1.6146e-01, 3.4237e-02, 1.7932e-03]],\n",
+ " \n",
+ " [[-5.0599e-02, 6.9354e-02, 4.7455e-02],\n",
+ " [-5.7678e-02, -5.2270e-02, 6.2546e-02],\n",
+ " [-5.2623e-02, -5.8615e-02, -5.7776e-03]],\n",
+ " \n",
+ " [[ 2.8584e-02, -4.6263e-03, 3.0092e-02],\n",
+ " [-4.2295e-02, -1.5103e-01, 9.4677e-02],\n",
+ " [ 7.6677e-02, 1.8689e-02, 1.9354e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 8.5414e-02, -4.9813e-02, -1.5877e-02],\n",
+ " [-5.5522e-02, -5.0945e-02, -8.9750e-03],\n",
+ " [-8.4041e-02, 4.0588e-02, -6.0091e-02]],\n",
+ " \n",
+ " [[-1.2544e-01, -9.1728e-02, 1.3723e-02],\n",
+ " [-8.0662e-02, 1.0641e-01, 3.4712e-02],\n",
+ " [-1.5880e-02, -3.3900e-02, -5.5494e-02]],\n",
+ " \n",
+ " [[-5.7554e-02, -7.3517e-02, 2.2063e-02],\n",
+ " [-2.3512e-02, -4.2891e-02, -7.6788e-02],\n",
+ " [ 7.7499e-03, 6.6297e-02, 1.2341e-01]]]], device='cuda:0')),\n",
+ " ('module.features.4.1.conv2.weight',\n",
+ " tensor([[[[-0.0525, -0.0483, -0.0494],\n",
+ " [ 0.0247, -0.0368, -0.0704],\n",
+ " [-0.0005, 0.0078, -0.0628]],\n",
+ " \n",
+ " [[ 0.0117, 0.0538, -0.1291],\n",
+ " [ 0.0146, 0.0787, -0.0116],\n",
+ " [-0.0727, 0.0086, -0.0126]],\n",
+ " \n",
+ " [[-0.0127, -0.0116, 0.0342],\n",
+ " [ 0.0765, 0.0809, 0.0526],\n",
+ " [-0.0483, 0.0173, 0.0572]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0352, -0.0164, -0.0248],\n",
+ " [ 0.0715, 0.0104, 0.0451],\n",
+ " [-0.0418, 0.0218, -0.0910]],\n",
+ " \n",
+ " [[ 0.0437, 0.0141, -0.0259],\n",
+ " [ 0.0845, -0.0283, -0.0529],\n",
+ " [-0.0938, -0.0099, 0.0636]],\n",
+ " \n",
+ " [[-0.0853, -0.0144, -0.0787],\n",
+ " [-0.0351, 0.0832, -0.0776],\n",
+ " [ 0.1014, -0.0103, -0.1094]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0730, -0.0990, 0.0356],\n",
+ " [-0.0599, 0.0117, -0.0096],\n",
+ " [ 0.0253, 0.1055, 0.1380]],\n",
+ " \n",
+ " [[ 0.0712, 0.1090, 0.0153],\n",
+ " [-0.0768, -0.0915, 0.0155],\n",
+ " [ 0.0987, -0.0024, -0.1088]],\n",
+ " \n",
+ " [[-0.0386, 0.0034, -0.0279],\n",
+ " [ 0.0505, -0.0153, 0.0605],\n",
+ " [-0.0701, 0.0598, 0.0538]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0071, 0.0587, -0.0383],\n",
+ " [-0.0250, -0.0264, 0.0615],\n",
+ " [ 0.1125, 0.0474, 0.0286]],\n",
+ " \n",
+ " [[-0.0069, 0.0343, -0.0296],\n",
+ " [ 0.0915, 0.0430, 0.1088],\n",
+ " [-0.1369, -0.0314, 0.0983]],\n",
+ " \n",
+ " [[-0.0319, 0.0417, -0.0538],\n",
+ " [-0.0282, 0.0577, -0.0480],\n",
+ " [ 0.0273, -0.0052, -0.0819]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0803, -0.0615, 0.0542],\n",
+ " [ 0.0250, -0.0350, -0.0608],\n",
+ " [-0.0744, -0.0055, 0.0029]],\n",
+ " \n",
+ " [[ 0.0872, 0.0505, 0.0315],\n",
+ " [-0.0831, -0.0840, -0.0597],\n",
+ " [ 0.0237, -0.1176, -0.0849]],\n",
+ " \n",
+ " [[-0.0290, 0.0226, 0.0660],\n",
+ " [ 0.1088, -0.0424, -0.0558],\n",
+ " [ 0.0751, 0.0066, 0.0674]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.1043, -0.0803, -0.0286],\n",
+ " [-0.0090, 0.0550, -0.0403],\n",
+ " [-0.1595, -0.1149, -0.0509]],\n",
+ " \n",
+ " [[-0.1280, -0.0498, -0.0465],\n",
+ " [ 0.0072, -0.1675, 0.0868],\n",
+ " [-0.0774, 0.0304, 0.0038]],\n",
+ " \n",
+ " [[ 0.0237, 0.0396, -0.0633],\n",
+ " [ 0.1348, 0.0573, -0.0941],\n",
+ " [ 0.0403, -0.0493, -0.0018]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-0.0720, -0.0048, -0.0738],\n",
+ " [ 0.0150, 0.0566, 0.0630],\n",
+ " [-0.0041, 0.0416, 0.0473]],\n",
+ " \n",
+ " [[ 0.0707, 0.0833, -0.0563],\n",
+ " [-0.0464, -0.0921, -0.0284],\n",
+ " [ 0.0428, 0.0079, 0.0133]],\n",
+ " \n",
+ " [[-0.0681, -0.0689, 0.0829],\n",
+ " [ 0.1170, 0.0167, 0.0201],\n",
+ " [ 0.0783, 0.0763, 0.0563]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0164, 0.0350, -0.0013],\n",
+ " [ 0.0690, 0.0359, -0.0513],\n",
+ " [ 0.0730, -0.0891, -0.0222]],\n",
+ " \n",
+ " [[-0.0428, -0.0946, 0.0147],\n",
+ " [-0.0037, -0.0407, -0.0290],\n",
+ " [ 0.0189, 0.0446, 0.0013]],\n",
+ " \n",
+ " [[-0.0965, -0.0170, 0.0327],\n",
+ " [-0.0066, 0.0590, 0.0330],\n",
+ " [-0.0344, -0.1080, -0.0929]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0411, -0.0099, 0.0199],\n",
+ " [-0.0871, -0.0102, -0.2179],\n",
+ " [ 0.0105, -0.0675, 0.0096]],\n",
+ " \n",
+ " [[ 0.0683, -0.0130, -0.0562],\n",
+ " [ 0.0730, -0.0939, 0.0569],\n",
+ " [-0.0503, 0.0872, -0.0596]],\n",
+ " \n",
+ " [[-0.0970, 0.0281, 0.0215],\n",
+ " [ 0.0463, 0.0242, -0.0812],\n",
+ " [-0.0824, 0.0101, -0.0548]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0455, -0.0058, -0.0019],\n",
+ " [-0.0193, -0.0986, -0.0407],\n",
+ " [ 0.0216, -0.0313, 0.0442]],\n",
+ " \n",
+ " [[-0.0578, 0.0639, -0.0347],\n",
+ " [ 0.0483, 0.0167, 0.0356],\n",
+ " [-0.0884, -0.0625, -0.0573]],\n",
+ " \n",
+ " [[-0.0386, -0.0107, 0.0538],\n",
+ " [-0.0215, 0.0030, -0.0279],\n",
+ " [-0.0193, 0.1219, 0.0516]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0699, 0.0497, -0.0102],\n",
+ " [ 0.0046, 0.0519, 0.0270],\n",
+ " [ 0.0369, 0.0953, -0.0288]],\n",
+ " \n",
+ " [[-0.0620, 0.0500, -0.1316],\n",
+ " [-0.0377, -0.0071, -0.0139],\n",
+ " [-0.0591, 0.0661, 0.2031]],\n",
+ " \n",
+ " [[-0.0262, 0.0128, 0.0796],\n",
+ " [ 0.0171, -0.0781, -0.0751],\n",
+ " [ 0.0560, -0.0993, -0.1257]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0683, -0.0576, -0.0644],\n",
+ " [-0.1106, -0.0743, -0.0878],\n",
+ " [ 0.0077, 0.0252, -0.0271]],\n",
+ " \n",
+ " [[ 0.1365, -0.0229, -0.0237],\n",
+ " [-0.0245, -0.0334, -0.0210],\n",
+ " [ 0.0896, 0.0498, 0.0945]],\n",
+ " \n",
+ " [[ 0.0300, 0.0274, -0.0963],\n",
+ " [-0.0513, 0.0832, -0.0052],\n",
+ " [-0.0037, -0.0797, -0.0482]]]], device='cuda:0')),\n",
+ " ('module.features.5.0.conv1.weight',\n",
+ " tensor([[[[-0.0792, 0.0218, -0.0899],\n",
+ " [-0.0803, -0.0315, 0.0240],\n",
+ " [-0.0841, -0.0110, -0.0109]],\n",
+ " \n",
+ " [[-0.0315, -0.0697, -0.0428],\n",
+ " [ 0.0572, 0.0261, -0.0217],\n",
+ " [ 0.0151, 0.0978, -0.0195]],\n",
+ " \n",
+ " [[ 0.0309, -0.0566, 0.0163],\n",
+ " [ 0.0194, -0.1011, -0.0228],\n",
+ " [-0.0361, -0.0042, -0.0763]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0801, 0.0180, 0.0183],\n",
+ " [-0.0507, -0.0176, -0.0653],\n",
+ " [-0.0196, 0.0873, -0.0575]],\n",
+ " \n",
+ " [[-0.0665, -0.0464, -0.0791],\n",
+ " [ 0.0303, -0.0349, 0.0325],\n",
+ " [ 0.0673, 0.0472, 0.0432]],\n",
+ " \n",
+ " [[ 0.0340, 0.0556, -0.0336],\n",
+ " [ 0.0384, 0.0019, -0.0620],\n",
+ " [-0.0209, 0.0068, 0.0490]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0167, -0.0249, -0.0185],\n",
+ " [-0.0041, -0.0665, -0.0203],\n",
+ " [ 0.0526, 0.0094, -0.0457]],\n",
+ " \n",
+ " [[ 0.0726, 0.0042, 0.0038],\n",
+ " [-0.0489, 0.0161, -0.0138],\n",
+ " [-0.0468, -0.0317, 0.0218]],\n",
+ " \n",
+ " [[-0.0506, 0.0394, 0.0849],\n",
+ " [-0.0749, 0.0147, 0.0096],\n",
+ " [ 0.0436, 0.0361, 0.0326]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0211, 0.0249, 0.0460],\n",
+ " [ 0.0015, -0.0474, -0.0434],\n",
+ " [ 0.0055, 0.0315, 0.0097]],\n",
+ " \n",
+ " [[-0.0690, 0.0033, 0.0164],\n",
+ " [-0.0765, 0.0196, -0.0821],\n",
+ " [ 0.0495, -0.0271, -0.0655]],\n",
+ " \n",
+ " [[ 0.0337, -0.0879, 0.0041],\n",
+ " [-0.0020, -0.0018, 0.0143],\n",
+ " [-0.0319, -0.0122, -0.0044]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0778, -0.0654, 0.0032],\n",
+ " [ 0.0187, -0.0885, -0.0015],\n",
+ " [ 0.0671, 0.1071, -0.0319]],\n",
+ " \n",
+ " [[ 0.1421, -0.0081, -0.0021],\n",
+ " [-0.0644, -0.0045, 0.0067],\n",
+ " [-0.0229, -0.0169, 0.0082]],\n",
+ " \n",
+ " [[-0.0393, 0.0062, -0.0046],\n",
+ " [ 0.0598, -0.0135, -0.0100],\n",
+ " [ 0.0283, -0.0384, 0.0068]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0349, -0.0346, 0.0055],\n",
+ " [-0.0319, -0.0370, -0.0279],\n",
+ " [ 0.0458, 0.0806, 0.0086]],\n",
+ " \n",
+ " [[ 0.0148, 0.0105, -0.0311],\n",
+ " [-0.0032, 0.0306, 0.0248],\n",
+ " [-0.0445, -0.0083, 0.0084]],\n",
+ " \n",
+ " [[ 0.0659, -0.0195, -0.0173],\n",
+ " [-0.0862, 0.0108, 0.0530],\n",
+ " [ 0.0390, 0.0169, -0.0187]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-0.0533, 0.0452, -0.0518],\n",
+ " [ 0.0509, 0.0094, -0.0395],\n",
+ " [ 0.0027, -0.0050, -0.0940]],\n",
+ " \n",
+ " [[ 0.0047, -0.0107, 0.0322],\n",
+ " [ 0.0044, -0.0472, 0.0117],\n",
+ " [ 0.0032, -0.0558, 0.0117]],\n",
+ " \n",
+ " [[ 0.0715, -0.0191, -0.0130],\n",
+ " [-0.0261, -0.0300, 0.0227],\n",
+ " [-0.0532, 0.0113, 0.0065]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0201, -0.0292, -0.0152],\n",
+ " [ 0.0303, -0.0559, -0.0149],\n",
+ " [-0.0054, 0.0347, 0.1112]],\n",
+ " \n",
+ " [[ 0.0415, 0.0173, 0.0301],\n",
+ " [-0.0307, 0.0392, -0.0117],\n",
+ " [ 0.0257, 0.0229, 0.0593]],\n",
+ " \n",
+ " [[-0.0990, -0.0523, -0.0409],\n",
+ " [-0.0661, -0.0069, 0.0011],\n",
+ " [-0.0444, 0.0089, 0.0445]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0045, 0.0149, 0.0122],\n",
+ " [ 0.0050, -0.0330, 0.0177],\n",
+ " [-0.0690, -0.0824, -0.0246]],\n",
+ " \n",
+ " [[-0.0881, -0.0018, -0.0066],\n",
+ " [ 0.0561, -0.0397, -0.0240],\n",
+ " [-0.0405, 0.0066, -0.0229]],\n",
+ " \n",
+ " [[ 0.0256, 0.0282, 0.0244],\n",
+ " [-0.0167, -0.0100, -0.0101],\n",
+ " [-0.0276, 0.0515, -0.0074]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0582, 0.0027, -0.0263],\n",
+ " [ 0.0084, 0.0272, 0.0462],\n",
+ " [ 0.0278, -0.0378, -0.0446]],\n",
+ " \n",
+ " [[ 0.0310, 0.0112, -0.0012],\n",
+ " [ 0.0115, 0.0438, -0.0174],\n",
+ " [-0.1078, -0.0189, -0.0014]],\n",
+ " \n",
+ " [[-0.0506, -0.0164, 0.0329],\n",
+ " [ 0.0816, -0.0127, 0.0256],\n",
+ " [-0.0311, -0.0202, -0.0431]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0184, 0.0127, 0.0492],\n",
+ " [-0.0172, -0.0385, -0.0424],\n",
+ " [-0.0818, -0.0270, -0.0116]],\n",
+ " \n",
+ " [[ 0.0077, 0.0117, -0.0061],\n",
+ " [ 0.0166, 0.0163, 0.0286],\n",
+ " [-0.0163, -0.0531, 0.0770]],\n",
+ " \n",
+ " [[ 0.0229, -0.0362, -0.0435],\n",
+ " [ 0.0539, 0.0568, 0.0706],\n",
+ " [-0.0477, 0.0183, 0.0310]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0473, -0.0633, 0.0155],\n",
+ " [ 0.0071, -0.0229, -0.0209],\n",
+ " [-0.0374, -0.0606, -0.0541]],\n",
+ " \n",
+ " [[ 0.0055, -0.0027, -0.0049],\n",
+ " [ 0.0064, 0.0350, -0.0610],\n",
+ " [ 0.0301, 0.0102, -0.0355]],\n",
+ " \n",
+ " [[-0.0422, -0.0496, 0.0068],\n",
+ " [-0.0090, -0.0634, -0.0383],\n",
+ " [-0.0983, -0.0244, -0.0193]]]], device='cuda:0')),\n",
+ " ('module.features.5.0.conv2.weight',\n",
+ " tensor([[[[-0.0185, -0.0055, 0.0204],\n",
+ " [-0.0669, 0.0242, 0.0155],\n",
+ " [-0.0176, 0.0231, 0.0618]],\n",
+ " \n",
+ " [[ 0.0329, -0.0596, 0.0462],\n",
+ " [ 0.0019, 0.0363, 0.0510],\n",
+ " [ 0.0255, -0.0271, 0.0377]],\n",
+ " \n",
+ " [[-0.0379, -0.0469, 0.0030],\n",
+ " [-0.0445, -0.0152, -0.0425],\n",
+ " [-0.0020, 0.0046, 0.0034]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0956, -0.0019, -0.0946],\n",
+ " [ 0.0631, -0.0316, -0.0266],\n",
+ " [-0.0110, -0.1127, -0.0501]],\n",
+ " \n",
+ " [[ 0.0361, 0.0123, -0.0237],\n",
+ " [-0.0586, -0.0240, 0.0603],\n",
+ " [-0.0407, -0.0967, -0.0087]],\n",
+ " \n",
+ " [[-0.0953, -0.0075, 0.0781],\n",
+ " [-0.0586, -0.0116, 0.0293],\n",
+ " [ 0.0420, -0.0406, -0.0262]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0963, -0.0383, 0.0114],\n",
+ " [-0.0260, 0.0145, -0.0099],\n",
+ " [ 0.0706, 0.0572, -0.0383]],\n",
+ " \n",
+ " [[-0.0160, -0.0009, -0.0596],\n",
+ " [ 0.0447, 0.0360, 0.0334],\n",
+ " [-0.0005, 0.0293, -0.0442]],\n",
+ " \n",
+ " [[-0.0596, -0.0388, 0.0065],\n",
+ " [-0.0359, 0.0717, -0.0545],\n",
+ " [ 0.0501, -0.0117, -0.0065]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0024, -0.0862, 0.0692],\n",
+ " [ 0.0315, 0.0387, -0.0095],\n",
+ " [ 0.0111, -0.0315, 0.0742]],\n",
+ " \n",
+ " [[-0.0501, -0.0397, 0.0452],\n",
+ " [-0.0079, 0.0929, 0.0956],\n",
+ " [-0.1150, 0.0330, -0.0026]],\n",
+ " \n",
+ " [[-0.0165, -0.0450, 0.0579],\n",
+ " [-0.0096, 0.0519, 0.0432],\n",
+ " [-0.0026, -0.0358, 0.0526]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0022, -0.0217, -0.0499],\n",
+ " [-0.0221, 0.0151, -0.0570],\n",
+ " [ 0.0224, 0.0505, 0.0402]],\n",
+ " \n",
+ " [[ 0.0356, 0.0084, 0.0051],\n",
+ " [-0.0006, -0.0410, -0.0303],\n",
+ " [ 0.0270, 0.0788, -0.0720]],\n",
+ " \n",
+ " [[ 0.0262, -0.0168, -0.0006],\n",
+ " [ 0.0143, 0.0763, 0.0362],\n",
+ " [ 0.0824, 0.0376, -0.0052]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0246, 0.0417, -0.0386],\n",
+ " [ 0.0018, -0.0701, 0.0177],\n",
+ " [ 0.0582, 0.0484, 0.0029]],\n",
+ " \n",
+ " [[-0.0829, 0.0297, 0.1046],\n",
+ " [ 0.0008, 0.0256, -0.0059],\n",
+ " [-0.0159, 0.0485, 0.0155]],\n",
+ " \n",
+ " [[ 0.0113, 0.0179, -0.0300],\n",
+ " [-0.0117, -0.0168, -0.0579],\n",
+ " [ 0.0297, -0.0137, -0.0320]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 0.0665, -0.0067, -0.0355],\n",
+ " [-0.0112, -0.0227, 0.0079],\n",
+ " [-0.0444, -0.0311, -0.0343]],\n",
+ " \n",
+ " [[-0.0311, 0.1055, 0.0788],\n",
+ " [ 0.0565, -0.0003, -0.0352],\n",
+ " [-0.0467, 0.0318, -0.0082]],\n",
+ " \n",
+ " [[ 0.0322, -0.0605, -0.0607],\n",
+ " [ 0.0802, 0.0164, -0.0120],\n",
+ " [-0.0080, 0.0134, -0.0655]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0505, -0.0788, 0.1130],\n",
+ " [ 0.0141, 0.0362, -0.0042],\n",
+ " [-0.0597, 0.0082, 0.0053]],\n",
+ " \n",
+ " [[ 0.0281, -0.0423, 0.0122],\n",
+ " [-0.0923, -0.0106, 0.0446],\n",
+ " [-0.0557, -0.0728, -0.0367]],\n",
+ " \n",
+ " [[-0.0470, 0.0243, 0.0581],\n",
+ " [ 0.0270, -0.0034, -0.0219],\n",
+ " [ 0.0516, -0.0335, 0.0021]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0429, -0.0169, -0.0063],\n",
+ " [ 0.0552, -0.0109, -0.0119],\n",
+ " [-0.0203, -0.0496, -0.0205]],\n",
+ " \n",
+ " [[-0.0178, 0.0129, 0.0213],\n",
+ " [-0.0029, -0.0199, -0.0105],\n",
+ " [ 0.0424, -0.0083, 0.0296]],\n",
+ " \n",
+ " [[-0.0600, 0.0293, -0.0311],\n",
+ " [-0.0541, -0.0449, -0.0008],\n",
+ " [ 0.0119, 0.0300, -0.0152]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0118, 0.0368, 0.0228],\n",
+ " [ 0.0384, 0.0141, -0.0213],\n",
+ " [-0.0364, -0.0394, 0.0516]],\n",
+ " \n",
+ " [[-0.0128, 0.0606, -0.0741],\n",
+ " [ 0.0360, 0.0348, 0.0741],\n",
+ " [-0.0049, -0.0374, 0.0287]],\n",
+ " \n",
+ " [[-0.0020, -0.0368, -0.0418],\n",
+ " [-0.0443, 0.0536, -0.0359],\n",
+ " [-0.0287, -0.0068, 0.0364]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0482, 0.0189, -0.0050],\n",
+ " [-0.0643, 0.0241, 0.0150],\n",
+ " [-0.0293, 0.0479, -0.0457]],\n",
+ " \n",
+ " [[-0.0449, -0.0111, -0.0052],\n",
+ " [ 0.0568, -0.0459, 0.0137],\n",
+ " [-0.0959, 0.0218, -0.0872]],\n",
+ " \n",
+ " [[ 0.0153, -0.0173, -0.0511],\n",
+ " [-0.0229, -0.0133, 0.0028],\n",
+ " [-0.0202, 0.0880, -0.0106]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0198, 0.0526, 0.0231],\n",
+ " [-0.0164, 0.0388, -0.0761],\n",
+ " [-0.0426, 0.1007, -0.0563]],\n",
+ " \n",
+ " [[-0.0545, -0.0352, -0.0286],\n",
+ " [-0.0113, 0.0061, -0.0081],\n",
+ " [ 0.0563, -0.0457, 0.0216]],\n",
+ " \n",
+ " [[ 0.0377, 0.0722, 0.0403],\n",
+ " [ 0.0199, 0.0028, 0.0053],\n",
+ " [-0.0022, 0.0155, 0.0596]]]], device='cuda:0')),\n",
+ " ('module.features.5.0.downsample.0.weight',\n",
+ " tensor([[[[-0.1062]],\n",
+ " \n",
+ " [[ 0.0730]],\n",
+ " \n",
+ " [[-0.1620]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0529]],\n",
+ " \n",
+ " [[-0.1904]],\n",
+ " \n",
+ " [[-0.0081]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0272]],\n",
+ " \n",
+ " [[-0.0740]],\n",
+ " \n",
+ " [[-0.1000]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.1445]],\n",
+ " \n",
+ " [[-0.1923]],\n",
+ " \n",
+ " [[ 0.0117]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0210]],\n",
+ " \n",
+ " [[-0.0334]],\n",
+ " \n",
+ " [[-0.0200]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0884]],\n",
+ " \n",
+ " [[-0.1663]],\n",
+ " \n",
+ " [[ 0.0249]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 0.0371]],\n",
+ " \n",
+ " [[-0.1020]],\n",
+ " \n",
+ " [[-0.1673]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0238]],\n",
+ " \n",
+ " [[-0.1666]],\n",
+ " \n",
+ " [[-0.0730]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0192]],\n",
+ " \n",
+ " [[ 0.0836]],\n",
+ " \n",
+ " [[-0.2289]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.1308]],\n",
+ " \n",
+ " [[ 0.1366]],\n",
+ " \n",
+ " [[-0.0892]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0063]],\n",
+ " \n",
+ " [[-0.0660]],\n",
+ " \n",
+ " [[-0.0632]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0167]],\n",
+ " \n",
+ " [[ 0.1447]],\n",
+ " \n",
+ " [[-0.2353]]]], device='cuda:0')),\n",
+ " ('module.features.5.1.conv1.weight',\n",
+ " tensor([[[[ 0.0524, 0.0397, 0.0304],\n",
+ " [-0.0118, 0.0106, -0.0503],\n",
+ " [-0.0191, 0.0437, -0.0011]],\n",
+ " \n",
+ " [[ 0.0521, -0.0468, -0.0494],\n",
+ " [-0.0680, 0.0069, 0.0577],\n",
+ " [ 0.0327, 0.0409, 0.0039]],\n",
+ " \n",
+ " [[-0.0154, 0.0447, 0.0069],\n",
+ " [-0.0919, -0.0604, -0.0296],\n",
+ " [ 0.0298, 0.0329, 0.0491]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0175, 0.0182, -0.0473],\n",
+ " [ 0.0262, 0.0171, -0.0327],\n",
+ " [-0.0018, -0.0248, 0.0211]],\n",
+ " \n",
+ " [[ 0.0266, 0.0194, 0.0980],\n",
+ " [-0.0007, 0.0087, -0.0767],\n",
+ " [-0.0205, -0.0228, 0.0293]],\n",
+ " \n",
+ " [[-0.0330, 0.0093, 0.0342],\n",
+ " [-0.0393, 0.0319, -0.0072],\n",
+ " [-0.0048, -0.0731, 0.0250]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0200, 0.0151, 0.0073],\n",
+ " [-0.0070, -0.0006, 0.0224],\n",
+ " [-0.0144, -0.0088, -0.0227]],\n",
+ " \n",
+ " [[-0.0728, 0.0609, -0.0212],\n",
+ " [ 0.0061, -0.0018, -0.0444],\n",
+ " [ 0.0704, -0.0230, 0.0283]],\n",
+ " \n",
+ " [[-0.0296, -0.0179, -0.0159],\n",
+ " [-0.0356, 0.0155, -0.0227],\n",
+ " [ 0.0409, -0.0235, -0.0307]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0399, -0.0184, -0.1100],\n",
+ " [-0.0362, -0.0091, -0.0157],\n",
+ " [ 0.0774, -0.0126, 0.0616]],\n",
+ " \n",
+ " [[ 0.0709, 0.0422, 0.0342],\n",
+ " [ 0.0445, 0.0818, -0.0921],\n",
+ " [ 0.0237, 0.0218, 0.0722]],\n",
+ " \n",
+ " [[-0.0024, -0.0132, 0.0135],\n",
+ " [-0.0443, -0.0957, 0.0015],\n",
+ " [-0.0534, -0.0437, -0.0510]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0410, 0.0319, 0.0060],\n",
+ " [ 0.0066, 0.0273, -0.0037],\n",
+ " [-0.0130, -0.0355, -0.0566]],\n",
+ " \n",
+ " [[ 0.0547, 0.0008, 0.0495],\n",
+ " [ 0.0512, -0.0130, 0.0659],\n",
+ " [ 0.0532, 0.0160, 0.0034]],\n",
+ " \n",
+ " [[-0.0293, -0.0005, 0.0227],\n",
+ " [ 0.0014, -0.0253, -0.0078],\n",
+ " [-0.0341, 0.0066, -0.0061]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0146, 0.0126, -0.0179],\n",
+ " [-0.0398, -0.0328, 0.0236],\n",
+ " [ 0.0281, 0.0066, -0.0820]],\n",
+ " \n",
+ " [[ 0.0164, -0.0245, -0.0377],\n",
+ " [-0.0083, 0.0010, 0.0083],\n",
+ " [ 0.0375, -0.0253, -0.0360]],\n",
+ " \n",
+ " [[ 0.0161, -0.0290, 0.0348],\n",
+ " [ 0.0022, 0.0219, 0.0387],\n",
+ " [ 0.0141, -0.0822, 0.0453]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-0.0518, 0.0440, 0.0333],\n",
+ " [ 0.0233, 0.0102, -0.0029],\n",
+ " [-0.0399, -0.0019, -0.0124]],\n",
+ " \n",
+ " [[ 0.0093, -0.0079, 0.0802],\n",
+ " [ 0.0376, 0.1255, 0.0040],\n",
+ " [-0.0386, -0.0161, 0.0193]],\n",
+ " \n",
+ " [[-0.1355, 0.0206, 0.0577],\n",
+ " [-0.0023, 0.0065, -0.0870],\n",
+ " [-0.0181, -0.0022, 0.0204]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0522, 0.0307, 0.0033],\n",
+ " [-0.0065, 0.0160, 0.0701],\n",
+ " [-0.0301, -0.0474, 0.0469]],\n",
+ " \n",
+ " [[-0.0221, -0.1105, 0.0512],\n",
+ " [ 0.0373, -0.0592, -0.0210],\n",
+ " [-0.0781, -0.0006, -0.0032]],\n",
+ " \n",
+ " [[ 0.0072, 0.0435, -0.0008],\n",
+ " [-0.0510, -0.0071, 0.0293],\n",
+ " [ 0.0355, 0.0196, 0.0561]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0175, -0.0339, 0.0416],\n",
+ " [ 0.0016, -0.0096, -0.0434],\n",
+ " [-0.0544, -0.0349, 0.0061]],\n",
+ " \n",
+ " [[ 0.0194, -0.0126, -0.0320],\n",
+ " [-0.0303, 0.0254, 0.0027],\n",
+ " [ 0.0055, -0.0507, -0.0462]],\n",
+ " \n",
+ " [[-0.0646, 0.0249, -0.0058],\n",
+ " [-0.0026, 0.0673, -0.0211],\n",
+ " [ 0.0457, 0.0309, -0.0090]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0143, 0.0254, -0.0236],\n",
+ " [-0.0102, 0.0430, -0.0116],\n",
+ " [-0.0336, -0.0753, 0.0151]],\n",
+ " \n",
+ " [[-0.0260, 0.0030, 0.0509],\n",
+ " [ 0.0197, 0.0323, 0.0497],\n",
+ " [-0.0455, -0.0137, 0.0477]],\n",
+ " \n",
+ " [[ 0.0073, 0.0444, -0.0131],\n",
+ " [ 0.0492, 0.0232, -0.0728],\n",
+ " [ 0.0272, -0.0433, 0.0159]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0218, 0.0135, -0.0812],\n",
+ " [-0.0456, 0.0107, -0.0395],\n",
+ " [-0.0329, 0.0774, -0.0346]],\n",
+ " \n",
+ " [[ 0.0271, 0.0287, 0.0262],\n",
+ " [-0.0314, 0.0923, 0.0007],\n",
+ " [ 0.0322, 0.0266, 0.0109]],\n",
+ " \n",
+ " [[-0.0257, -0.0564, 0.0011],\n",
+ " [ 0.1216, -0.0470, 0.0248],\n",
+ " [ 0.0016, -0.0504, 0.0032]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0228, 0.0184, -0.0226],\n",
+ " [ 0.0335, -0.0399, -0.0486],\n",
+ " [ 0.0249, 0.0200, -0.0146]],\n",
+ " \n",
+ " [[-0.0255, 0.0459, -0.0260],\n",
+ " [ 0.0220, -0.0481, 0.0306],\n",
+ " [-0.0272, -0.1053, -0.0381]],\n",
+ " \n",
+ " [[-0.0147, -0.0229, -0.0310],\n",
+ " [ 0.0103, 0.0320, -0.1259],\n",
+ " [-0.0652, -0.0111, -0.0101]]]], device='cuda:0')),\n",
+ " ('module.features.5.1.conv2.weight',\n",
+ " tensor([[[[-6.3954e-02, 4.1020e-02, -1.6813e-02],\n",
+ " [ 1.9134e-02, 3.3471e-02, -4.8216e-03],\n",
+ " [-5.1764e-02, 5.2896e-02, -3.0942e-03]],\n",
+ " \n",
+ " [[-4.1595e-02, 3.2443e-02, -6.2537e-03],\n",
+ " [ 8.8430e-03, -6.9640e-03, -5.4212e-04],\n",
+ " [-9.1833e-03, 4.6340e-02, -5.6299e-02]],\n",
+ " \n",
+ " [[-2.6058e-02, 1.7892e-02, 7.1810e-02],\n",
+ " [ 2.3703e-02, 5.2424e-02, 3.3864e-02],\n",
+ " [ 4.2138e-02, -2.6445e-02, -3.3645e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 3.0310e-03, 4.2958e-02, 2.0308e-02],\n",
+ " [-1.2983e-02, -9.3949e-03, 5.3939e-02],\n",
+ " [ 1.3553e-02, -4.4959e-05, 2.0351e-02]],\n",
+ " \n",
+ " [[ 3.3951e-02, 3.9457e-02, -8.2942e-02],\n",
+ " [-3.9516e-03, 3.5715e-02, -7.5140e-02],\n",
+ " [ 2.2566e-02, 5.9749e-02, 2.4178e-03]],\n",
+ " \n",
+ " [[-5.9489e-02, 1.2264e-02, -1.7448e-02],\n",
+ " [-4.6808e-02, -1.0983e-01, 5.9044e-03],\n",
+ " [-4.2222e-02, 3.5628e-02, -1.8010e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-5.9920e-02, -5.8497e-02, 1.9054e-02],\n",
+ " [ 1.8326e-02, 3.2683e-02, -1.1116e-01],\n",
+ " [ 3.9147e-02, 1.1690e-02, -2.0936e-02]],\n",
+ " \n",
+ " [[-1.5919e-02, -5.8673e-02, -7.0556e-02],\n",
+ " [-2.0246e-02, 2.8831e-03, -6.7587e-03],\n",
+ " [ 6.8233e-02, 4.5227e-02, -1.0763e-02]],\n",
+ " \n",
+ " [[ 1.0176e-03, -5.5830e-02, 2.1531e-03],\n",
+ " [-2.0116e-03, -5.8720e-02, -1.6871e-02],\n",
+ " [ 2.0126e-02, -2.0070e-02, -2.5659e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-6.4194e-03, 7.0518e-02, 9.0191e-04],\n",
+ " [-3.5949e-02, 9.2871e-02, 5.3642e-02],\n",
+ " [-2.0992e-02, -2.6704e-02, 5.2092e-02]],\n",
+ " \n",
+ " [[ 5.4812e-02, -1.6752e-03, -8.8176e-02],\n",
+ " [ 1.5927e-02, 1.6110e-03, 3.7639e-02],\n",
+ " [ 1.7232e-02, 4.7434e-02, -3.0740e-02]],\n",
+ " \n",
+ " [[-5.2121e-02, 3.7098e-02, 2.0256e-02],\n",
+ " [-6.0424e-02, 3.3092e-02, 5.1734e-02],\n",
+ " [-7.9362e-03, -3.2565e-03, -2.9208e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 3.1005e-02, -1.9540e-02, 3.3992e-02],\n",
+ " [ 2.2688e-02, 5.2265e-02, -5.0763e-02],\n",
+ " [-6.6935e-03, -1.2542e-02, -8.2880e-03]],\n",
+ " \n",
+ " [[-2.8486e-02, 1.2028e-02, 3.1694e-02],\n",
+ " [ 2.8941e-02, 1.2840e-02, -5.0390e-02],\n",
+ " [ 7.9660e-02, 1.5672e-02, -3.8056e-02]],\n",
+ " \n",
+ " [[ 1.0606e-02, -4.1412e-02, -1.9782e-02],\n",
+ " [-3.5188e-02, 5.0918e-03, -4.2261e-02],\n",
+ " [-8.7299e-02, 6.5995e-02, -2.9643e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 7.1659e-02, -6.0586e-02, -9.8954e-03],\n",
+ " [ 5.6954e-02, 5.5151e-02, 3.5438e-02],\n",
+ " [-3.1127e-02, -6.1179e-02, -7.7983e-02]],\n",
+ " \n",
+ " [[ 6.6632e-02, 4.3427e-02, -4.0689e-02],\n",
+ " [-1.5278e-02, -4.2361e-02, 3.6805e-02],\n",
+ " [ 3.2855e-02, -4.2530e-02, -4.5892e-02]],\n",
+ " \n",
+ " [[ 2.3422e-02, 5.8044e-02, -7.6857e-03],\n",
+ " [-2.8058e-02, -1.5949e-02, -4.0950e-02],\n",
+ " [ 4.8697e-03, -2.4890e-02, 2.8388e-02]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 1.2142e-02, 7.3265e-02, -9.9022e-03],\n",
+ " [ 2.5017e-02, 5.8642e-02, -6.8056e-03],\n",
+ " [-9.9244e-02, -9.1136e-02, -7.3883e-02]],\n",
+ " \n",
+ " [[-1.9512e-02, 4.1599e-02, -1.2142e-01],\n",
+ " [ 1.7323e-02, 1.3393e-02, 3.3512e-02],\n",
+ " [ 1.1051e-02, -1.4824e-02, 8.3638e-03]],\n",
+ " \n",
+ " [[ 1.8719e-02, 6.2494e-02, 8.6195e-03],\n",
+ " [-1.0191e-02, 4.2297e-03, -4.0402e-02],\n",
+ " [ 4.0748e-02, 4.5447e-02, -2.4630e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 8.1352e-02, -1.6838e-02, 2.7775e-02],\n",
+ " [ 1.6260e-02, 2.2381e-02, 4.7176e-02],\n",
+ " [ 6.5905e-03, -2.0240e-03, 5.0318e-02]],\n",
+ " \n",
+ " [[ 3.5941e-02, -1.3546e-02, 2.5559e-02],\n",
+ " [ 8.7063e-02, 5.7019e-03, 8.7749e-02],\n",
+ " [ 3.1105e-02, 3.0270e-02, 7.4364e-02]],\n",
+ " \n",
+ " [[ 2.4678e-02, -1.4076e-02, 2.7644e-02],\n",
+ " [-7.6359e-02, 3.7058e-02, -2.3741e-02],\n",
+ " [-2.0144e-02, 9.5933e-02, 4.5360e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 6.5741e-02, 4.9462e-02, -6.5293e-02],\n",
+ " [-4.4282e-02, -2.1336e-02, -5.6478e-02],\n",
+ " [ 1.1655e-02, -1.0429e-02, -3.8329e-02]],\n",
+ " \n",
+ " [[-4.5852e-02, 4.2555e-02, -3.2961e-02],\n",
+ " [-3.5738e-02, 4.2727e-02, 3.7960e-02],\n",
+ " [-5.9692e-02, -3.7927e-02, -6.3316e-02]],\n",
+ " \n",
+ " [[-3.1069e-02, -4.6978e-02, -6.3886e-02],\n",
+ " [-7.9316e-02, 6.1180e-03, -1.3502e-02],\n",
+ " [ 1.7732e-02, 2.7328e-02, -3.5424e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-5.1502e-02, 1.1618e-02, 4.4762e-02],\n",
+ " [ 1.3293e-02, -2.8515e-02, 2.3477e-02],\n",
+ " [ 1.0963e-02, 6.3643e-03, 6.7758e-03]],\n",
+ " \n",
+ " [[ 4.9450e-02, -1.1083e-02, 4.1855e-02],\n",
+ " [ 2.1341e-02, 5.4715e-02, 5.2060e-02],\n",
+ " [ 1.9619e-02, 8.7005e-03, -1.0872e-02]],\n",
+ " \n",
+ " [[ 1.6183e-03, 1.7625e-02, -7.1724e-02],\n",
+ " [ 5.6331e-02, 1.0753e-01, 1.2052e-02],\n",
+ " [ 2.0485e-02, 1.2266e-02, 6.9537e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[ 1.4428e-02, 7.6514e-03, 2.5245e-02],\n",
+ " [-3.3150e-02, 4.3417e-02, 1.9428e-02],\n",
+ " [ 4.7255e-02, -3.4982e-02, -4.3335e-02]],\n",
+ " \n",
+ " [[-1.1168e-02, -4.7031e-02, -3.6381e-02],\n",
+ " [ 4.1839e-02, 3.1655e-02, -1.7358e-02],\n",
+ " [-1.9284e-02, -3.4662e-03, -4.0539e-02]],\n",
+ " \n",
+ " [[ 5.0484e-02, -7.0366e-02, 7.1903e-02],\n",
+ " [ 2.5097e-02, 1.6200e-02, 2.6682e-02],\n",
+ " [-2.6516e-02, -4.4044e-02, -1.1757e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 2.7789e-02, -2.4881e-02, -1.3087e-02],\n",
+ " [-2.1391e-02, 1.0493e-02, -5.0449e-02],\n",
+ " [-3.7873e-02, 2.2887e-02, 2.3353e-02]],\n",
+ " \n",
+ " [[-2.7992e-02, -9.8221e-03, -2.6284e-02],\n",
+ " [ 5.5802e-02, -1.9979e-02, -5.3550e-02],\n",
+ " [-8.4603e-02, -4.5323e-02, -5.7951e-02]],\n",
+ " \n",
+ " [[ 1.6301e-02, 6.3142e-02, 9.8293e-02],\n",
+ " [ 4.3805e-03, 3.1449e-02, 5.8051e-02],\n",
+ " [-1.9028e-02, -1.3127e-02, -1.3450e-02]]]], device='cuda:0')),\n",
+ " ('module.features.6.0.conv1.weight',\n",
+ " tensor([[[[ 0.0240, -0.0227, -0.0309],\n",
+ " [-0.0247, 0.0033, 0.0432],\n",
+ " [ 0.0297, -0.0698, 0.0268]],\n",
+ " \n",
+ " [[-0.0087, -0.0099, -0.0024],\n",
+ " [ 0.0354, -0.0023, 0.0237],\n",
+ " [-0.0012, 0.0157, 0.0071]],\n",
+ " \n",
+ " [[ 0.0058, -0.0504, -0.0078],\n",
+ " [-0.0192, 0.0444, -0.0287],\n",
+ " [ 0.0058, -0.0301, 0.0103]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0248, 0.0194, -0.0008],\n",
+ " [-0.0152, 0.0511, 0.0284],\n",
+ " [-0.0017, -0.0710, -0.0117]],\n",
+ " \n",
+ " [[ 0.0231, -0.0283, 0.0355],\n",
+ " [ 0.0168, -0.0039, -0.0019],\n",
+ " [ 0.0184, -0.0179, -0.0213]],\n",
+ " \n",
+ " [[ 0.0283, -0.0127, -0.0369],\n",
+ " [ 0.0271, -0.0027, -0.0016],\n",
+ " [ 0.0173, -0.0237, 0.0033]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0057, -0.0025, 0.0064],\n",
+ " [ 0.0161, -0.0475, 0.0262],\n",
+ " [ 0.0325, -0.0095, -0.0054]],\n",
+ " \n",
+ " [[-0.0346, -0.0113, 0.0190],\n",
+ " [-0.0364, 0.0418, 0.0298],\n",
+ " [ 0.0088, 0.0369, 0.0125]],\n",
+ " \n",
+ " [[ 0.0016, 0.0035, 0.0206],\n",
+ " [ 0.0233, 0.0029, -0.0143],\n",
+ " [-0.0005, -0.0228, 0.0294]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0209, -0.0253, 0.0257],\n",
+ " [-0.0404, -0.0363, 0.0393],\n",
+ " [-0.0203, 0.0136, 0.0041]],\n",
+ " \n",
+ " [[-0.0301, -0.0184, -0.0200],\n",
+ " [-0.0319, -0.0195, -0.0332],\n",
+ " [-0.0231, 0.0407, -0.0170]],\n",
+ " \n",
+ " [[-0.0198, 0.0130, -0.0241],\n",
+ " [-0.0224, 0.0324, -0.0427],\n",
+ " [ 0.0380, -0.0024, 0.0291]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0466, -0.0423, -0.0005],\n",
+ " [-0.0131, -0.0138, -0.0047],\n",
+ " [-0.0888, 0.0136, 0.0141]],\n",
+ " \n",
+ " [[-0.0107, -0.0105, -0.0020],\n",
+ " [ 0.0112, -0.0028, -0.0252],\n",
+ " [ 0.0050, 0.0019, 0.0032]],\n",
+ " \n",
+ " [[-0.0417, 0.0396, -0.0192],\n",
+ " [-0.0302, -0.0239, -0.0329],\n",
+ " [-0.0322, -0.0501, 0.0183]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0651, -0.0232, -0.0549],\n",
+ " [-0.0417, -0.0100, 0.0518],\n",
+ " [-0.0001, 0.0329, -0.0203]],\n",
+ " \n",
+ " [[ 0.0449, 0.0181, -0.0199],\n",
+ " [-0.0355, -0.0602, 0.0449],\n",
+ " [-0.0516, -0.0057, 0.0202]],\n",
+ " \n",
+ " [[ 0.0048, -0.0382, -0.0481],\n",
+ " [ 0.0446, -0.0534, 0.0196],\n",
+ " [ 0.0420, -0.0488, -0.0025]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 0.0370, 0.0267, -0.0290],\n",
+ " [ 0.0086, 0.0268, 0.0252],\n",
+ " [ 0.0014, 0.0049, 0.0030]],\n",
+ " \n",
+ " [[-0.0172, 0.0166, -0.0116],\n",
+ " [ 0.0058, 0.0025, 0.0096],\n",
+ " [ 0.0088, -0.0045, 0.0109]],\n",
+ " \n",
+ " [[-0.0301, -0.0272, -0.0277],\n",
+ " [-0.0049, 0.0645, 0.0125],\n",
+ " [-0.0520, -0.0035, -0.0437]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0150, -0.0118, -0.0402],\n",
+ " [-0.0030, 0.0220, -0.0568],\n",
+ " [-0.0604, -0.0319, 0.0031]],\n",
+ " \n",
+ " [[ 0.0017, -0.0345, 0.0604],\n",
+ " [ 0.0126, -0.0026, 0.0381],\n",
+ " [-0.0294, 0.0549, -0.0118]],\n",
+ " \n",
+ " [[-0.0526, -0.0483, 0.0265],\n",
+ " [ 0.0209, -0.0104, 0.0020],\n",
+ " [ 0.0164, 0.0114, 0.0161]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0355, 0.0480, -0.0118],\n",
+ " [ 0.0319, -0.0545, 0.0136],\n",
+ " [ 0.0028, -0.0574, 0.0331]],\n",
+ " \n",
+ " [[ 0.0020, 0.0231, -0.0033],\n",
+ " [ 0.0219, 0.0202, -0.0166],\n",
+ " [-0.0336, 0.0322, -0.0502]],\n",
+ " \n",
+ " [[ 0.0146, -0.0157, 0.0381],\n",
+ " [ 0.0105, 0.0060, -0.0070],\n",
+ " [-0.0114, 0.0013, -0.0418]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0059, 0.0099, 0.0060],\n",
+ " [-0.0165, 0.0038, -0.0169],\n",
+ " [ 0.0343, -0.0244, -0.0167]],\n",
+ " \n",
+ " [[ 0.0479, 0.0010, 0.0205],\n",
+ " [-0.0019, 0.0591, -0.0672],\n",
+ " [-0.0070, -0.0364, 0.0232]],\n",
+ " \n",
+ " [[-0.0080, 0.0137, -0.0213],\n",
+ " [-0.0774, 0.0105, -0.0237],\n",
+ " [ 0.0721, -0.0167, 0.0276]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0220, -0.0279, 0.0012],\n",
+ " [-0.0209, 0.0293, -0.0029],\n",
+ " [-0.0057, 0.0015, 0.0172]],\n",
+ " \n",
+ " [[-0.0106, -0.0302, -0.0075],\n",
+ " [ 0.0008, -0.0210, 0.0442],\n",
+ " [-0.0106, 0.0031, -0.0311]],\n",
+ " \n",
+ " [[-0.0157, -0.0408, 0.0793],\n",
+ " [-0.0214, 0.0764, -0.0372],\n",
+ " [ 0.0025, 0.0271, -0.0315]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0563, 0.0262, -0.0354],\n",
+ " [ 0.0141, -0.0001, -0.0292],\n",
+ " [-0.0230, 0.0063, -0.0463]],\n",
+ " \n",
+ " [[-0.0258, 0.0125, -0.0095],\n",
+ " [ 0.0223, -0.0436, 0.0133],\n",
+ " [ 0.0052, -0.0080, -0.0041]],\n",
+ " \n",
+ " [[-0.0072, 0.0499, -0.0308],\n",
+ " [-0.0131, -0.0604, 0.0236],\n",
+ " [-0.0735, 0.0252, -0.0268]]]], device='cuda:0')),\n",
+ " ('module.features.6.0.conv2.weight',\n",
+ " tensor([[[[ 0.0456, 0.0217, -0.0140],\n",
+ " [ 0.0218, -0.0242, -0.0408],\n",
+ " [-0.0064, 0.0213, -0.0625]],\n",
+ " \n",
+ " [[-0.0043, -0.0048, 0.0029],\n",
+ " [-0.0304, -0.0281, 0.0088],\n",
+ " [ 0.0585, 0.0351, 0.0145]],\n",
+ " \n",
+ " [[-0.0666, -0.0023, 0.0278],\n",
+ " [-0.0346, -0.0102, 0.0055],\n",
+ " [-0.0033, -0.0292, 0.0276]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0203, 0.0189, 0.0237],\n",
+ " [ 0.0598, -0.0031, 0.0022],\n",
+ " [-0.0101, 0.0254, -0.0024]],\n",
+ " \n",
+ " [[ 0.0243, -0.0321, 0.0007],\n",
+ " [ 0.0324, -0.0349, 0.0275],\n",
+ " [-0.0017, 0.0128, 0.0202]],\n",
+ " \n",
+ " [[-0.0222, 0.0059, -0.0872],\n",
+ " [-0.0068, -0.0591, 0.0200],\n",
+ " [ 0.0156, 0.0124, 0.0116]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0435, 0.0144, -0.0088],\n",
+ " [-0.0421, -0.0291, 0.0273],\n",
+ " [ 0.0186, -0.0065, -0.0051]],\n",
+ " \n",
+ " [[-0.0224, -0.0085, -0.0016],\n",
+ " [-0.0155, -0.0116, 0.0089],\n",
+ " [ 0.0052, -0.0223, 0.0146]],\n",
+ " \n",
+ " [[ 0.0664, 0.0152, 0.0241],\n",
+ " [ 0.0502, -0.0051, 0.0327],\n",
+ " [ 0.0381, -0.0349, -0.0250]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0974, 0.0089, -0.0157],\n",
+ " [ 0.0427, 0.0091, -0.0036],\n",
+ " [-0.0220, -0.0030, -0.0207]],\n",
+ " \n",
+ " [[ 0.0463, -0.0679, 0.0149],\n",
+ " [-0.0382, -0.0128, -0.0297],\n",
+ " [ 0.0492, 0.0189, -0.0443]],\n",
+ " \n",
+ " [[ 0.0432, -0.0122, -0.0390],\n",
+ " [-0.0299, 0.0153, 0.0116],\n",
+ " [ 0.0074, 0.0139, 0.0156]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0378, -0.0024, 0.0227],\n",
+ " [-0.0338, 0.0147, 0.0021],\n",
+ " [ 0.0113, 0.0399, -0.0064]],\n",
+ " \n",
+ " [[-0.0599, -0.0307, -0.0259],\n",
+ " [ 0.0257, 0.0076, 0.0498],\n",
+ " [-0.0048, -0.0039, -0.0475]],\n",
+ " \n",
+ " [[ 0.0055, -0.0252, -0.0048],\n",
+ " [ 0.0249, -0.0032, 0.0166],\n",
+ " [-0.0380, 0.0109, 0.0167]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0421, -0.0173, -0.0114],\n",
+ " [ 0.0343, -0.0060, 0.0394],\n",
+ " [-0.0232, -0.0279, 0.0052]],\n",
+ " \n",
+ " [[ 0.0079, -0.0103, -0.0056],\n",
+ " [ 0.0265, 0.0216, -0.0492],\n",
+ " [ 0.0082, 0.0359, -0.0071]],\n",
+ " \n",
+ " [[-0.0195, 0.0216, -0.0235],\n",
+ " [ 0.0362, 0.0314, 0.0027],\n",
+ " [ 0.0388, 0.0462, -0.0083]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 0.0158, 0.0483, 0.0320],\n",
+ " [ 0.0072, 0.0244, 0.0348],\n",
+ " [ 0.0027, 0.0525, 0.0169]],\n",
+ " \n",
+ " [[-0.0024, 0.0197, 0.0254],\n",
+ " [-0.0142, 0.0002, -0.0231],\n",
+ " [-0.0176, 0.0244, 0.0119]],\n",
+ " \n",
+ " [[-0.0151, 0.0136, 0.0562],\n",
+ " [ 0.0227, -0.0010, -0.0176],\n",
+ " [ 0.0216, 0.0213, -0.0401]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0267, -0.0435, 0.0050],\n",
+ " [-0.0382, -0.0170, 0.0197],\n",
+ " [-0.0110, -0.0479, 0.0207]],\n",
+ " \n",
+ " [[ 0.0468, 0.0084, -0.0388],\n",
+ " [-0.0114, 0.0255, -0.0155],\n",
+ " [-0.0160, -0.0051, -0.0084]],\n",
+ " \n",
+ " [[-0.0102, 0.0076, 0.0167],\n",
+ " [-0.0097, 0.0808, 0.0072],\n",
+ " [-0.0422, -0.0090, 0.0205]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0127, -0.0120, -0.0155],\n",
+ " [-0.0229, -0.0039, -0.0077],\n",
+ " [ 0.0269, 0.0339, 0.0376]],\n",
+ " \n",
+ " [[ 0.0109, -0.0058, -0.0114],\n",
+ " [ 0.0051, 0.0078, 0.0334],\n",
+ " [ 0.0142, 0.0040, -0.0676]],\n",
+ " \n",
+ " [[ 0.0029, -0.0156, 0.0024],\n",
+ " [-0.0088, 0.0022, 0.0056],\n",
+ " [ 0.0235, -0.0165, -0.0713]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0219, -0.0182, 0.0225],\n",
+ " [ 0.0015, -0.0273, -0.0245],\n",
+ " [ 0.0080, -0.0202, -0.0027]],\n",
+ " \n",
+ " [[ 0.0078, 0.0203, -0.0222],\n",
+ " [ 0.0223, 0.0386, 0.0078],\n",
+ " [ 0.0222, 0.0257, -0.0107]],\n",
+ " \n",
+ " [[ 0.0005, -0.0625, 0.0093],\n",
+ " [ 0.0340, -0.0411, -0.0146],\n",
+ " [ 0.0081, 0.0240, 0.0127]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0027, -0.0428, -0.0054],\n",
+ " [-0.0341, -0.0059, -0.0241],\n",
+ " [-0.0252, -0.0168, 0.0242]],\n",
+ " \n",
+ " [[-0.0126, -0.0144, -0.0143],\n",
+ " [ 0.0255, 0.0032, -0.0261],\n",
+ " [-0.0114, 0.0082, -0.0139]],\n",
+ " \n",
+ " [[-0.0032, -0.0282, 0.0255],\n",
+ " [-0.0109, -0.0130, 0.0422],\n",
+ " [ 0.0156, 0.0132, -0.0362]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0280, -0.0159, 0.0197],\n",
+ " [-0.0053, 0.0227, 0.0105],\n",
+ " [ 0.0252, -0.0133, 0.0017]],\n",
+ " \n",
+ " [[ 0.0088, -0.0180, -0.0219],\n",
+ " [-0.0258, -0.0302, 0.0063],\n",
+ " [-0.0330, 0.0104, 0.0190]],\n",
+ " \n",
+ " [[-0.0133, 0.0672, -0.0083],\n",
+ " [-0.0084, -0.0127, 0.0435],\n",
+ " [-0.0250, 0.0217, -0.0196]]]], device='cuda:0')),\n",
+ " ('module.features.6.0.downsample.0.weight',\n",
+ " tensor([[[[ 0.0167]],\n",
+ " \n",
+ " [[ 0.0906]],\n",
+ " \n",
+ " [[-0.0019]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0699]],\n",
+ " \n",
+ " [[-0.0978]],\n",
+ " \n",
+ " [[ 0.0776]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0425]],\n",
+ " \n",
+ " [[ 0.0029]],\n",
+ " \n",
+ " [[-0.0786]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0315]],\n",
+ " \n",
+ " [[-0.0649]],\n",
+ " \n",
+ " [[ 0.0829]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.1541]],\n",
+ " \n",
+ " [[-0.1738]],\n",
+ " \n",
+ " [[-0.0216]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0183]],\n",
+ " \n",
+ " [[-0.0233]],\n",
+ " \n",
+ " [[-0.0739]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 0.0135]],\n",
+ " \n",
+ " [[-0.1001]],\n",
+ " \n",
+ " [[-0.0375]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0250]],\n",
+ " \n",
+ " [[ 0.1578]],\n",
+ " \n",
+ " [[-0.0817]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.1361]],\n",
+ " \n",
+ " [[-0.1274]],\n",
+ " \n",
+ " [[ 0.1172]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0530]],\n",
+ " \n",
+ " [[-0.0503]],\n",
+ " \n",
+ " [[ 0.1324]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0932]],\n",
+ " \n",
+ " [[ 0.2297]],\n",
+ " \n",
+ " [[-0.1584]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0520]],\n",
+ " \n",
+ " [[-0.0518]],\n",
+ " \n",
+ " [[ 0.1083]]]], device='cuda:0')),\n",
+ " ('module.features.6.1.conv1.weight',\n",
+ " tensor([[[[-1.0663e-02, -6.4592e-02, -1.7643e-02],\n",
+ " [-1.9915e-02, -2.3672e-02, -2.6418e-02],\n",
+ " [-3.7928e-03, 1.4260e-02, 8.1776e-04]],\n",
+ " \n",
+ " [[-1.3990e-02, 4.2493e-02, -4.5754e-02],\n",
+ " [ 6.5685e-02, -9.7300e-03, -4.2949e-03],\n",
+ " [ 5.9361e-03, 2.8047e-02, -7.9141e-03]],\n",
+ " \n",
+ " [[ 7.7116e-03, 5.7022e-02, -3.6133e-02],\n",
+ " [-1.7684e-02, -3.1758e-02, -3.3286e-02],\n",
+ " [-4.7406e-02, 2.8379e-02, -2.8364e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-5.6726e-04, 1.2521e-02, -2.0426e-02],\n",
+ " [ 4.0585e-02, -4.6295e-02, 1.5691e-02],\n",
+ " [-2.1484e-02, -3.0863e-03, -4.5604e-02]],\n",
+ " \n",
+ " [[ 3.8378e-03, -3.6587e-03, 1.3426e-02],\n",
+ " [ 1.0853e-03, 8.8641e-03, -4.1009e-02],\n",
+ " [ 3.7223e-03, 9.7395e-03, -5.4633e-03]],\n",
+ " \n",
+ " [[-1.3050e-02, -3.5338e-02, 2.1541e-03],\n",
+ " [ 1.8484e-02, -1.7885e-02, 3.4567e-02],\n",
+ " [ 1.0993e-02, -2.7531e-02, 4.5742e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[ 7.5872e-02, -3.7636e-02, 1.7796e-02],\n",
+ " [ 9.2064e-03, 4.1637e-03, 1.0417e-02],\n",
+ " [-3.1397e-02, -3.2680e-03, -4.1829e-02]],\n",
+ " \n",
+ " [[ 6.7461e-02, -9.0697e-03, -1.5819e-02],\n",
+ " [ 2.0520e-02, 1.0564e-02, 1.5427e-02],\n",
+ " [-3.1541e-02, -1.0761e-01, 1.0140e-02]],\n",
+ " \n",
+ " [[ 2.0647e-02, 3.4147e-03, 5.1170e-03],\n",
+ " [ 6.2912e-02, -2.5438e-02, -7.0194e-03],\n",
+ " [-5.7876e-03, -2.5348e-02, 6.9223e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-1.8312e-02, -1.5583e-02, 1.9101e-02],\n",
+ " [ 1.1937e-02, -1.1945e-02, 3.0733e-02],\n",
+ " [ 8.7930e-02, -8.5641e-03, -3.9050e-02]],\n",
+ " \n",
+ " [[-4.3650e-03, 3.8903e-02, -7.1692e-02],\n",
+ " [ 2.3118e-02, 7.5008e-03, -2.0700e-02],\n",
+ " [-5.0258e-02, 1.1372e-02, -1.7813e-02]],\n",
+ " \n",
+ " [[-8.8078e-04, 3.8180e-02, -2.6889e-02],\n",
+ " [ 1.2815e-02, 2.0097e-02, -3.1433e-02],\n",
+ " [-1.8129e-02, -1.6484e-02, -1.1331e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 1.1774e-02, -4.3228e-02, -2.4902e-02],\n",
+ " [ 3.0643e-02, -7.8646e-03, 7.8888e-03],\n",
+ " [ 2.7589e-02, 3.7252e-03, -1.6763e-02]],\n",
+ " \n",
+ " [[ 3.5904e-02, 9.2736e-03, 1.5518e-03],\n",
+ " [-2.8347e-02, 7.6164e-03, -2.6277e-02],\n",
+ " [-2.5395e-03, 5.1337e-02, 3.6514e-02]],\n",
+ " \n",
+ " [[-8.9175e-03, 2.8995e-02, 3.6913e-03],\n",
+ " [ 4.7412e-02, -2.0663e-02, -7.5921e-02],\n",
+ " [ 4.4291e-02, 2.7490e-02, 3.7849e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-4.9216e-02, 2.1112e-02, 2.7770e-02],\n",
+ " [-5.3211e-03, -1.9268e-02, 5.1660e-03],\n",
+ " [ 5.1669e-03, 9.7053e-03, -3.8086e-02]],\n",
+ " \n",
+ " [[-2.4785e-02, 1.6300e-02, -1.5496e-02],\n",
+ " [ 1.5307e-02, -2.9707e-02, 2.7436e-02],\n",
+ " [-1.9306e-02, 7.0717e-02, -1.0173e-02]],\n",
+ " \n",
+ " [[ 3.0778e-02, 4.9702e-04, 3.1833e-02],\n",
+ " [-2.5441e-02, -1.4679e-02, 3.4316e-02],\n",
+ " [-2.3471e-03, -3.3221e-03, -2.5022e-02]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 2.0778e-04, -2.3316e-03, -2.8760e-02],\n",
+ " [-4.4530e-02, -2.4329e-02, -3.9114e-02],\n",
+ " [-1.5464e-02, 1.0834e-02, -8.8251e-03]],\n",
+ " \n",
+ " [[ 1.0525e-04, 6.5905e-02, 3.1255e-02],\n",
+ " [ 4.8631e-02, 1.9551e-02, -1.5353e-02],\n",
+ " [-1.9567e-03, 2.6212e-04, -3.2542e-02]],\n",
+ " \n",
+ " [[-2.9586e-02, 3.8948e-02, 1.4704e-02],\n",
+ " [ 7.8863e-03, -2.3816e-02, 9.8579e-03],\n",
+ " [-2.5092e-02, 6.2358e-03, -6.0422e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-1.6740e-02, 2.9511e-02, -1.9156e-02],\n",
+ " [ 1.6318e-02, -3.4487e-04, -4.0301e-03],\n",
+ " [ 1.0853e-02, 2.3323e-02, -2.9527e-03]],\n",
+ " \n",
+ " [[ 1.6024e-02, -1.4094e-02, -1.2549e-02],\n",
+ " [ 5.7991e-03, 3.7044e-03, -3.6310e-02],\n",
+ " [-5.9241e-02, -2.2843e-02, 3.7137e-02]],\n",
+ " \n",
+ " [[-1.7307e-02, -3.0526e-02, -6.5269e-03],\n",
+ " [ 3.6899e-02, 3.5466e-02, -3.6425e-02],\n",
+ " [ 1.6369e-02, 2.1150e-03, 2.2940e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-4.8313e-02, -1.7220e-02, 2.5480e-03],\n",
+ " [-7.8973e-04, 1.3135e-02, 8.3401e-03],\n",
+ " [-3.0900e-03, -2.5085e-02, 2.9801e-02]],\n",
+ " \n",
+ " [[ 4.8227e-02, 1.4280e-02, 1.7196e-02],\n",
+ " [-4.0403e-03, -5.7785e-03, -3.2870e-02],\n",
+ " [-1.3931e-02, -7.0277e-02, -3.6799e-02]],\n",
+ " \n",
+ " [[ 1.6242e-02, -1.8689e-02, 1.0376e-03],\n",
+ " [ 3.9966e-02, 2.1135e-02, 2.2993e-02],\n",
+ " [-9.3555e-03, 5.1732e-02, -3.1082e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-1.3077e-02, 2.8440e-02, 4.4365e-02],\n",
+ " [-1.3948e-02, -6.9563e-04, 2.8177e-02],\n",
+ " [ 3.3243e-02, -3.3940e-02, -5.2344e-02]],\n",
+ " \n",
+ " [[-3.1947e-02, -9.0099e-03, 4.4178e-02],\n",
+ " [ 8.0787e-03, 7.7619e-02, 5.6659e-04],\n",
+ " [-1.0522e-02, -6.2181e-03, -4.4128e-02]],\n",
+ " \n",
+ " [[ 4.1340e-03, -3.2932e-02, -6.5304e-02],\n",
+ " [-2.3988e-02, -2.4525e-03, 9.7465e-03],\n",
+ " [-3.9158e-03, -4.7578e-02, -1.2476e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 3.4115e-02, 2.7232e-02, -4.2150e-02],\n",
+ " [-1.9656e-02, -2.1527e-02, -3.7997e-02],\n",
+ " [ 4.6947e-02, -4.2265e-02, 1.0091e-02]],\n",
+ " \n",
+ " [[ 2.7583e-03, -7.0453e-04, 2.0036e-02],\n",
+ " [ 2.1210e-02, 1.0332e-02, 3.0095e-02],\n",
+ " [-8.7497e-03, 2.3021e-02, 3.6112e-02]],\n",
+ " \n",
+ " [[ 2.3547e-03, -2.3150e-02, -3.4876e-02],\n",
+ " [ 8.3291e-03, 8.2892e-03, -2.3248e-02],\n",
+ " [-1.4488e-02, -2.4721e-02, 5.8742e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-8.3425e-02, -2.4922e-02, -6.0979e-03],\n",
+ " [ 2.5739e-02, 5.1047e-02, -1.5323e-02],\n",
+ " [ 9.1807e-03, 2.0354e-03, -3.6563e-04]],\n",
+ " \n",
+ " [[-1.8426e-02, 3.0250e-02, -3.0054e-03],\n",
+ " [ 2.6835e-02, -1.9758e-02, 2.7468e-02],\n",
+ " [ 3.3452e-02, -1.6044e-02, -1.7455e-02]],\n",
+ " \n",
+ " [[-5.7080e-03, 6.3050e-02, 3.1424e-02],\n",
+ " [ 9.9518e-03, 1.4794e-02, -1.4960e-02],\n",
+ " [ 1.0899e-02, 4.6267e-03, 1.7051e-02]]]], device='cuda:0')),\n",
+ " ('module.features.6.1.conv2.weight',\n",
+ " tensor([[[[ 3.4787e-02, 3.8060e-02, -1.5990e-02],\n",
+ " [ 1.7950e-02, -8.4159e-03, -6.6534e-02],\n",
+ " [-1.4080e-03, -3.2043e-03, 2.0568e-02]],\n",
+ " \n",
+ " [[-3.1033e-02, -2.5325e-02, 1.4106e-02],\n",
+ " [-2.0858e-02, -2.9413e-02, -1.6169e-02],\n",
+ " [ 4.9200e-03, 2.5553e-02, 1.8689e-02]],\n",
+ " \n",
+ " [[-7.8053e-03, -5.6583e-02, 1.0117e-02],\n",
+ " [ 2.1717e-02, -3.9589e-02, 8.9629e-03],\n",
+ " [ 7.4012e-03, 5.4506e-02, -2.3082e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-6.1332e-02, 2.4637e-02, 8.5247e-03],\n",
+ " [-7.8043e-03, -5.0832e-02, -3.5482e-02],\n",
+ " [-1.0316e-02, -2.4707e-02, -5.4953e-02]],\n",
+ " \n",
+ " [[-3.7195e-02, -2.7308e-02, -9.6095e-04],\n",
+ " [-1.1219e-02, 5.7739e-02, -2.0120e-02],\n",
+ " [ 4.1952e-02, 3.4357e-02, -2.2346e-02]],\n",
+ " \n",
+ " [[ 6.2356e-02, 5.6113e-02, -1.4165e-02],\n",
+ " [ 3.5778e-03, -1.5435e-02, 2.8328e-02],\n",
+ " [ 1.1567e-02, 1.4654e-02, 8.2448e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[-1.2194e-02, -5.0170e-03, 3.5250e-02],\n",
+ " [ 1.9931e-03, 6.0486e-03, 9.3887e-03],\n",
+ " [ 2.1655e-04, 1.3419e-02, 1.5752e-02]],\n",
+ " \n",
+ " [[ 5.1833e-02, -4.1004e-02, 1.4427e-02],\n",
+ " [ 5.7461e-03, 3.6509e-02, 4.9289e-02],\n",
+ " [ 6.3431e-02, 2.2885e-02, 2.8438e-02]],\n",
+ " \n",
+ " [[ 1.4997e-02, 5.6978e-02, -2.3945e-02],\n",
+ " [ 4.4951e-02, 3.1703e-02, -1.1541e-02],\n",
+ " [ 3.0593e-02, 1.4636e-02, 3.1538e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-2.0998e-02, -2.3779e-02, -2.7013e-02],\n",
+ " [ 4.0677e-03, 1.0402e-02, 2.4015e-03],\n",
+ " [ 8.1995e-02, 4.0297e-02, -5.9132e-02]],\n",
+ " \n",
+ " [[-1.5711e-02, 2.0702e-02, 1.6739e-02],\n",
+ " [-1.7031e-02, -8.0594e-03, 8.3916e-03],\n",
+ " [-3.0904e-02, -1.6253e-02, -3.8727e-02]],\n",
+ " \n",
+ " [[-4.1036e-03, 2.9179e-02, -1.2883e-02],\n",
+ " [-6.3859e-03, 2.3857e-03, 5.9913e-03],\n",
+ " [ 3.0187e-02, -3.4250e-02, 4.2737e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[-1.1786e-02, 2.8669e-02, 3.1567e-02],\n",
+ " [ 4.9293e-02, -4.7991e-03, -6.0374e-02],\n",
+ " [ 1.3316e-02, 6.5624e-02, -3.3356e-02]],\n",
+ " \n",
+ " [[-9.2111e-03, 1.4936e-02, 1.5386e-02],\n",
+ " [-3.3757e-02, 1.1189e-02, -3.3450e-02],\n",
+ " [ 3.7834e-02, -1.4629e-02, 1.1069e-02]],\n",
+ " \n",
+ " [[-5.7206e-03, -1.9249e-02, -7.7542e-03],\n",
+ " [ 2.9950e-02, 2.9959e-02, 2.7115e-02],\n",
+ " [-3.3033e-02, -4.4068e-02, -1.1164e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 2.6510e-03, -1.7953e-02, -1.1683e-02],\n",
+ " [ 3.0550e-03, 2.4446e-02, 1.6834e-02],\n",
+ " [ 4.1480e-02, -4.5584e-02, -5.1714e-02]],\n",
+ " \n",
+ " [[-2.7675e-02, 5.1731e-02, 2.0568e-02],\n",
+ " [ 4.1835e-02, 1.6490e-03, -3.9423e-02],\n",
+ " [ 1.4456e-02, 4.6297e-02, -2.3125e-02]],\n",
+ " \n",
+ " [[-5.5972e-02, 7.2672e-03, -4.7699e-02],\n",
+ " [ 1.6370e-02, 5.0866e-02, 3.3071e-02],\n",
+ " [-3.0254e-02, -4.6505e-03, -1.3263e-02]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-5.5671e-02, -1.4973e-02, 2.4598e-02],\n",
+ " [-1.5751e-02, -1.9332e-02, -1.1173e-02],\n",
+ " [-1.1446e-02, 3.4020e-02, -5.6328e-03]],\n",
+ " \n",
+ " [[ 1.4628e-02, 3.8786e-02, -1.1804e-02],\n",
+ " [ 1.1848e-02, 1.8123e-02, -1.0171e-02],\n",
+ " [ 6.0197e-02, 4.3300e-02, 5.8398e-02]],\n",
+ " \n",
+ " [[-6.3173e-02, -2.6803e-02, 1.3401e-03],\n",
+ " [ 3.0209e-02, 3.8472e-02, 3.5204e-02],\n",
+ " [-1.4885e-02, 3.1834e-02, -7.7356e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 1.1209e-02, -5.0194e-03, 4.2764e-02],\n",
+ " [-2.4707e-02, 3.6524e-04, -5.5653e-03],\n",
+ " [ 1.7874e-02, 1.0252e-02, 6.7133e-02]],\n",
+ " \n",
+ " [[ 1.4919e-02, -2.0242e-03, 1.3058e-02],\n",
+ " [-3.0284e-03, 4.6720e-02, 5.9795e-02],\n",
+ " [ 2.8785e-02, -1.5592e-02, 1.6045e-02]],\n",
+ " \n",
+ " [[-2.5472e-02, -8.5856e-02, -3.7504e-02],\n",
+ " [-3.0099e-02, -2.3069e-02, 1.2823e-02],\n",
+ " [-4.1428e-02, 1.5843e-02, 1.4451e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 7.7966e-03, 1.6178e-02, 2.6000e-02],\n",
+ " [-6.4233e-02, -1.7636e-02, 1.2902e-03],\n",
+ " [-4.1026e-04, 2.8500e-02, -3.6673e-02]],\n",
+ " \n",
+ " [[-3.1332e-02, -1.7827e-02, 4.2891e-02],\n",
+ " [-2.4205e-03, -2.4863e-02, 2.2896e-03],\n",
+ " [ 1.0987e-02, -3.0397e-02, -4.4000e-02]],\n",
+ " \n",
+ " [[ 5.1481e-03, 7.3660e-04, 3.3303e-03],\n",
+ " [ 3.6612e-04, -5.4450e-02, -8.3123e-03],\n",
+ " [ 2.0301e-03, -2.1761e-02, 3.3226e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 1.2331e-03, 8.0546e-03, -3.6006e-02],\n",
+ " [ 3.8699e-02, -2.6648e-02, -1.8826e-02],\n",
+ " [ 2.7367e-02, 1.4244e-02, 5.2926e-04]],\n",
+ " \n",
+ " [[-6.9364e-02, -1.0560e-02, 9.4717e-03],\n",
+ " [ 5.0586e-02, 1.5329e-03, 3.8197e-02],\n",
+ " [ 2.4806e-02, 7.5918e-02, -1.8269e-02]],\n",
+ " \n",
+ " [[ 3.2278e-02, -1.0672e-03, 6.8488e-03],\n",
+ " [-4.4616e-02, -3.5674e-02, 9.5346e-04],\n",
+ " [-1.3379e-02, 6.8442e-03, 9.0560e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[ 2.0115e-02, -5.3358e-03, -3.5381e-02],\n",
+ " [ 7.1608e-04, -6.9627e-03, -1.9737e-02],\n",
+ " [-8.2062e-03, -3.7454e-02, -7.4117e-02]],\n",
+ " \n",
+ " [[ 1.4927e-02, 7.1709e-02, 1.1718e-02],\n",
+ " [ 8.2372e-02, 5.8646e-03, 3.4174e-03],\n",
+ " [ 1.0936e-03, 3.0345e-02, -1.7796e-02]],\n",
+ " \n",
+ " [[-5.0297e-02, 2.2410e-02, 5.7437e-03],\n",
+ " [ 3.7350e-02, 1.3494e-02, -2.9290e-04],\n",
+ " [-6.8438e-03, 3.8460e-05, -9.2413e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 9.5568e-03, 2.5887e-03, 1.0262e-02],\n",
+ " [-3.2448e-03, -1.7702e-04, 1.8214e-02],\n",
+ " [-8.0327e-03, 6.6512e-04, -1.5375e-02]],\n",
+ " \n",
+ " [[ 3.9076e-02, 1.3856e-03, 1.0307e-02],\n",
+ " [ 2.3322e-02, 4.0026e-03, 3.5763e-02],\n",
+ " [ 1.3618e-02, 3.7627e-02, 1.0824e-02]],\n",
+ " \n",
+ " [[-1.6677e-03, -2.2723e-02, -5.5824e-02],\n",
+ " [-1.3569e-02, -1.2928e-02, -6.1438e-03],\n",
+ " [ 6.2071e-02, 1.5035e-03, -7.6897e-02]]]], device='cuda:0')),\n",
+ " ('module.features.7.0.conv1.weight',\n",
+ " tensor([[[[-3.4441e-03, 1.0994e-02, 1.4760e-03],\n",
+ " [ 1.9767e-02, -1.8906e-02, -2.6181e-02],\n",
+ " [ 1.4694e-02, 1.9673e-02, -1.7175e-03]],\n",
+ " \n",
+ " [[-3.7813e-03, -3.9056e-02, 2.6098e-02],\n",
+ " [-3.7201e-03, -4.2006e-03, -9.9549e-03],\n",
+ " [ 4.3169e-02, -1.2532e-02, 5.1302e-03]],\n",
+ " \n",
+ " [[-4.2043e-02, 1.4136e-02, -9.0424e-03],\n",
+ " [-2.1283e-03, -8.4711e-03, -2.1926e-03],\n",
+ " [-6.4922e-03, 7.2286e-03, 8.8581e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 1.8439e-02, -1.1141e-03, 2.0971e-02],\n",
+ " [-4.0137e-02, 1.4446e-02, -1.3450e-03],\n",
+ " [ 1.5027e-02, -7.5354e-03, -3.5068e-04]],\n",
+ " \n",
+ " [[ 4.2114e-03, 1.1973e-02, -4.3907e-03],\n",
+ " [ 7.0268e-03, 1.9686e-02, -2.6050e-02],\n",
+ " [ 1.4946e-03, -7.8286e-03, 9.1873e-03]],\n",
+ " \n",
+ " [[ 3.9229e-02, -4.6599e-03, -2.6167e-02],\n",
+ " [-1.6986e-02, -4.5638e-02, 3.0573e-03],\n",
+ " [-1.8197e-02, -2.7810e-02, -8.9829e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[-2.4756e-02, 1.4825e-02, 1.5259e-02],\n",
+ " [-1.4752e-02, 1.0032e-02, -4.7665e-02],\n",
+ " [ 1.0060e-02, 8.5348e-03, 3.7074e-02]],\n",
+ " \n",
+ " [[ 1.5102e-03, -2.1007e-02, 7.4820e-03],\n",
+ " [-1.0579e-03, -9.3815e-03, -6.6800e-04],\n",
+ " [-2.1778e-02, 2.0677e-02, -1.1825e-02]],\n",
+ " \n",
+ " [[-1.0584e-02, 3.6972e-03, -4.1202e-03],\n",
+ " [ 1.5519e-02, -4.3264e-03, 6.3549e-03],\n",
+ " [ 2.9458e-03, 1.1394e-02, -9.6818e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 4.2069e-03, -8.4145e-03, -1.2281e-02],\n",
+ " [-1.1881e-02, 2.5911e-02, 3.5466e-02],\n",
+ " [ 2.7593e-02, -4.0577e-02, 1.0283e-02]],\n",
+ " \n",
+ " [[-3.7335e-02, 3.5848e-02, -2.9818e-02],\n",
+ " [ 1.6793e-03, 6.0743e-04, -1.1339e-02],\n",
+ " [-5.2570e-02, -7.8037e-03, -3.1148e-02]],\n",
+ " \n",
+ " [[-4.5868e-02, 1.4489e-02, 5.0106e-03],\n",
+ " [ 1.0891e-02, -1.3956e-02, -7.5098e-03],\n",
+ " [-1.5168e-02, -6.3514e-04, -7.5874e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[ 1.2017e-02, -3.1987e-03, -1.2760e-02],\n",
+ " [ 9.7807e-03, -5.9038e-03, -3.6333e-02],\n",
+ " [ 1.5559e-02, -1.4835e-02, 1.2544e-02]],\n",
+ " \n",
+ " [[-1.5986e-03, -3.9330e-03, -1.9158e-02],\n",
+ " [ 1.5867e-02, -2.2501e-03, 1.3388e-02],\n",
+ " [ 1.9915e-02, 1.5831e-02, 8.9993e-03]],\n",
+ " \n",
+ " [[-2.7614e-02, 3.2708e-02, -2.1841e-02],\n",
+ " [-9.4685e-03, -2.4966e-02, 1.2511e-02],\n",
+ " [-1.9887e-03, -4.3608e-02, 3.9580e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 9.3413e-03, -5.7094e-03, 4.3542e-03],\n",
+ " [-6.8274e-03, 7.4210e-03, 1.1561e-02],\n",
+ " [ 6.4388e-03, 1.9781e-02, 4.2751e-03]],\n",
+ " \n",
+ " [[ 5.0592e-03, 2.7730e-03, -2.4055e-02],\n",
+ " [ 2.4276e-03, 8.3152e-04, 3.7686e-02],\n",
+ " [ 2.7420e-02, -9.2692e-03, 2.3494e-02]],\n",
+ " \n",
+ " [[ 5.9059e-03, -4.4244e-02, -1.3735e-03],\n",
+ " [-2.5886e-02, -3.9441e-02, 6.0072e-03],\n",
+ " [-1.1696e-02, -4.1958e-03, 6.9575e-03]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 3.8198e-02, -1.2132e-02, -1.1683e-02],\n",
+ " [-2.4794e-03, 2.5955e-02, 1.8615e-02],\n",
+ " [-1.7031e-03, 4.1369e-02, -7.3895e-03]],\n",
+ " \n",
+ " [[ 1.4308e-03, 4.2879e-03, -1.3985e-02],\n",
+ " [ 1.5767e-02, 4.8289e-03, -3.0731e-02],\n",
+ " [ 1.2513e-02, 5.6250e-02, 4.5197e-04]],\n",
+ " \n",
+ " [[ 2.4120e-02, -5.4435e-03, 4.5873e-03],\n",
+ " [ 2.0246e-03, 2.8319e-02, 9.0150e-03],\n",
+ " [-1.1607e-02, 1.8807e-02, 2.4154e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-3.2622e-03, 1.1339e-02, 4.2190e-02],\n",
+ " [ 2.8931e-02, 1.8660e-02, -3.6523e-02],\n",
+ " [-9.0465e-03, 3.1880e-02, 3.1114e-02]],\n",
+ " \n",
+ " [[-4.3627e-03, 2.1465e-02, 6.3580e-03],\n",
+ " [ 3.1705e-03, 1.1819e-02, 3.9138e-02],\n",
+ " [ 2.9341e-03, 1.3085e-02, 2.7232e-02]],\n",
+ " \n",
+ " [[-1.6039e-02, -2.7102e-02, -3.2196e-02],\n",
+ " [-1.0371e-02, 3.2571e-03, 4.9135e-03],\n",
+ " [-4.6609e-04, -2.5075e-03, 2.4381e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 1.0021e-02, 5.6412e-03, -2.6135e-02],\n",
+ " [-2.0356e-02, -1.7683e-04, 3.3079e-03],\n",
+ " [-1.4637e-02, 6.8626e-02, -4.9217e-02]],\n",
+ " \n",
+ " [[ 1.9138e-03, -2.6581e-02, -1.5232e-03],\n",
+ " [-1.0672e-02, 5.2147e-03, 1.7318e-02],\n",
+ " [ 2.1448e-03, -1.5667e-02, 6.1177e-03]],\n",
+ " \n",
+ " [[-2.5615e-02, 2.5841e-02, -2.5348e-03],\n",
+ " [-3.3380e-03, 1.4092e-02, -1.3205e-02],\n",
+ " [-6.7538e-03, 3.6098e-03, -1.2386e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-1.8762e-03, -1.0771e-02, -2.6408e-02],\n",
+ " [ 2.6240e-02, 2.8887e-02, 2.0198e-02],\n",
+ " [ 6.9128e-05, -1.2740e-02, -1.1935e-03]],\n",
+ " \n",
+ " [[-1.0645e-02, -2.0356e-02, 3.9404e-02],\n",
+ " [ 6.0474e-03, 1.1074e-02, 2.7843e-02],\n",
+ " [ 7.6534e-03, -2.0945e-02, 3.9055e-02]],\n",
+ " \n",
+ " [[-2.9758e-02, -9.6599e-03, 1.1365e-02],\n",
+ " [ 6.4649e-03, 1.7953e-02, -2.0368e-02],\n",
+ " [-9.9998e-03, -4.2007e-03, 8.8341e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[ 3.6774e-02, 6.3172e-03, 2.2152e-02],\n",
+ " [-2.1888e-03, 1.4285e-02, 1.0440e-02],\n",
+ " [ 1.6240e-02, -2.2892e-02, -1.0568e-02]],\n",
+ " \n",
+ " [[-1.1015e-02, 2.5463e-02, 9.1917e-04],\n",
+ " [ 3.3877e-02, 2.0663e-03, 1.8499e-02],\n",
+ " [ 4.4987e-02, -5.6371e-02, -1.6414e-02]],\n",
+ " \n",
+ " [[ 3.9439e-04, -3.3985e-02, -1.5881e-02],\n",
+ " [ 5.9595e-03, 3.1137e-02, -1.7259e-03],\n",
+ " [-2.1612e-02, -1.1537e-02, 6.6468e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 2.9433e-02, -1.1377e-02, 8.4307e-03],\n",
+ " [ 4.2296e-03, 1.1831e-02, 1.6850e-02],\n",
+ " [-1.2929e-02, -7.1755e-03, 1.8415e-02]],\n",
+ " \n",
+ " [[-1.6618e-02, 7.8405e-03, 4.0864e-02],\n",
+ " [ 4.8490e-02, 4.2489e-02, 6.0160e-03],\n",
+ " [ 3.7563e-02, 1.0776e-02, -6.1331e-03]],\n",
+ " \n",
+ " [[ 2.6469e-02, -6.6585e-04, -6.7010e-04],\n",
+ " [ 3.8410e-03, -1.4263e-02, 6.1104e-03],\n",
+ " [-1.0586e-02, 2.6340e-02, -3.0821e-02]]]], device='cuda:0')),\n",
+ " ('module.features.7.0.conv2.weight',\n",
+ " tensor([[[[-1.2200e-02, 2.4671e-02, -9.8238e-03],\n",
+ " [-1.2586e-02, -1.3236e-02, 2.5329e-03],\n",
+ " [-1.6542e-03, 1.4174e-03, 8.2470e-03]],\n",
+ " \n",
+ " [[-4.4317e-02, 1.0416e-02, 3.3988e-02],\n",
+ " [-3.3332e-02, 8.5336e-03, -4.7323e-02],\n",
+ " [ 5.7513e-03, -1.8317e-02, -2.3702e-02]],\n",
+ " \n",
+ " [[-2.7567e-02, 3.0532e-02, 3.4585e-02],\n",
+ " [-1.6679e-02, -2.7876e-02, -5.6227e-03],\n",
+ " [-6.7172e-03, -9.3788e-03, 3.3215e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-2.6134e-02, -3.7077e-03, -2.7437e-02],\n",
+ " [ 8.5012e-06, 2.5855e-02, 7.1096e-03],\n",
+ " [-5.5645e-03, 5.7466e-03, -2.4518e-02]],\n",
+ " \n",
+ " [[-4.4125e-02, 2.1510e-03, -3.1742e-04],\n",
+ " [ 1.0338e-02, -1.1641e-02, 5.6042e-02],\n",
+ " [-8.5308e-03, -2.9704e-02, 2.9754e-02]],\n",
+ " \n",
+ " [[ 1.4525e-02, 7.8045e-03, 1.7892e-02],\n",
+ " [ 8.9423e-03, -2.1311e-02, 1.8905e-02],\n",
+ " [ 1.4926e-03, 5.7177e-02, 4.0705e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 1.4293e-02, 2.2090e-02, 3.3967e-03],\n",
+ " [-3.1053e-02, -4.6200e-03, 1.2056e-02],\n",
+ " [-7.9522e-03, -4.3661e-02, -4.2358e-03]],\n",
+ " \n",
+ " [[ 4.4703e-02, -2.4620e-03, 3.6090e-03],\n",
+ " [-5.4236e-04, -3.4849e-02, 6.2990e-03],\n",
+ " [ 1.1801e-02, -6.0881e-03, -6.4435e-03]],\n",
+ " \n",
+ " [[-1.1111e-02, -4.8533e-03, -4.5175e-02],\n",
+ " [-1.0993e-02, 2.3395e-02, -2.0765e-02],\n",
+ " [-1.9418e-02, -1.3892e-03, -1.1269e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 7.2273e-03, -1.7110e-02, 5.6988e-03],\n",
+ " [ 2.1792e-02, -9.0769e-03, -1.5590e-03],\n",
+ " [ 2.9187e-02, 2.6378e-02, -4.0534e-03]],\n",
+ " \n",
+ " [[ 8.8414e-03, 2.6818e-02, 1.2076e-04],\n",
+ " [-8.8425e-04, 1.2134e-02, -1.3035e-02],\n",
+ " [ 1.3764e-02, 4.9568e-02, 7.7859e-03]],\n",
+ " \n",
+ " [[-3.3327e-02, -2.3628e-02, 3.3143e-02],\n",
+ " [ 2.0608e-02, 6.5762e-03, 8.5704e-03],\n",
+ " [ 4.0431e-02, -6.4119e-03, -2.8803e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 3.0359e-03, 2.2854e-02, 3.4083e-02],\n",
+ " [-2.1899e-02, -4.9271e-03, 2.5522e-02],\n",
+ " [-2.8607e-02, -1.9181e-02, 3.9501e-03]],\n",
+ " \n",
+ " [[ 1.4994e-02, -2.1828e-02, 6.4722e-03],\n",
+ " [ 1.9912e-02, 4.4057e-03, 1.0549e-03],\n",
+ " [-2.5813e-02, -2.5785e-02, 1.4741e-02]],\n",
+ " \n",
+ " [[-2.3405e-03, -5.0771e-03, -1.5741e-02],\n",
+ " [-7.1172e-03, -2.1527e-02, 1.0617e-02],\n",
+ " [ 2.0363e-02, -1.7201e-02, 5.4293e-04]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-8.7831e-03, 4.9798e-03, 2.3290e-02],\n",
+ " [-2.3432e-02, 3.7738e-03, -1.0715e-03],\n",
+ " [ 2.0324e-02, -1.2834e-02, 2.3588e-02]],\n",
+ " \n",
+ " [[-9.9962e-03, -2.8735e-02, -5.6838e-03],\n",
+ " [ 2.4581e-02, -7.0371e-03, -1.3003e-02],\n",
+ " [-4.5345e-02, -1.1756e-02, -2.2176e-02]],\n",
+ " \n",
+ " [[-3.6717e-02, 1.0350e-02, 1.0031e-02],\n",
+ " [ 1.9931e-02, 1.5897e-02, 8.8945e-04],\n",
+ " [-3.0208e-02, 2.8018e-02, 3.4711e-03]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[ 6.2676e-04, -5.8382e-03, 4.8320e-04],\n",
+ " [ 1.6859e-02, 2.4176e-02, 9.4217e-03],\n",
+ " [-3.6895e-03, 5.9570e-03, 3.2383e-02]],\n",
+ " \n",
+ " [[-5.2046e-02, 3.5914e-02, -2.2961e-03],\n",
+ " [-3.3788e-02, 1.8930e-02, -1.9586e-02],\n",
+ " [-2.6040e-02, -1.1494e-02, -1.3791e-02]],\n",
+ " \n",
+ " [[-2.6943e-02, 2.2764e-02, -8.1082e-03],\n",
+ " [ 1.9537e-02, 2.3907e-03, 2.4450e-02],\n",
+ " [ 1.1038e-02, -1.3544e-03, 8.8689e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-2.9265e-02, 2.9907e-03, -1.9701e-02],\n",
+ " [ 7.6547e-03, 2.3468e-02, 3.0149e-03],\n",
+ " [-5.6129e-02, -5.5228e-02, -1.4498e-02]],\n",
+ " \n",
+ " [[-7.0285e-03, -1.4398e-02, -8.0260e-03],\n",
+ " [ 8.6078e-03, 1.7272e-02, 3.7922e-03],\n",
+ " [ 1.7967e-02, -5.8396e-03, 1.2388e-02]],\n",
+ " \n",
+ " [[-1.9122e-02, 8.0995e-03, 1.4047e-02],\n",
+ " [-2.0732e-02, -1.7510e-03, -7.0565e-03],\n",
+ " [-8.9632e-03, 2.5373e-02, -2.1325e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[-7.2139e-03, 7.5723e-04, 1.5858e-02],\n",
+ " [ 4.4614e-03, 2.5813e-03, 7.7735e-03],\n",
+ " [-3.8430e-02, -1.2453e-02, -1.4860e-03]],\n",
+ " \n",
+ " [[ 1.5345e-02, -1.1651e-02, -4.7176e-02],\n",
+ " [-1.6323e-02, -3.9558e-02, -6.2256e-02],\n",
+ " [-2.7249e-02, -3.9055e-03, 3.3146e-02]],\n",
+ " \n",
+ " [[ 2.7576e-03, -1.1254e-02, 1.5793e-02],\n",
+ " [ 9.0961e-04, -1.9929e-02, -3.9376e-02],\n",
+ " [-4.5218e-02, 6.8002e-03, 2.4895e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-3.2264e-02, -4.5119e-02, -2.3231e-02],\n",
+ " [ 4.5156e-02, 1.1634e-02, -1.9521e-02],\n",
+ " [ 8.7846e-03, 1.1269e-02, -2.7752e-02]],\n",
+ " \n",
+ " [[-2.6495e-02, 2.2894e-02, 1.4803e-02],\n",
+ " [ 4.4075e-03, 2.3500e-02, 3.7441e-03],\n",
+ " [ 9.3967e-03, 1.2344e-02, 3.9956e-02]],\n",
+ " \n",
+ " [[ 4.6042e-04, -9.5107e-04, 5.8999e-02],\n",
+ " [ 1.6742e-03, 8.3862e-03, -1.2621e-02],\n",
+ " [-9.1627e-03, 3.0100e-02, 1.9895e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 2.0569e-02, -1.0795e-02, -8.7080e-03],\n",
+ " [ 2.4528e-02, 1.6483e-02, 9.9702e-04],\n",
+ " [ 3.6408e-03, -7.5199e-04, -5.9227e-02]],\n",
+ " \n",
+ " [[ 1.1502e-02, 2.2854e-02, -4.8116e-03],\n",
+ " [-1.7260e-02, -6.9348e-04, 8.0554e-03],\n",
+ " [ 8.5994e-03, -2.4679e-02, -3.9365e-02]],\n",
+ " \n",
+ " [[ 4.1262e-02, 1.6001e-02, 1.1125e-02],\n",
+ " [ 1.9232e-02, -4.0470e-02, -4.2124e-03],\n",
+ " [-3.1845e-02, -2.8374e-03, 1.2675e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-2.6239e-02, -3.9554e-02, 1.6393e-02],\n",
+ " [-6.0580e-03, 2.7392e-02, 2.6700e-02],\n",
+ " [-9.8317e-03, 2.6180e-02, -1.0239e-02]],\n",
+ " \n",
+ " [[-2.6586e-02, 3.0612e-02, -1.3597e-03],\n",
+ " [ 4.8483e-02, -1.3060e-02, -1.8707e-02],\n",
+ " [-6.5954e-03, 1.6304e-02, -2.2056e-02]],\n",
+ " \n",
+ " [[-2.0903e-03, -2.5995e-02, 4.7070e-02],\n",
+ " [ 1.9098e-02, -1.4131e-02, 1.0743e-02],\n",
+ " [-7.3639e-03, 3.8980e-02, -1.8740e-02]]]], device='cuda:0')),\n",
+ " ('module.features.7.0.downsample.0.weight',\n",
+ " tensor([[[[ 0.0280]],\n",
+ " \n",
+ " [[ 0.1087]],\n",
+ " \n",
+ " [[ 0.0847]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0122]],\n",
+ " \n",
+ " [[-0.0237]],\n",
+ " \n",
+ " [[ 0.0892]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.1502]],\n",
+ " \n",
+ " [[-0.0710]],\n",
+ " \n",
+ " [[-0.0160]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0163]],\n",
+ " \n",
+ " [[ 0.0837]],\n",
+ " \n",
+ " [[-0.0358]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0068]],\n",
+ " \n",
+ " [[ 0.0029]],\n",
+ " \n",
+ " [[ 0.0566]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0893]],\n",
+ " \n",
+ " [[ 0.0697]],\n",
+ " \n",
+ " [[ 0.0071]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-0.1321]],\n",
+ " \n",
+ " [[-0.0198]],\n",
+ " \n",
+ " [[-0.0812]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0576]],\n",
+ " \n",
+ " [[ 0.0471]],\n",
+ " \n",
+ " [[-0.0246]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.1082]],\n",
+ " \n",
+ " [[ 0.0021]],\n",
+ " \n",
+ " [[ 0.0406]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0061]],\n",
+ " \n",
+ " [[-0.0029]],\n",
+ " \n",
+ " [[ 0.0266]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.1127]],\n",
+ " \n",
+ " [[-0.0219]],\n",
+ " \n",
+ " [[-0.0168]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0372]],\n",
+ " \n",
+ " [[-0.0544]],\n",
+ " \n",
+ " [[-0.0213]]]], device='cuda:0')),\n",
+ " ('module.features.7.1.conv1.weight',\n",
+ " tensor([[[[-0.0412, 0.0398, 0.0164],\n",
+ " [ 0.0117, 0.0329, 0.0236],\n",
+ " [ 0.0188, -0.0364, -0.0361]],\n",
+ " \n",
+ " [[-0.0310, -0.0105, -0.0407],\n",
+ " [-0.0084, -0.0072, 0.0195],\n",
+ " [ 0.0076, 0.0183, 0.0473]],\n",
+ " \n",
+ " [[ 0.0089, 0.0133, -0.0262],\n",
+ " [-0.0185, -0.0104, -0.0033],\n",
+ " [ 0.0133, 0.0122, -0.0572]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0086, 0.0117, -0.0214],\n",
+ " [ 0.0113, -0.0197, 0.0212],\n",
+ " [-0.0350, 0.0330, -0.0123]],\n",
+ " \n",
+ " [[ 0.0009, 0.0155, 0.0001],\n",
+ " [-0.0108, 0.0034, -0.0038],\n",
+ " [ 0.0087, 0.0137, -0.0251]],\n",
+ " \n",
+ " [[-0.0321, -0.0180, 0.0369],\n",
+ " [ 0.0048, 0.0178, 0.0107],\n",
+ " [-0.0123, 0.0056, 0.0049]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0111, -0.0211, -0.0223],\n",
+ " [-0.0191, -0.0136, 0.0010],\n",
+ " [-0.0029, 0.0223, 0.0629]],\n",
+ " \n",
+ " [[-0.0196, 0.0049, -0.0111],\n",
+ " [ 0.0151, -0.0025, -0.0261],\n",
+ " [ 0.0095, -0.0071, 0.0136]],\n",
+ " \n",
+ " [[ 0.0367, 0.0324, -0.0228],\n",
+ " [-0.0251, 0.0110, -0.0151],\n",
+ " [ 0.0457, -0.0145, 0.0162]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0022, -0.0046, -0.0142],\n",
+ " [ 0.0224, 0.0348, 0.0015],\n",
+ " [-0.0036, 0.0051, -0.0091]],\n",
+ " \n",
+ " [[ 0.0116, 0.0166, 0.0132],\n",
+ " [-0.0150, 0.0014, 0.0032],\n",
+ " [-0.0103, 0.0153, -0.0235]],\n",
+ " \n",
+ " [[-0.0142, -0.0050, -0.0244],\n",
+ " [ 0.0185, -0.0016, 0.0012],\n",
+ " [-0.0144, 0.0141, -0.0072]]],\n",
+ " \n",
+ " \n",
+ " [[[-0.0123, 0.0130, -0.0075],\n",
+ " [-0.0071, -0.0082, -0.0034],\n",
+ " [ 0.0143, 0.0157, -0.0249]],\n",
+ " \n",
+ " [[-0.0159, 0.0027, -0.0388],\n",
+ " [ 0.0021, 0.0024, -0.0288],\n",
+ " [-0.0225, -0.0067, 0.0192]],\n",
+ " \n",
+ " [[-0.0156, -0.0163, 0.0312],\n",
+ " [ 0.0416, 0.0107, -0.0375],\n",
+ " [ 0.0049, -0.0461, -0.0219]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0265, -0.0160, -0.0259],\n",
+ " [-0.0189, -0.0011, -0.0063],\n",
+ " [ 0.0012, 0.0453, -0.0172]],\n",
+ " \n",
+ " [[ 0.0210, 0.0088, -0.0411],\n",
+ " [-0.0089, 0.0210, -0.0035],\n",
+ " [ 0.0183, -0.0061, 0.0316]],\n",
+ " \n",
+ " [[-0.0117, -0.0214, -0.0032],\n",
+ " [-0.0275, -0.0606, -0.0150],\n",
+ " [ 0.0040, -0.0216, -0.0044]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-0.0001, 0.0213, -0.0171],\n",
+ " [ 0.0011, -0.0231, 0.0122],\n",
+ " [-0.0306, 0.0143, 0.0080]],\n",
+ " \n",
+ " [[-0.0027, 0.0032, -0.0124],\n",
+ " [ 0.0044, -0.0284, 0.0182],\n",
+ " [ 0.0099, 0.0445, 0.0078]],\n",
+ " \n",
+ " [[-0.0175, 0.0329, -0.0161],\n",
+ " [-0.0447, -0.0261, -0.0158],\n",
+ " [-0.0196, 0.0309, -0.0215]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0195, 0.0048, -0.0415],\n",
+ " [-0.0131, -0.0159, -0.0024],\n",
+ " [ 0.0295, 0.0165, -0.0192]],\n",
+ " \n",
+ " [[ 0.0355, 0.0102, 0.0040],\n",
+ " [ 0.0007, 0.0181, -0.0201],\n",
+ " [-0.0097, 0.0173, -0.0389]],\n",
+ " \n",
+ " [[ 0.0024, -0.0168, -0.0181],\n",
+ " [ 0.0059, 0.0004, -0.0522],\n",
+ " [ 0.0194, 0.0180, -0.0293]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0062, 0.0105, -0.0109],\n",
+ " [-0.0213, 0.0117, 0.0115],\n",
+ " [-0.0232, -0.0389, 0.0185]],\n",
+ " \n",
+ " [[-0.0121, -0.0008, -0.0179],\n",
+ " [-0.0042, -0.0131, 0.0363],\n",
+ " [-0.0141, 0.0162, 0.0122]],\n",
+ " \n",
+ " [[-0.0172, 0.0188, -0.0150],\n",
+ " [-0.0093, 0.0270, 0.0506],\n",
+ " [-0.0337, -0.0070, -0.0267]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 0.0118, -0.0035, 0.0332],\n",
+ " [-0.0185, -0.0073, -0.0262],\n",
+ " [-0.0110, -0.0086, 0.0152]],\n",
+ " \n",
+ " [[ 0.0105, -0.0218, -0.0302],\n",
+ " [-0.0114, 0.0171, -0.0384],\n",
+ " [ 0.0255, -0.0277, 0.0629]],\n",
+ " \n",
+ " [[ 0.0015, 0.0425, -0.0147],\n",
+ " [-0.0056, -0.0232, -0.0242],\n",
+ " [ 0.0015, 0.0060, -0.0036]]],\n",
+ " \n",
+ " \n",
+ " [[[ 0.0189, 0.0194, 0.0064],\n",
+ " [-0.0075, -0.0141, -0.0050],\n",
+ " [ 0.0002, -0.0243, -0.0248]],\n",
+ " \n",
+ " [[ 0.0173, 0.0066, -0.0278],\n",
+ " [-0.0158, -0.0061, 0.0161],\n",
+ " [-0.0176, -0.0237, -0.0293]],\n",
+ " \n",
+ " [[ 0.0067, -0.0371, 0.0001],\n",
+ " [-0.0122, 0.0012, -0.0346],\n",
+ " [-0.0239, -0.0195, 0.0066]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-0.0004, -0.0357, -0.0282],\n",
+ " [-0.0071, -0.0012, -0.0346],\n",
+ " [-0.0103, -0.0152, -0.0183]],\n",
+ " \n",
+ " [[-0.0087, 0.0358, 0.0211],\n",
+ " [ 0.0090, -0.0186, 0.0573],\n",
+ " [-0.0072, 0.0191, -0.0075]],\n",
+ " \n",
+ " [[-0.0091, 0.0155, 0.0092],\n",
+ " [ 0.0244, 0.0233, 0.0293],\n",
+ " [ 0.0218, -0.0016, -0.0111]]]], device='cuda:0')),\n",
+ " ('module.features.7.1.conv2.weight',\n",
+ " tensor([[[[ 2.7669e-02, 1.0102e-02, -1.5389e-02],\n",
+ " [-7.7896e-03, 1.7454e-02, -5.2838e-03],\n",
+ " [ 3.2319e-02, -9.5478e-03, 2.1955e-02]],\n",
+ " \n",
+ " [[ 1.3091e-02, 1.2744e-02, -6.0428e-03],\n",
+ " [ 7.9706e-04, 4.0532e-03, -2.7187e-03],\n",
+ " [ 5.6271e-03, -2.0450e-02, -8.3630e-04]],\n",
+ " \n",
+ " [[ 2.6977e-02, -1.3292e-02, 9.1527e-03],\n",
+ " [-2.3563e-02, -2.7924e-02, -1.6096e-02],\n",
+ " [ 5.5006e-04, 1.0982e-02, 1.1167e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 3.2128e-03, -1.1398e-02, -1.7091e-02],\n",
+ " [-5.3630e-02, -1.8246e-02, -7.6490e-03],\n",
+ " [-1.8577e-02, 9.7755e-03, -2.1825e-02]],\n",
+ " \n",
+ " [[ 2.5462e-02, -4.8757e-03, 1.7404e-03],\n",
+ " [ 6.9207e-03, 2.9026e-02, 1.9597e-02],\n",
+ " [-2.6660e-03, -1.3143e-02, 1.3166e-02]],\n",
+ " \n",
+ " [[-9.2363e-03, 1.4657e-02, -1.7753e-02],\n",
+ " [-1.5085e-02, -8.5969e-03, 1.7411e-02],\n",
+ " [ 3.1844e-02, -1.2665e-02, 3.8360e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 1.6494e-02, 4.0510e-02, -1.1034e-02],\n",
+ " [ 1.0801e-02, -4.6805e-02, 3.1894e-02],\n",
+ " [ 3.1285e-02, -2.7874e-02, 3.1512e-03]],\n",
+ " \n",
+ " [[ 1.3249e-02, -2.4612e-02, 8.0424e-03],\n",
+ " [ 1.4087e-02, 8.3113e-03, -1.6492e-02],\n",
+ " [-1.6472e-02, 1.4578e-02, -1.0045e-02]],\n",
+ " \n",
+ " [[-2.6437e-02, -1.5198e-02, 8.9365e-03],\n",
+ " [ 4.9480e-03, 3.0334e-02, 7.5042e-03],\n",
+ " [ 1.6577e-02, -5.5669e-03, -8.8701e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-3.6576e-02, -2.9337e-02, 3.2167e-03],\n",
+ " [-8.1815e-04, -1.7954e-02, -1.2874e-02],\n",
+ " [ 8.4388e-05, -2.0595e-02, -4.4706e-02]],\n",
+ " \n",
+ " [[-1.6041e-02, 1.0156e-02, -9.1008e-03],\n",
+ " [ 4.5301e-02, 4.7988e-04, 8.6943e-03],\n",
+ " [-1.1701e-02, 4.7244e-02, 1.7367e-02]],\n",
+ " \n",
+ " [[ 1.3813e-02, -4.0753e-02, -2.4185e-03],\n",
+ " [-2.4735e-02, -2.6664e-02, 9.9198e-04],\n",
+ " [ 1.4143e-02, -9.8058e-03, 2.7278e-03]]],\n",
+ " \n",
+ " \n",
+ " [[[-3.9752e-03, 2.7291e-03, 2.9844e-02],\n",
+ " [-8.5228e-03, -6.5568e-03, -1.9087e-02],\n",
+ " [ 5.2401e-03, 5.0350e-04, 1.7048e-02]],\n",
+ " \n",
+ " [[-9.2543e-03, 1.8186e-02, -1.7019e-02],\n",
+ " [-3.3054e-02, 2.0240e-03, 5.5067e-03],\n",
+ " [-4.1597e-03, 9.8459e-03, -8.0485e-03]],\n",
+ " \n",
+ " [[-3.1671e-02, -3.4744e-03, -2.2173e-02],\n",
+ " [ 1.9146e-02, 5.9635e-03, 2.6835e-02],\n",
+ " [-2.7219e-02, -1.4855e-02, 1.3716e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-1.3893e-03, -2.3909e-03, 3.7731e-03],\n",
+ " [ 2.5951e-02, 2.2350e-02, 3.3649e-02],\n",
+ " [ 8.8714e-03, -3.2832e-03, 3.0970e-02]],\n",
+ " \n",
+ " [[ 1.6667e-02, -7.0636e-03, -8.0601e-05],\n",
+ " [-3.8572e-03, -2.2319e-02, -2.1308e-02],\n",
+ " [-2.4768e-02, -1.5440e-02, -1.5375e-02]],\n",
+ " \n",
+ " [[ 4.0461e-04, -1.9128e-02, 2.8521e-02],\n",
+ " [ 2.7569e-03, 1.2434e-02, 1.1612e-02],\n",
+ " [ 4.9260e-02, 2.2169e-02, -2.4459e-02]]],\n",
+ " \n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " \n",
+ " [[[-2.1874e-02, 1.4048e-03, 2.3160e-02],\n",
+ " [-2.2029e-02, 1.3565e-02, -1.1423e-02],\n",
+ " [ 1.0922e-02, 1.6705e-02, 4.6132e-03]],\n",
+ " \n",
+ " [[ 1.5474e-03, 4.6071e-02, 9.5328e-03],\n",
+ " [ 4.8634e-03, 3.2872e-02, -4.5178e-04],\n",
+ " [-2.0327e-02, 5.2714e-03, 2.2405e-02]],\n",
+ " \n",
+ " [[-1.3570e-02, -5.1997e-02, -2.5401e-02],\n",
+ " [ 2.3371e-02, -2.0209e-02, -1.3148e-02],\n",
+ " [ 1.3217e-02, -9.4160e-03, -1.7271e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-9.3603e-03, -1.0252e-02, -2.1926e-03],\n",
+ " [ 8.3288e-03, 1.4618e-02, 2.2519e-02],\n",
+ " [ 4.3726e-03, 1.5235e-02, 2.4148e-02]],\n",
+ " \n",
+ " [[ 2.1585e-02, -2.3678e-03, 1.3339e-03],\n",
+ " [-1.2002e-02, -1.8972e-02, -8.1394e-03],\n",
+ " [ 4.2480e-03, -1.7132e-02, 2.2052e-02]],\n",
+ " \n",
+ " [[ 1.0347e-02, 2.2348e-02, -9.3872e-03],\n",
+ " [-1.3332e-02, 1.9236e-02, -5.8983e-03],\n",
+ " [-2.6036e-02, -1.4431e-02, -4.9069e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 3.4349e-03, 3.0972e-02, -2.8677e-02],\n",
+ " [ 7.1175e-03, 2.1691e-02, 4.0088e-03],\n",
+ " [ 2.0570e-02, 4.5587e-03, 5.5360e-03]],\n",
+ " \n",
+ " [[-2.6303e-02, -6.6955e-03, -3.8755e-02],\n",
+ " [-1.7761e-02, -9.3154e-03, 1.6708e-02],\n",
+ " [-2.8337e-02, 1.5408e-02, -1.2784e-02]],\n",
+ " \n",
+ " [[-1.0890e-02, 1.0057e-03, 4.9444e-03],\n",
+ " [ 1.5475e-02, 2.7684e-02, 3.2346e-03],\n",
+ " [-1.2827e-02, -8.9223e-03, 1.1571e-02]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[ 2.9467e-02, -1.7513e-02, 2.1790e-02],\n",
+ " [ 9.5623e-03, 1.5912e-03, -2.8656e-02],\n",
+ " [-2.0841e-02, 3.4052e-02, 1.4477e-02]],\n",
+ " \n",
+ " [[ 2.5121e-02, 1.3353e-02, 1.7004e-02],\n",
+ " [-2.2956e-02, -7.9317e-03, -7.7189e-03],\n",
+ " [-2.2805e-02, 5.5981e-03, -9.2481e-03]],\n",
+ " \n",
+ " [[-7.4896e-03, 1.5322e-03, -1.4587e-02],\n",
+ " [ 4.1217e-02, 1.2760e-02, -3.7466e-02],\n",
+ " [ 4.6187e-02, 3.8541e-03, -3.0564e-02]]],\n",
+ " \n",
+ " \n",
+ " [[[ 2.1435e-02, 4.5467e-02, 4.5000e-03],\n",
+ " [ 7.6490e-03, -2.5208e-03, 2.3527e-02],\n",
+ " [ 2.5580e-02, 1.6698e-02, -1.8860e-02]],\n",
+ " \n",
+ " [[ 1.4487e-02, -3.4432e-03, -7.5392e-03],\n",
+ " [ 1.8938e-02, 4.7559e-02, -2.5380e-02],\n",
+ " [ 2.1595e-02, -8.7880e-03, 2.2249e-03]],\n",
+ " \n",
+ " [[ 1.0577e-02, 2.2048e-02, -1.5811e-03],\n",
+ " [-3.6731e-02, -4.3263e-03, 9.8155e-03],\n",
+ " [ 5.3485e-03, 3.3231e-02, 4.9549e-03]],\n",
+ " \n",
+ " ...,\n",
+ " \n",
+ " [[-1.4723e-02, -1.8430e-02, 2.0312e-02],\n",
+ " [-4.0386e-02, 1.3468e-02, -9.7934e-03],\n",
+ " [-7.6757e-03, 5.5895e-02, -6.2044e-03]],\n",
+ " \n",
+ " [[-4.6726e-03, -1.3059e-02, -6.3235e-03],\n",
+ " [ 2.0474e-02, 9.5724e-03, -1.0297e-02],\n",
+ " [-2.9894e-02, -2.8968e-02, 1.5460e-02]],\n",
+ " \n",
+ " [[ 3.6078e-03, 4.0985e-03, -1.9945e-02],\n",
+ " [ 8.5199e-03, -1.5623e-02, 1.6469e-02],\n",
+ " [-4.2940e-03, 3.3834e-03, -2.3444e-02]]]], device='cuda:0')),\n",
+ " ('module.l1.weight',\n",
+ " tensor([[-2.8849e-02, -1.9119e-02, 3.8300e-02, ..., 4.1686e-02,\n",
+ " 3.6721e-02, 1.5111e-03],\n",
+ " [-2.2924e-03, -7.3991e-05, -4.6714e-04, ..., 2.8410e-04,\n",
+ " -7.1595e-05, 1.2524e-03],\n",
+ " [-4.3433e-02, -1.6126e-02, 2.4300e-02, ..., 3.1861e-02,\n",
+ " -4.4353e-03, 3.3997e-02],\n",
+ " ...,\n",
+ " [-4.2116e-02, 9.1577e-03, -2.8979e-03, ..., 8.2516e-03,\n",
+ " -2.0367e-02, -2.6846e-02],\n",
+ " [-3.9404e-02, 3.1663e-03, -3.1503e-02, ..., -5.0674e-03,\n",
+ " -1.2334e-02, 2.4472e-02],\n",
+ " [ 4.0532e-02, 2.3001e-02, -4.4442e-02, ..., 4.0665e-02,\n",
+ " -1.5093e-03, -3.1281e-02]], device='cuda:0')),\n",
+ " ('module.l1.bias',\n",
+ " tensor([ 3.0996e-02, -1.0540e-04, -1.4748e-02, 2.5318e-04, 4.0875e-03,\n",
+ " -2.7277e-02, 2.8669e-02, -2.7196e-15, -1.9390e-02, -6.0616e-03,\n",
+ " 2.7244e-02, -7.6942e-16, 7.5572e-03, -4.1416e-02, 3.1923e-02,\n",
+ " -4.3156e-02, -7.3542e-23, -1.2017e-24, -3.8204e-02, -2.5722e-02,\n",
+ " -3.6925e-15, -3.7085e-02, -5.6505e-15, -1.9316e-02, -5.5478e-03,\n",
+ " 4.0829e-02, -7.1897e-04, 4.1314e-02, -1.0092e-02, -5.3813e-04,\n",
+ " 2.4123e-02, -1.0446e-02, -3.5741e-18, -2.5170e-04, -3.2334e-02,\n",
+ " -7.4074e-16, -2.1480e-04, -9.5165e-03, 2.5351e-02, -8.0323e-03,\n",
+ " 2.2315e-02, -1.6049e-03, -9.0869e-03, -3.8037e-02, 3.1691e-02,\n",
+ " 3.4157e-02, -1.7944e-03, -1.9633e-02, 3.0051e-02, -1.3012e-02,\n",
+ " -3.7214e-02, 2.4073e-02, -9.3066e-03, 3.7976e-03, -2.2379e-03,\n",
+ " -3.3308e-39, -1.3419e-04, -2.2215e-02, 6.3669e-03, 7.4496e-03,\n",
+ " -2.0330e-03, -2.2457e-02, -2.8783e-04, -4.1772e-02, 3.2072e-03,\n",
+ " 1.4834e-02, 1.6824e-02, -2.8914e-24, 6.9507e-03, 4.3442e-02,\n",
+ " 3.6441e-02, 2.9667e-02, 3.8556e-02, -6.1635e-04, -4.2663e-02,\n",
+ " 1.3519e-02, 1.9096e-02, 1.7169e-02, 8.0511e-03, -9.5487e-03,\n",
+ " -3.1515e-02, 5.2946e-03, 8.7159e-03, 3.3859e-02, 2.5782e-02,\n",
+ " 3.0781e-02, -2.5431e-21, 2.2940e-03, -3.4296e-02, -3.2956e-04,\n",
+ " -3.8724e-02, -9.5437e-03, 6.4339e-03, 1.9079e-03, -4.8774e-40,\n",
+ " -1.2229e-02, -3.6624e-02, 1.6433e-03, 3.0857e-02, 5.2444e-03,\n",
+ " 2.5248e-02, -2.3934e-04, -5.5150e-22, 2.6404e-02, 1.1247e-02,\n",
+ " 2.0062e-02, -1.3908e-02, 2.9855e-02, -3.4427e-02, -1.1697e-02,\n",
+ " 2.0255e-02, -3.1341e-12, -4.0415e-03, -3.6996e-02, -3.4850e-03,\n",
+ " 5.8550e-03, 2.1957e-02, -8.1380e-30, 1.9003e-02, -3.1575e-02,\n",
+ " -2.5251e-02, 8.6628e-04, -3.8475e-04, -2.3658e-02, -3.6147e-40,\n",
+ " 3.5547e-02, 3.0704e-03, -2.3029e-22, -2.0700e-02, 1.6150e-02,\n",
+ " -5.2000e-04, 1.8059e-02, 1.1933e-03, 2.8345e-02, -3.0557e-03,\n",
+ " -1.0855e-13, -3.1424e-02, 3.4007e-02, 3.1265e-02, 2.9189e-02,\n",
+ " -1.8984e-02, 1.7491e-02, -2.1464e-02, -2.1183e-03, 9.4477e-03,\n",
+ " 1.9714e-02, 3.4079e-03, 4.1416e-02, -4.2228e-02, -5.4038e-03,\n",
+ " 6.3732e-03, 4.0005e-02, 9.1947e-03, 3.6073e-02, 1.7447e-02,\n",
+ " 8.5820e-04, 4.7003e-03, 1.5726e-02, 8.5460e-03, 1.0651e-02,\n",
+ " 1.8213e-02, 2.6473e-02, 2.8870e-02, 4.2648e-02, -3.7402e-02,\n",
+ " 2.4007e-02, 1.4670e-04, 1.2810e-02, 1.7543e-02, 1.9343e-02,\n",
+ " 2.0540e-02, 3.8007e-02, 4.3428e-03, 2.8722e-02, -5.9784e-04,\n",
+ " -3.5587e-16, -6.5072e-31, 1.9575e-02, -2.1262e-02, 3.3649e-02,\n",
+ " -3.7233e-02, 1.5873e-02, 1.8498e-02, 2.4411e-02, 2.9294e-02,\n",
+ " 1.4994e-02, -4.3407e-02, 1.4391e-02, -3.2271e-18, 1.4919e-02,\n",
+ " 8.2647e-03, -5.9837e-03, 1.9369e-02, 4.9382e-03, 8.6709e-04,\n",
+ " 3.7271e-02, -1.5391e-04, -1.3788e-04, 3.3446e-02, -7.8908e-04,\n",
+ " -2.7301e-02, 2.3324e-02, -1.2863e-04, -1.4473e-20, -2.2201e-02,\n",
+ " -6.5220e-03, 2.2164e-02, 2.2754e-02, -6.6216e-03, 2.3274e-02,\n",
+ " -4.5658e-04, 4.2457e-02, 2.4065e-02, -1.7021e-02, -6.6629e-03,\n",
+ " -2.7511e-15, 2.0085e-02, 2.0448e-02, -1.1625e-13, -3.0990e-02,\n",
+ " 1.5328e-02, -9.4063e-13, 3.2097e-02, 2.0661e-02, -1.9480e-02,\n",
+ " -4.0583e-03, -8.1608e-05, 2.4092e-02, -1.0616e-02, -1.0273e-02,\n",
+ " -4.5386e-24, 2.0436e-02, 1.3322e-02, -1.9890e-02, -4.9095e-22,\n",
+ " -6.6272e-03, -9.4388e-03, -2.7694e-03, 1.4548e-02, 1.6013e-02,\n",
+ " 2.4510e-02, 3.4344e-02, 3.7955e-02, -2.3670e-03, -5.5752e-04,\n",
+ " 4.0415e-02, -2.0931e-02, 4.2837e-02, 9.9007e-03, 2.6497e-02,\n",
+ " -2.4462e-02, -2.5526e-24, -4.6201e-04, 2.1199e-02, 1.5242e-02,\n",
+ " -3.6141e-02, -2.2361e-04, -1.5856e-40, 9.2270e-03, 3.1342e-02,\n",
+ " 1.3890e-02, -2.0049e-03, -1.7684e-03, 8.6534e-03, 1.0998e-02,\n",
+ " 2.7369e-02, -2.3437e-02, -2.9719e-06, -1.1690e-17, -1.8020e-06,\n",
+ " -4.4680e-02, 4.0719e-02, 1.7397e-02, 2.3354e-02, -3.3420e-02,\n",
+ " 1.0529e-02, 1.4536e-02, -9.0702e-04, 6.8670e-03, -1.8068e-02,\n",
+ " 6.0072e-03, 2.9915e-02, 1.5068e-02, 1.9626e-02, 1.1432e-02,\n",
+ " 2.7300e-02, 1.0894e-02, 3.3583e-02, 3.8062e-02, -4.0905e-02,\n",
+ " 2.9605e-02, 9.2011e-03, -1.2439e-02, -3.1910e-02, 2.0235e-02,\n",
+ " -2.3785e-02, 1.9151e-02, -2.7005e-12, 1.7179e-02, 1.5289e-03,\n",
+ " 1.0438e-02, -1.0442e-05, 1.4552e-02, 1.6959e-04, -2.5779e-02,\n",
+ " -7.5984e-03, -1.0531e-02, 1.1304e-02, 1.5698e-02, 4.2287e-02,\n",
+ " -1.8908e-04, 2.2671e-02, 2.6370e-02, 2.6532e-02, -3.1829e-02,\n",
+ " 1.7723e-02, 1.1734e-02, 4.3042e-02, -8.5903e-03, -1.0594e-02,\n",
+ " 3.6971e-02, -1.4430e-14, 2.7222e-02, 3.6229e-02, 3.7119e-02,\n",
+ " 1.8926e-02, 1.0625e-03, 9.2381e-04, -8.2238e-03, -6.4543e-14,\n",
+ " -1.7204e-02, -1.7061e-02, -2.6419e-02, 2.8717e-02, 4.7878e-03,\n",
+ " 1.5822e-02, -2.7476e-02, -3.3203e-02, -2.4455e-02, -6.5348e-03,\n",
+ " 2.9319e-02, -5.7435e-03, 2.5551e-02, 7.4299e-04, -7.3322e-23,\n",
+ " 2.6626e-02, -3.7447e-04, 5.5226e-03, -3.3814e-02, -4.4083e-02,\n",
+ " -1.4452e-38, 3.1598e-02, -2.7289e-02, 3.3592e-02, 3.0028e-02,\n",
+ " -4.3492e-03, -5.6239e-04, 4.2002e-02, 4.2847e-02, 2.1402e-02,\n",
+ " 3.6151e-02, 2.9313e-02, 3.4036e-02, 1.0820e-02, 8.1017e-03,\n",
+ " -3.2201e-02, -1.3255e-02, -2.0437e-19, -4.4835e-02, -1.6591e-15,\n",
+ " 7.3390e-03, -4.1098e-02, 1.0173e-03, -8.9903e-04, -1.2006e-02,\n",
+ " -1.7443e-02, 2.8147e-02, 2.3075e-02, 4.2859e-03, 1.2978e-02,\n",
+ " -1.5794e-02, 3.9542e-02, -1.5094e-06, -1.2039e-02, -4.7866e-04,\n",
+ " -7.7794e-04, -1.9704e-02, 1.4743e-02, 1.8525e-02, -9.4513e-03,\n",
+ " 2.2987e-02, 1.7009e-03, -1.9632e-02, -1.8319e-18, -7.6008e-04,\n",
+ " -1.8262e-02, 1.7942e-02, -3.4788e-02, 9.3117e-03, -1.2839e-03,\n",
+ " -1.9524e-13, -7.4299e-03, -2.1009e-12, 1.9936e-02, -2.3543e-07,\n",
+ " -3.7481e-12, -3.2467e-02, -2.0639e-04, -2.5341e-02, -2.8937e-39,\n",
+ " 1.9073e-02, -2.2175e-02, -7.8619e-13, 3.0309e-02, -1.4175e-05,\n",
+ " 3.3793e-02, -2.2715e-02, 2.4584e-02, 1.6416e-02, -4.0398e-02,\n",
+ " -2.8118e-03, 5.8474e-40, -1.3411e-04, -3.5482e-02, -2.5144e-02,\n",
+ " -2.7105e-18, 1.3947e-02, -2.8604e-18, -8.5239e-05, 2.1395e-02,\n",
+ " 8.6881e-03, -1.2376e-03, -6.1344e-04, 2.9223e-02, 3.6894e-02,\n",
+ " 4.1455e-02, -4.3520e-02, -7.9448e-17, -3.3451e-03, 3.6124e-02,\n",
+ " -1.7476e-02, 2.6451e-02, -3.3266e-03, -3.2752e-02, -1.3644e-02,\n",
+ " -3.3310e-03, -1.6137e-02, -2.0225e-17, 2.7003e-02, 2.9479e-02,\n",
+ " -2.6810e-02, -2.4228e-02, 8.1781e-03, 3.0183e-02, -1.3654e-02,\n",
+ " -2.8101e-02, 3.2361e-02, 2.3272e-02, 3.7588e-02, -1.9659e-37,\n",
+ " -7.2061e-19, 3.4653e-02, 1.1160e-02, -4.4758e-04, 1.0306e-02,\n",
+ " 6.9661e-03, -1.3581e-02, 4.0198e-03, -3.7042e-02, -2.3607e-07,\n",
+ " -1.2145e-37, 1.8899e-02, 4.1548e-02, 4.4042e-02, 2.8042e-02,\n",
+ " -1.1759e-02, 2.5871e-02, -2.8453e-21, 2.2797e-02, 8.1754e-04,\n",
+ " 9.6130e-03, 4.2659e-03, -3.3586e-02, 1.9345e-03, -7.2296e-03,\n",
+ " -5.0466e-13, -5.1403e-04, 2.8400e-02, -1.2116e-02, -4.2314e-02,\n",
+ " -4.1556e-02, 5.7938e-05, -1.3978e-03, -1.7389e-05, 3.2640e-02,\n",
+ " 2.0089e-02, 1.3020e-02, 8.1609e-03, -1.7976e-22, -8.7996e-03,\n",
+ " 1.0226e-02, 3.5521e-03, 4.3961e-02, -5.3368e-03, 3.8782e-02,\n",
+ " -4.5596e-04, -2.2731e-03, -2.3455e-02, 3.5500e-02, 3.9022e-02,\n",
+ " -2.9015e-02, 2.9233e-02], device='cuda:0')),\n",
+ " ('module.l2.weight',\n",
+ " tensor([[-1.9599e-02, 4.3520e-14, 3.3635e-02, ..., -5.6260e-03,\n",
+ " -1.0022e-02, 2.9938e-02],\n",
+ " [-3.7372e-02, -3.6296e-06, -4.1800e-02, ..., -2.4365e-03,\n",
+ " 1.1096e-02, -2.5845e-03],\n",
+ " [-1.3492e-02, 2.6660e-06, 4.1717e-02, ..., 3.4630e-02,\n",
+ " 3.4364e-02, 3.4099e-02],\n",
+ " ...,\n",
+ " [ 3.8297e-02, 1.7021e-07, 7.2816e-03, ..., -2.4668e-02,\n",
+ " -2.1351e-02, 2.6505e-02],\n",
+ " [ 3.9132e-02, -7.4625e-09, -3.8427e-02, ..., 2.9803e-02,\n",
+ " -4.0681e-02, -1.4239e-02],\n",
+ " [ 1.4177e-02, 5.8990e-06, 2.3557e-02, ..., -2.7589e-02,\n",
+ " 1.6648e-02, 4.2929e-03]], device='cuda:0')),\n",
+ " ('module.l2.bias',\n",
+ " tensor([ 2.9559e-02, -2.3026e-02, -3.2657e-02, -6.7057e-03, -3.2271e-02,\n",
+ " 3.4164e-02, 3.5729e-02, 2.8985e-02, -2.8901e-02, -4.7665e-03,\n",
+ " 2.8278e-02, 7.5709e-03, -2.6814e-02, 3.5719e-02, -5.4381e-03,\n",
+ " -3.8308e-02, 2.4235e-02, -1.6882e-03, 1.4637e-02, 1.0824e-03,\n",
+ " 2.4541e-02, 1.3036e-02, 2.0802e-02, -2.2542e-02, -2.9280e-02,\n",
+ " 1.1447e-02, -2.1490e-02, -3.9629e-02, 1.2565e-02, -2.7861e-03,\n",
+ " -4.3228e-02, -5.9347e-03, -5.7878e-03, -2.1625e-02, -2.8541e-02,\n",
+ " -1.0963e-02, -3.7994e-02, 1.5290e-02, -3.9319e-02, -2.2262e-03,\n",
+ " -1.2133e-02, 2.7171e-03, -3.4644e-02, -3.4570e-02, -1.1267e-02,\n",
+ " 3.2676e-02, 1.0274e-02, -2.3055e-02, -2.1700e-02, -4.1650e-02,\n",
+ " 1.4503e-02, 5.6783e-03, 5.8394e-03, -2.7246e-02, 1.2473e-02,\n",
+ " -3.5907e-02, 1.3339e-02, 3.1142e-02, -2.4338e-02, -1.5809e-02,\n",
+ " 1.3111e-02, 4.3001e-02, -7.0815e-04, 3.5192e-02, -2.2486e-02,\n",
+ " -3.7531e-03, 3.5100e-02, 1.3077e-02, -4.2701e-02, 1.2880e-02,\n",
+ " 1.6414e-03, -1.6710e-02, -3.9001e-02, -3.1517e-02, 9.5927e-03,\n",
+ " -1.2958e-02, 3.0725e-02, -3.7749e-02, 2.9525e-02, -2.4572e-02,\n",
+ " -1.8708e-02, -4.0700e-02, 2.5859e-02, -2.1719e-02, 2.3005e-02,\n",
+ " 3.6742e-02, -3.3517e-02, 3.1454e-02, -1.9672e-02, -7.4635e-03,\n",
+ " 4.1330e-02, 3.0112e-05, -4.4783e-02, -3.4733e-02, 1.2247e-02,\n",
+ " -2.9102e-02, 2.9429e-02, 3.7014e-02, 1.0661e-02, -3.4314e-02,\n",
+ " -3.4389e-02, -2.8635e-02, -3.3195e-02, 1.9766e-04, 6.7751e-03,\n",
+ " -4.0058e-02, 2.0636e-02, -3.0467e-02, -1.2753e-02, 3.1120e-03,\n",
+ " 2.5857e-02, 3.7302e-02, 7.2018e-04, 1.0489e-02, -2.2780e-02,\n",
+ " -2.2742e-02, -3.0345e-02, 4.4343e-02, -1.5053e-02, 6.2518e-03,\n",
+ " -2.8297e-02, 2.6441e-02, 1.0819e-02, 2.1973e-02, 3.7255e-02,\n",
+ " -2.0747e-03, 1.2554e-02, -3.6957e-02, -2.7261e-02, -3.3418e-02,\n",
+ " -3.4205e-02, 2.3845e-02, 3.0916e-02, -2.5197e-02, -1.6039e-02,\n",
+ " -2.7497e-02, -1.0755e-02, -1.5479e-02, -1.9881e-02, -9.5323e-04,\n",
+ " -7.9155e-03, -3.4123e-02, 1.5674e-02, -1.6659e-02, -1.4317e-02,\n",
+ " 1.5828e-02, 2.3558e-02, 3.0177e-02, 3.2466e-02, 4.0365e-02,\n",
+ " -2.6484e-02, -2.2293e-03, -3.8066e-02, -3.3669e-02, -3.0378e-02,\n",
+ " 6.4637e-03, 1.6316e-02, 2.5219e-02, 5.6825e-03, -3.4366e-02,\n",
+ " 1.3237e-02, 2.9095e-02, -3.5179e-02, 3.7836e-04, 1.5759e-03,\n",
+ " -2.0528e-02, 8.4841e-03, -1.6437e-02, 3.4154e-02, 2.6474e-02,\n",
+ " -1.2459e-02, 3.8338e-02, 4.1170e-02, -3.9659e-02, -1.1147e-02,\n",
+ " 1.8834e-02, 2.1028e-02, 1.8429e-02, -3.9750e-02, -2.1570e-02,\n",
+ " -3.3184e-02, -9.8474e-03, 4.1417e-02, 2.0477e-02, 3.0310e-02,\n",
+ " -1.6439e-02, -4.2043e-03, -1.9652e-02, 1.5055e-02, 1.3996e-02,\n",
+ " 2.5340e-02, -1.0289e-02, 2.1606e-02, 1.3181e-02, 1.4767e-02,\n",
+ " -1.7636e-02, -8.3252e-03, 2.8213e-02, -7.6369e-03, 3.5022e-02,\n",
+ " 2.7302e-02, 7.9103e-03, 1.9150e-02, -5.3064e-03, -2.1429e-02,\n",
+ " -3.9241e-03, 1.6526e-02, 3.0762e-02, -3.2304e-02, 4.2036e-02,\n",
+ " 3.4973e-02, -3.2242e-02, 1.4682e-02, 3.1378e-02, 1.3040e-02,\n",
+ " 1.3286e-02, 7.1529e-03, 4.0689e-02, 1.8221e-02, -8.5823e-04,\n",
+ " -3.9397e-02, -1.7632e-02, 3.8854e-02, 7.2799e-03, -8.8435e-03,\n",
+ " -3.7702e-02, 2.9500e-02, 1.3317e-02, 3.1820e-02, -2.5771e-02,\n",
+ " -2.0684e-02, 9.7548e-03, 1.1657e-02, 3.6020e-02, 3.8275e-03,\n",
+ " 2.5878e-02, -2.6804e-02, -2.7374e-02, -4.1425e-03, -8.4578e-03,\n",
+ " 3.6086e-02, -9.2384e-03, 3.6645e-02, 2.2945e-02, -3.1541e-02,\n",
+ " -3.6963e-02, -8.4283e-03, -1.2568e-02, -1.0610e-02, 4.6122e-03,\n",
+ " 8.0923e-03, 1.9675e-02, 1.6277e-02, 8.4814e-03, -1.7338e-02,\n",
+ " -2.2944e-03, -1.4254e-02, -1.1278e-02, -4.1493e-02, -1.1478e-02,\n",
+ " -4.1900e-03, -3.9914e-02, 3.6327e-02, 2.1176e-02, 2.1221e-02,\n",
+ " 2.3638e-03, -1.8431e-02, -1.6513e-02, 2.7354e-02, -2.7289e-02,\n",
+ " 5.6979e-03, -3.3893e-02, 3.4217e-02, 1.1927e-02, -3.7365e-02,\n",
+ " -4.3069e-02, 1.1643e-02, -3.8766e-02, 2.7033e-02, -3.3769e-02,\n",
+ " -8.2054e-03, 3.8923e-02, -1.3608e-02, -1.6946e-02, -5.1545e-03,\n",
+ " -3.2235e-02, -1.8647e-02, -2.6441e-02, 3.2820e-02, -6.6160e-03,\n",
+ " -1.0563e-02, -1.8934e-02, 1.0979e-02, 3.1983e-02, -3.0583e-02,\n",
+ " -5.1670e-03, -1.3826e-02, -4.7620e-03, -2.3506e-03, 2.0033e-02,\n",
+ " -2.9167e-02, 1.3401e-02, 3.6858e-05, -4.0258e-02, 3.0535e-03,\n",
+ " 1.0325e-02, -4.1490e-02, 3.5403e-02, -2.3996e-02, -2.3905e-02,\n",
+ " -2.4046e-02, 3.5478e-02, -2.8242e-02, 3.0915e-02, 1.1743e-02,\n",
+ " 2.9823e-02, 4.1825e-03, 2.7585e-02, 3.1861e-02, 3.5130e-02,\n",
+ " 1.4075e-02, 1.3824e-02, -3.5370e-02, 9.8739e-03, -4.4559e-02,\n",
+ " -3.6603e-03, -3.4666e-02, 7.4369e-03, -4.0257e-02, 1.5170e-02,\n",
+ " 3.8140e-02, -1.3233e-02, -4.7628e-03, 1.3612e-02, 5.4920e-03,\n",
+ " -3.4651e-02, 1.6557e-02, -3.7007e-03, -2.2401e-02, -3.1753e-02,\n",
+ " 1.8287e-02, 4.0354e-02, 1.9965e-02, 2.8328e-02, 3.8909e-02,\n",
+ " -2.5958e-02, 3.0594e-02, -2.6624e-02, 2.8449e-02, -3.4423e-02,\n",
+ " -1.9999e-02, 1.2986e-02, 1.7698e-04, 3.8171e-02, -1.8423e-02,\n",
+ " -4.0395e-02, 2.8430e-02, 3.3913e-02, -3.7508e-02, -2.3408e-02,\n",
+ " 2.0060e-02, 3.2080e-03, -4.0210e-02, -1.7192e-04, 4.2094e-02,\n",
+ " -4.1802e-02, -2.7460e-02, -2.8787e-02, -2.3993e-02, -2.0099e-02,\n",
+ " -5.7462e-03, -3.6673e-02, 2.0087e-02, 3.5864e-02, -3.5259e-02,\n",
+ " -1.8713e-02, -3.9394e-02, 2.8474e-02, -3.4466e-02, -2.5969e-02,\n",
+ " -2.6257e-02, -2.6545e-02, -8.0948e-03, 8.3270e-03, 8.8027e-03,\n",
+ " -1.9652e-02, 1.1662e-02, 2.7739e-02, -4.2084e-02, 2.5227e-02,\n",
+ " -4.2397e-02, -1.9872e-02, 2.2378e-02, -2.2397e-02, -1.0248e-03,\n",
+ " 2.3392e-02, 2.4300e-02, -3.3800e-03, -4.1069e-02, 1.6576e-02,\n",
+ " 2.9552e-02, 3.2363e-02, 3.5240e-02, -2.3632e-02, 3.7093e-02,\n",
+ " 3.5138e-02, -3.9624e-02, 4.2746e-02, -1.5768e-02, -2.0072e-02,\n",
+ " 9.3175e-03, 4.1433e-02, -3.0851e-03, 2.7544e-02, -2.1298e-02,\n",
+ " 4.0739e-02, 4.3524e-03, -1.0058e-02, -4.0087e-02, -1.2815e-02,\n",
+ " -4.3053e-02, 2.8493e-02, -2.3377e-02, -3.6020e-02, 2.6309e-02,\n",
+ " -1.2762e-02, -2.2052e-02, -1.7514e-02, 1.6878e-02, -2.3183e-02,\n",
+ " -4.4345e-02, -1.7994e-03, -1.6895e-03, 6.9794e-03, 4.4503e-02,\n",
+ " 3.6033e-02, -3.2598e-02, -3.0241e-03, 1.9046e-03, 2.1675e-02,\n",
+ " -3.1861e-02, 4.2725e-02, -3.1347e-02, -3.8474e-03, 2.6504e-02,\n",
+ " 1.1765e-02, -2.2374e-02, -1.1619e-02, -3.1221e-02, -6.7740e-03,\n",
+ " -3.4773e-03, -3.6647e-02, -3.2517e-02, -1.8775e-02, 4.6997e-03,\n",
+ " -3.5332e-02, 4.1909e-02, -1.8737e-02, -7.7933e-03, -2.4771e-02,\n",
+ " -2.5742e-02, -4.4500e-02, 3.2834e-02, 3.8324e-02, 1.2405e-02,\n",
+ " -3.8426e-02, 2.9064e-02, -3.6016e-02, 4.3653e-03, 3.8132e-02,\n",
+ " -4.2313e-02, -4.1442e-03, -3.4012e-02, 1.9081e-02, 7.1228e-03,\n",
+ " 1.9582e-02, -3.8668e-02, -3.8291e-02, -3.3956e-02, 9.5759e-04,\n",
+ " 2.0652e-02, -2.1983e-02, 1.7222e-02, 6.5477e-03, -4.0682e-02,\n",
+ " -2.4376e-02, 2.8711e-02, -3.0863e-02, 3.5675e-03, 3.1138e-03,\n",
+ " 2.8312e-03, 3.6704e-02, -1.4013e-02, 1.4549e-02, 4.7681e-03,\n",
+ " -1.5461e-03, -2.5192e-02, -2.2219e-02, -1.7596e-02, -3.3960e-03,\n",
+ " -3.1382e-02, 1.3545e-02, 8.4428e-04, 1.7972e-02, -7.7820e-03,\n",
+ " 2.1294e-02, 1.6943e-03, -3.1114e-02, 2.1719e-02, -2.6632e-02,\n",
+ " 4.6229e-03, 1.4209e-02], device='cuda:0'))])"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "\n",
+ "torch.load('papers/tmi2022/feature_extractor/runs/Oct29_16-15-55_xrh1/checkpoints/model.pth')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "931f65c8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import cl as cl\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.data import DataLoader\n",
+ "import torchvision.models as models\n",
+ "import torchvision.transforms.functional as VF\n",
+ "from torchvision import transforms\n",
+ "\n",
+ "import sys, argparse, os, glob\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "from collections import OrderedDict\n",
+ "from easydict import EasyDict as edict\n",
+ "\n",
+ "\n",
+ "edict({'backbone':'resnet18',\n",
+ " 'weights':})\n",
+ "\n",
+ "\n",
+ "if args.backbone == 'resnet18':\n",
+ " resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)\n",
+ " num_feats = 512\n",
+ "if args.backbone == 'resnet34':\n",
+ " resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)\n",
+ " num_feats = 512\n",
+ "if args.backbone == 'resnet50':\n",
+ " resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)\n",
+ " num_feats = 2048\n",
+ "if args.backbone == 'resnet101':\n",
+ " resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)\n",
+ " num_feats = 2048\n",
+ "for param in resnet.parameters():\n",
+ " param.requires_grad = False\n",
+ "resnet.fc = nn.Identity()\n",
+ "i_classifier = cl.IClassifier(resnet, num_feats, output_class=args.num_classes).cuda()\n",
+ "\n",
+ "# load feature extractor\n",
+ "if args.weights is None:\n",
+ " print('No feature extractor')\n",
+ " return\n",
+ "state_dict_weights = torch.load(args.weights)\n",
+ "print(state_dict_weights)\n",
+ "state_dict_init = i_classifier.state_dict()\n",
+ "new_state_dict = OrderedDict()\n",
+ "for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):\n",
+ " name = k_0\n",
+ " new_state_dict[name] = v\n",
+ "i_classifier.load_state_dict(new_state_dict, strict=False)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/feature_extractor/__init__.py b/feature_extractor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/feature_extractor/__pycache__/__init__.cpython-38.pyc b/feature_extractor/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b556d42af77e9d7b94633e305322f0d57a69d37
Binary files /dev/null and b/feature_extractor/__pycache__/__init__.cpython-38.pyc differ
diff --git a/feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc b/feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b44ae714b0972cd17a1d00c342e8a82cbf7f290
Binary files /dev/null and b/feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc differ
diff --git a/feature_extractor/__pycache__/build_graphs.cpython-38.pyc b/feature_extractor/__pycache__/build_graphs.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a83ffb9324aac47a3725f7017826b06bb41aaf13
Binary files /dev/null and b/feature_extractor/__pycache__/build_graphs.cpython-38.pyc differ
diff --git a/feature_extractor/__pycache__/cl.cpython-38.pyc b/feature_extractor/__pycache__/cl.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e42920b563025ba18624dd86f37a0b02bc8b2848
Binary files /dev/null and b/feature_extractor/__pycache__/cl.cpython-38.pyc differ
diff --git a/feature_extractor/__pycache__/simclr.cpython-36.pyc b/feature_extractor/__pycache__/simclr.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b22ef85b25820aa4f3bb208e1f4f05ae4b4d2040
Binary files /dev/null and b/feature_extractor/__pycache__/simclr.cpython-36.pyc differ
diff --git a/feature_extractor/__pycache__/simclr.cpython-38.pyc b/feature_extractor/__pycache__/simclr.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d71b3c2c74d8b382691a15fc69feebd0bf73b1b3
Binary files /dev/null and b/feature_extractor/__pycache__/simclr.cpython-38.pyc differ
diff --git a/feature_extractor/build_graph_utils.py b/feature_extractor/build_graph_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b7b79b4af4974f364e81153ce78b4215120050e
--- /dev/null
+++ b/feature_extractor/build_graph_utils.py
@@ -0,0 +1,85 @@
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+import torchvision.models as models
+import torchvision.transforms.functional as VF
+from torchvision import transforms
+
+import sys, argparse, os, glob
+import pandas as pd
+import numpy as np
+from PIL import Image
+from collections import OrderedDict
+
+class ToPIL(object):
+ def __call__(self, sample):
+ img = sample
+ img = transforms.functional.to_pil_image(img)
+ return img
+
+class BagDataset():
+ def __init__(self, csv_file, transform=None):
+ self.files_list = csv_file
+ self.transform = transform
+ def __len__(self):
+ return len(self.files_list)
+ def __getitem__(self, idx):
+ temp_path = self.files_list[idx]
+ img = os.path.join(temp_path)
+ img = Image.open(img)
+ img = img.resize((224, 224))
+ sample = {'input': img}
+
+ if self.transform:
+ sample = self.transform(sample)
+ return sample
+
+class ToTensor(object):
+ def __call__(self, sample):
+ img = sample['input']
+ img = VF.to_tensor(img)
+ return {'input': img}
+
+class Compose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, img):
+ for t in self.transforms:
+ img = t(img)
+ return img
+
+def save_coords(txt_file, csv_file_path):
+ for path in csv_file_path:
+ x, y = path.split('/')[-1].split('.')[0].split('_')
+ txt_file.writelines(str(x) + '\t' + str(y) + '\n')
+ txt_file.close()
+
+def adj_matrix(csv_file_path, output, device='cpu'):
+ total = len(csv_file_path)
+ adj_s = np.zeros((total, total))
+
+ for i in range(total-1):
+ path_i = csv_file_path[i]
+ x_i, y_i = path_i.split('/')[-1].split('.')[0].split('_')
+ for j in range(i+1, total):
+ # sptial
+ path_j = csv_file_path[j]
+ x_j, y_j = path_j.split('/')[-1].split('.')[0].split('_')
+ if abs(int(x_i)-int(x_j)) <=1 and abs(int(y_i)-int(y_j)) <= 1:
+ adj_s[i][j] = 1
+ adj_s[j][i] = 1
+
+ adj_s = torch.from_numpy(adj_s)
+ adj_s = adj_s.to(device)
+
+ return adj_s
+
+def bag_dataset(args, csv_file_path):
+ transformed_dataset = BagDataset(csv_file=csv_file_path,
+ transform=Compose([
+ ToTensor()
+ ]))
+ dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False)
+ return dataloader, len(transformed_dataset)
\ No newline at end of file
diff --git a/feature_extractor/build_graphs.py b/feature_extractor/build_graphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..64620387d1a607b32b7239e18739dc0e80f92567
--- /dev/null
+++ b/feature_extractor/build_graphs.py
@@ -0,0 +1,114 @@
+
+from cl import IClassifier
+from build_graph_utils import *
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+import torchvision.models as models
+import torchvision.transforms.functional as VF
+from torchvision import transforms
+
+import sys, argparse, os, glob
+import pandas as pd
+import numpy as np
+from PIL import Image
+from collections import OrderedDict
+
+
+
+def compute_feats(args, bags_list, i_classifier, device, save_path=None, whole_slide_path=None):
+ num_bags = len(bags_list)
+ Tensor = torch.FloatTensor
+ for i in range(0, num_bags):
+ feats_list = []
+ if args.magnification == '20x':
+ glob_path = os.path.join(bags_list[i], '*.jpeg')
+ csv_file_path = glob.glob(glob_path)
+ # line below was in the original version, commented due to errror with current version
+ #file_name = bags_list[i].split('/')[-3].split('_')[0]
+
+ file_name = glob_path.split('/')[-3].split('_')[0]
+
+ if args.magnification == '5x' or args.magnification == '10x':
+ csv_file_path = glob.glob(os.path.join(bags_list[i], '*.jpg'))
+
+ dataloader, bag_size = bag_dataset(args, csv_file_path)
+ print('{} files to be processed: {}'.format(len(csv_file_path), file_name))
+
+ if os.path.isdir(os.path.join(save_path, 'simclr_files', file_name)) or len(csv_file_path) < 1:
+ print('alreday exists')
+ continue
+ with torch.no_grad():
+ for iteration, batch in enumerate(dataloader):
+ patches = batch['input'].float().to(device)
+ feats, classes = i_classifier(patches)
+ #feats = feats.cpu().numpy()
+ feats_list.extend(feats)
+
+ os.makedirs(os.path.join(save_path, 'simclr_files', file_name), exist_ok=True)
+
+ txt_file = open(os.path.join(save_path, 'simclr_files', file_name, 'c_idx.txt'), "w+")
+ save_coords(txt_file, csv_file_path)
+ # save node features
+ output = torch.stack(feats_list, dim=0).to(device)
+ torch.save(output, os.path.join(save_path, 'simclr_files', file_name, 'features.pt'))
+ # save adjacent matrix
+ adj_s = adj_matrix(csv_file_path, output, device=device)
+ torch.save(adj_s, os.path.join(save_path, 'simclr_files', file_name, 'adj_s.pt'))
+
+ print('\r Computed: {}/{}'.format(i+1, num_bags))
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Compute TCGA features from SimCLR embedder')
+ parser.add_argument('--num_classes', default=2, type=int, help='Number of output classes')
+ parser.add_argument('--num_feats', default=512, type=int, help='Feature size')
+ parser.add_argument('--batch_size', default=128, type=int, help='Batch size of dataloader')
+ parser.add_argument('--num_workers', default=0, type=int, help='Number of threads for datalodaer')
+ parser.add_argument('--dataset', default=None, type=str, help='path to patches')
+ parser.add_argument('--backbone', default='resnet18', type=str, help='Embedder backbone')
+ parser.add_argument('--magnification', default='20x', type=str, help='Magnification to compute features')
+ parser.add_argument('--weights', default=None, type=str, help='path to the pretrained weights')
+ parser.add_argument('--output', default=None, type=str, help='path to the output graph folder')
+ args = parser.parse_args()
+
+ if args.backbone == 'resnet18':
+ resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 512
+ if args.backbone == 'resnet34':
+ resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 512
+ if args.backbone == 'resnet50':
+ resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 2048
+ if args.backbone == 'resnet101':
+ resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 2048
+ for param in resnet.parameters():
+ param.requires_grad = False
+ resnet.fc = nn.Identity()
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ print("Running on:", device)
+ i_classifier = IClassifier(resnet, num_feats, output_class=args.num_classes).to(device)
+
+ # load feature extractor
+ if args.weights is None:
+ print('No feature extractor')
+ return
+ state_dict_weights = torch.load(args.weights)
+ state_dict_init = i_classifier.state_dict()
+ new_state_dict = OrderedDict()
+ for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):
+ if 'features' not in k:
+ continue
+ name = k_0
+ new_state_dict[name] = v
+ i_classifier.load_state_dict(new_state_dict, strict=False)
+
+ os.makedirs(args.output, exist_ok=True)
+ bags_list = glob.glob(args.dataset)
+ print(bags_list)
+ compute_feats(args, bags_list, i_classifier, device, args.output)
+
+if __name__ == '__main__':
+ main()
diff --git a/feature_extractor/cl.py b/feature_extractor/cl.py
new file mode 100644
index 0000000000000000000000000000000000000000..6de9ef291a50dcbe870185a1ec62a63ecbd4f161
--- /dev/null
+++ b/feature_extractor/cl.py
@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+class FCLayer(nn.Module):
+ def __init__(self, in_size, out_size=1):
+ super(FCLayer, self).__init__()
+ self.fc = nn.Sequential(nn.Linear(in_size, out_size))
+ def forward(self, feats):
+ x = self.fc(feats)
+ return feats, x
+
+class IClassifier(nn.Module):
+ def __init__(self, feature_extractor, feature_size, output_class):
+ super(IClassifier, self).__init__()
+
+ self.feature_extractor = feature_extractor
+ self.fc = nn.Linear(feature_size, output_class)
+
+
+ def forward(self, x):
+ device = x.device
+ feats = self.feature_extractor(x) # N x K
+ c = self.fc(feats.view(feats.shape[0], -1)) # N x C
+ return feats.view(feats.shape[0], -1), c
+
+class BClassifier(nn.Module):
+ def __init__(self, input_size, output_class, dropout_v=0.0): # K, L, N
+ super(BClassifier, self).__init__()
+ self.q = nn.Linear(input_size, 128)
+ self.v = nn.Sequential(
+ nn.Dropout(dropout_v),
+ nn.Linear(input_size, input_size)
+ )
+
+ ### 1D convolutional layer that can handle multiple class (including binary)
+ self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size)
+
+ def forward(self, feats, c): # N x K, N x C
+ device = feats.device
+ V = self.v(feats) # N x V, unsorted
+ Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted
+
+ # handle multiple classes without for loop
+ _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
+ m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K
+ q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q
+ A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
+ A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,
+ B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V
+
+
+# for i in range(c.shape[1]):
+# _, indices = torch.sort(c[:, i], 0, True)
+# feats = torch.index_select(feats, 0, indices) # N x K, sorted
+# q_max = self.q(feats[0].view(1, -1)) # 1 x 1 x Q
+# temp = torch.mm(Q, q_max.view(-1, 1)) / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device))
+# if i == 0:
+# A = F.softmax(temp, 0) # N x 1
+# B = torch.sum(torch.mul(A, V), 0).view(1, -1) # 1 x V
+# else:
+# temp = F.softmax(temp, 0) # N x 1
+# A = torch.cat((A, temp), 1) # N x C
+# B = torch.cat((B, torch.sum(torch.mul(temp, V), 0).view(1, -1)), 0) # C x V -> 1 x C x V
+
+ B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
+ C = self.fcc(B) # 1 x C x 1
+ C = C.view(1, -1)
+ return C, A, B
+
+class MILNet(nn.Module):
+ def __init__(self, i_classifier, b_classifier):
+ super(MILNet, self).__init__()
+ self.i_classifier = i_classifier
+ self.b_classifier = b_classifier
+
+ def forward(self, x):
+ feats, classes = self.i_classifier(x)
+ prediction_bag, A, B = self.b_classifier(feats, classes)
+
+ return classes, prediction_bag, A, B
+
\ No newline at end of file
diff --git a/feature_extractor/config.yaml b/feature_extractor/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c8f4309e6cbefa7270b1beb7c639d9551b325a8
--- /dev/null
+++ b/feature_extractor/config.yaml
@@ -0,0 +1,23 @@
+batch_size: 256
+epochs: 20
+eval_every_n_epochs: 1
+fine_tune_from: ''
+log_every_n_steps: 25
+weight_decay: 10e-6
+fp16_precision: False
+n_gpu: 2
+gpu_ids: (0,1)
+
+model:
+ out_dim: 512
+ base_model: "resnet18"
+
+dataset:
+ s: 1
+ input_shape: (224,224,3)
+ num_workers: 10
+ valid_size: 0.1
+
+loss:
+ temperature: 0.5
+ use_cosine_similarity: True
diff --git a/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0693e648c7284afcca6e210918d3a23633b446f
Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc differ
diff --git a/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46f9dc9314fbe1d8839cab934ce38e1cbee3428a
Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc differ
diff --git a/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26111ce71915c4cf020bea014d5349d67314036c
Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc differ
diff --git a/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e800e41ae6afff480e1435bb55d62e55c608915b
Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc differ
diff --git a/feature_extractor/data_aug/dataset_wrapper.py b/feature_extractor/data_aug/dataset_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c2ad19f9ee1de487027b9db55516a89531aa484
--- /dev/null
+++ b/feature_extractor/data_aug/dataset_wrapper.py
@@ -0,0 +1,93 @@
+import numpy as np
+from torch.utils.data import DataLoader
+from torch.utils.data.sampler import SubsetRandomSampler
+import torchvision.transforms as transforms
+from data_aug.gaussian_blur import GaussianBlur
+from torchvision import datasets
+import pandas as pd
+from PIL import Image
+from skimage import io, img_as_ubyte
+
+np.random.seed(0)
+
+class Dataset():
+ def __init__(self, csv_file, transform=None):
+ lines = []
+ with open(csv_file) as f:
+ for line in f:
+ line = line.rstrip().strip()
+ lines.append(line)
+ self.files_list = lines#pd.read_csv(csv_file)
+ self.transform = transform
+ def __len__(self):
+ return len(self.files_list)
+ def __getitem__(self, idx):
+ temp_path = self.files_list[idx]# self.files_list.iloc[idx, 0]
+ img = Image.open(temp_path)
+ img = transforms.functional.to_tensor(img)
+ if self.transform:
+ sample = self.transform(img)
+ return sample
+
+class ToPIL(object):
+ def __call__(self, sample):
+ img = sample
+ img = transforms.functional.to_pil_image(img)
+ return img
+
+class DataSetWrapper(object):
+
+ def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.valid_size = valid_size
+ self.s = s
+ self.input_shape = eval(input_shape)
+
+ def get_data_loaders(self):
+ data_augment = self._get_simclr_pipeline_transform()
+ train_dataset = Dataset(csv_file='all_patches.csv', transform=SimCLRDataTransform(data_augment))
+ train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
+ return train_loader, valid_loader
+
+ def _get_simclr_pipeline_transform(self):
+ # get a set of data augmentation transformations as described in the SimCLR paper.
+ color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
+ data_transforms = transforms.Compose([ToPIL(),
+ # transforms.RandomResizedCrop(size=self.input_shape[0]),
+ transforms.Resize((self.input_shape[0],self.input_shape[1])),
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomApply([color_jitter], p=0.8),
+ transforms.RandomGrayscale(p=0.2),
+ GaussianBlur(kernel_size=int(0.06 * self.input_shape[0])),
+ transforms.ToTensor()])
+ return data_transforms
+
+ def get_train_validation_data_loaders(self, train_dataset):
+ # obtain training indices that will be used for validation
+ num_train = len(train_dataset)
+ indices = list(range(num_train))
+ np.random.shuffle(indices)
+
+ split = int(np.floor(self.valid_size * num_train))
+ train_idx, valid_idx = indices[split:], indices[:split]
+
+ # define samplers for obtaining training and validation batches
+ train_sampler = SubsetRandomSampler(train_idx)
+ valid_sampler = SubsetRandomSampler(valid_idx)
+
+ train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
+ num_workers=self.num_workers, drop_last=True, shuffle=False)
+ valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
+ num_workers=self.num_workers, drop_last=True)
+ return train_loader, valid_loader
+
+
+class SimCLRDataTransform(object):
+ def __init__(self, transform):
+ self.transform = transform
+
+ def __call__(self, sample):
+ xi = self.transform(sample)
+ xj = self.transform(sample)
+ return xi, xj
diff --git a/feature_extractor/data_aug/gaussian_blur.py b/feature_extractor/data_aug/gaussian_blur.py
new file mode 100644
index 0000000000000000000000000000000000000000..19669769637750ecc021e553483d71da3256174c
--- /dev/null
+++ b/feature_extractor/data_aug/gaussian_blur.py
@@ -0,0 +1,26 @@
+import cv2
+import numpy as np
+
+np.random.seed(0)
+
+
+class GaussianBlur(object):
+ # Implements Gaussian blur as described in the SimCLR paper
+ def __init__(self, kernel_size, min=0.1, max=2.0):
+ self.min = min
+ self.max = max
+ # kernel size is set to be 10% of the image height/width
+ self.kernel_size = kernel_size
+
+ def __call__(self, sample):
+ sample = np.array(sample)
+
+ # blur the image with a 50% chance
+ prob = np.random.random_sample()
+
+ if prob < 0.5:
+# print(self.kernel_size)
+ sigma = (self.max - self.min) * np.random.random_sample() + self.min
+ sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
+
+ return sample
diff --git a/feature_extractor/load_patches.py b/feature_extractor/load_patches.py
new file mode 100644
index 0000000000000000000000000000000000000000..0418cdbc185ef8a2d9c2870062b5ce18bcc347e7
--- /dev/null
+++ b/feature_extractor/load_patches.py
@@ -0,0 +1,37 @@
+
+import os, glob
+import argparse
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_path', type=str)
+ args = parser.parse_args()
+
+ wsi_slides_paths = []
+
+
+ def r(dirpath):
+ for file in os.listdir(dirpath):
+ path = os.path.join(dirpath, file)
+ if os.path.isfile(path) and file.endswith(".svs"):
+ wsi_slides_paths.append(path)
+ elif os.path.isdir(path):
+ r(path)
+ def r(dirpath):
+ for path in glob.glob(os.path.join(dirpath, '*','*.svs') ):#os.listdir(dirpath):
+ if os.path.isfile(path):
+ wsi_slides_paths.append(path)
+ def r(dirpath):
+ for path in glob.glob(os.path.join(dirpath, '*', '*', '*.jpeg') ):#os.listdir(dirpath):
+ if os.path.isfile(path):
+ wsi_slides_paths.append(path)
+ r(args.data_path)
+ with open('all_patches.csv', 'w') as f:
+ for filepath in wsi_slides_paths:
+ f.write(f'{filepath}\n')
+
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc b/feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f9b816d47e6570a705d9ed13c3962fbc3f04d39
Binary files /dev/null and b/feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc differ
diff --git a/feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc b/feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd661bb4c3737477f5da9b20be4bdfd94d22e595
Binary files /dev/null and b/feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc differ
diff --git a/feature_extractor/loss/nt_xent.py b/feature_extractor/loss/nt_xent.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff2baff1d67613797c333b27be0cd29756f89bbe
--- /dev/null
+++ b/feature_extractor/loss/nt_xent.py
@@ -0,0 +1,65 @@
+import torch
+import numpy as np
+
+
+class NTXentLoss(torch.nn.Module):
+
+ def __init__(self, device, batch_size, temperature, use_cosine_similarity):
+ super(NTXentLoss, self).__init__()
+ self.batch_size = batch_size
+ self.temperature = temperature
+ self.device = device
+ self.softmax = torch.nn.Softmax(dim=-1)
+ self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
+ self.similarity_function = self._get_similarity_function(use_cosine_similarity)
+ self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
+
+ def _get_similarity_function(self, use_cosine_similarity):
+ if use_cosine_similarity:
+ self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
+ return self._cosine_simililarity
+ else:
+ return self._dot_simililarity
+
+ def _get_correlated_mask(self):
+ diag = np.eye(2 * self.batch_size)
+ l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
+ l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
+ mask = torch.from_numpy((diag + l1 + l2))
+ mask = (1 - mask).type(torch.bool)
+ return mask.to(self.device)
+
+ @staticmethod
+ def _dot_simililarity(x, y):
+ v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
+ # x shape: (N, 1, C)
+ # y shape: (1, C, 2N)
+ # v shape: (N, 2N)
+ return v
+
+ def _cosine_simililarity(self, x, y):
+ # x shape: (N, 1, C)
+ # y shape: (1, 2N, C)
+ # v shape: (N, 2N)
+ v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
+ return v
+
+ def forward(self, zis, zjs):
+ representations = torch.cat([zjs, zis], dim=0)
+
+ similarity_matrix = self.similarity_function(representations, representations)
+
+ # filter out the scores from the positive samples
+ l_pos = torch.diag(similarity_matrix, self.batch_size)
+ r_pos = torch.diag(similarity_matrix, -self.batch_size)
+ positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
+
+ negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
+
+ logits = torch.cat((positives, negatives), dim=1)
+ logits /= self.temperature
+
+ labels = torch.zeros(2 * self.batch_size).to(self.device).long()
+ loss = self.criterion(logits, labels)
+
+ return loss / (2 * self.batch_size)
diff --git a/feature_extractor/models/__init__.py b/feature_extractor/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/feature_extractor/models/__pycache__/__init__.cpython-38.pyc b/feature_extractor/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ed96a932d406396d34df0b7ef0d78679b2ac52f
Binary files /dev/null and b/feature_extractor/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc b/feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dcc3d0e9ab9c84d08d1d299105dcce4c10c8f9c1
Binary files /dev/null and b/feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc differ
diff --git a/feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc b/feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb91e49d76dc61d79b70295ce9ff335321500ac7
Binary files /dev/null and b/feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc differ
diff --git a/feature_extractor/models/baseline_encoder.py b/feature_extractor/models/baseline_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b9b931c34d5a91dcefa65d9f838bbd30707009
--- /dev/null
+++ b/feature_extractor/models/baseline_encoder.py
@@ -0,0 +1,43 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+
+class Encoder(nn.Module):
+ def __init__(self, out_dim=64):
+ super(Encoder, self).__init__()
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
+ self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
+ self.pool = nn.MaxPool2d(2, 2)
+
+ # projection MLP
+ self.l1 = nn.Linear(64, 64)
+ self.l2 = nn.Linear(64, out_dim)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.pool(x)
+
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = self.pool(x)
+
+ x = self.conv3(x)
+ x = F.relu(x)
+ x = self.pool(x)
+
+ x = self.conv4(x)
+ x = F.relu(x)
+ x = self.pool(x)
+
+ h = torch.mean(x, dim=[2, 3])
+
+ x = self.l1(h)
+ x = F.relu(x)
+ x = self.l2(x)
+
+ return h, x
diff --git a/feature_extractor/models/resnet_simclr.py b/feature_extractor/models/resnet_simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..957d2611229c5a452ccd73a62630e420cf1e2e70
--- /dev/null
+++ b/feature_extractor/models/resnet_simclr.py
@@ -0,0 +1,37 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+
+class ResNetSimCLR(nn.Module):
+
+ def __init__(self, base_model, out_dim):
+ super(ResNetSimCLR, self).__init__()
+ self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d),
+ "resnet50": models.resnet50(pretrained=False)}
+
+ resnet = self._get_basemodel(base_model)
+ num_ftrs = resnet.fc.in_features
+
+ self.features = nn.Sequential(*list(resnet.children())[:-1])
+
+ # projection MLP
+ self.l1 = nn.Linear(num_ftrs, num_ftrs)
+ self.l2 = nn.Linear(num_ftrs, out_dim)
+
+ def _get_basemodel(self, model_name):
+ try:
+ model = self.resnet_dict[model_name]
+ print("Feature extractor:", model_name)
+ return model
+ except:
+ raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
+
+ def forward(self, x):
+ h = self.features(x)
+ h = h.squeeze()
+
+ x = self.l1(h)
+ x = F.relu(x)
+ x = self.l2(x)
+ return h, x
diff --git a/feature_extractor/run.py b/feature_extractor/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..50d357b15d364b8064f69d5ecc1cca9f670e4987
--- /dev/null
+++ b/feature_extractor/run.py
@@ -0,0 +1,21 @@
+from simclr import SimCLR
+import yaml
+from data_aug.dataset_wrapper import DataSetWrapper
+import os, glob
+import pandas as pd
+import argparse
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--magnification', type=str, default='20x')
+ parser.add_argument('--dest_weights', type=str)
+ args = parser.parse_args()
+ config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
+ dataset = DataSetWrapper(config['batch_size'], **config['dataset'])
+
+ simclr = SimCLR(dataset, config, args)
+ simclr.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/feature_extractor/simclr.py b/feature_extractor/simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..4165108714d9b8f677d2bb7d7de77fc7c11ad151
--- /dev/null
+++ b/feature_extractor/simclr.py
@@ -0,0 +1,165 @@
+import torch
+from models.resnet_simclr import ResNetSimCLR
+from torch.utils.tensorboard import SummaryWriter
+import torch.nn.functional as F
+from loss.nt_xent import NTXentLoss
+import os
+import shutil
+import sys
+
+apex_support = False
+try:
+ sys.path.append('./apex')
+ from apex import amp
+
+ apex_support = True
+except:
+ print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
+ apex_support = False
+
+import numpy as np
+
+torch.manual_seed(0)
+
+
+def _save_config_file(model_checkpoints_folder):
+ if not os.path.exists(model_checkpoints_folder):
+ os.makedirs(model_checkpoints_folder)
+ shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
+
+
+class SimCLR(object):
+
+ def __init__(self, dataset, config, args=None):
+ self.config = config
+ self.device = self._get_device()
+ self.writer = SummaryWriter()
+ self.dataset = dataset
+ self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
+ self.args = args
+ def _get_device(self):
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ print("Running on:", device)
+ return device
+
+ def _step(self, model, xis, xjs, n_iter):
+
+ # get the representations and the projections
+ ris, zis = model(xis) # [N,C]
+
+ # get the representations and the projections
+ rjs, zjs = model(xjs) # [N,C]
+
+ # normalize projection feature vectors
+ zis = F.normalize(zis, dim=1)
+ zjs = F.normalize(zjs, dim=1)
+
+ loss = self.nt_xent_criterion(zis, zjs)
+ return loss
+
+ def train(self):
+
+ train_loader, valid_loader = self.dataset.get_data_loaders()
+
+ model = ResNetSimCLR(**self.config["model"])# .to(self.device)
+ if self.config['n_gpu'] > 1:
+ model = torch.nn.DataParallel(model, device_ids=eval(self.config['gpu_ids']))
+ model = self._load_pre_trained_weights(model)
+ model = model.to(self.device)
+
+
+ optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=eval(self.config['weight_decay']))
+
+# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
+# last_epoch=-1)
+
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config['epochs'], eta_min=0,
+ last_epoch=-1)
+
+
+ if apex_support and self.config['fp16_precision']:
+ model, optimizer = amp.initialize(model, optimizer,
+ opt_level='O2',
+ keep_batchnorm_fp32=True)
+
+ if self.args is None:
+ model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
+ else:
+ model_checkpoints_folder = self.args.dest_weights#os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH']
+ model_checkpoints_folder = os.path.dirname(model_checkpoints_folder)
+ # save config file
+ _save_config_file(model_checkpoints_folder)
+
+ n_iter = 0
+ valid_n_iter = 0
+ best_valid_loss = np.inf
+
+ for epoch_counter in range(self.config['epochs']):
+ for (xis, xjs) in train_loader:
+ optimizer.zero_grad()
+ xis = xis.to(self.device)
+ xjs = xjs.to(self.device)
+
+ loss = self._step(model, xis, xjs, n_iter)
+
+ if n_iter % self.config['log_every_n_steps'] == 0:
+ self.writer.add_scalar('train_loss', loss, global_step=n_iter)
+ print("[%d/%d] step: %d train_loss: %.3f" % (epoch_counter, self.config['epochs'], n_iter, loss))
+
+ if apex_support and self.config['fp16_precision']:
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ loss.backward()
+
+ optimizer.step()
+ n_iter += 1
+
+ # validate the model if requested
+ if epoch_counter % self.config['eval_every_n_epochs'] == 0:
+ valid_loss = self._validate(model, valid_loader)
+ print("[%d/%d] val_loss: %.3f" % (epoch_counter, self.config['epochs'], valid_loss))
+ if valid_loss < best_valid_loss:
+ # save the model weights
+ best_valid_loss = valid_loss
+ torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
+ print('saved')
+
+ self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
+ valid_n_iter += 1
+
+ # warmup for the first 10 epochs
+ if epoch_counter >= 10:
+ scheduler.step()
+ self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
+
+ def _load_pre_trained_weights(self, model):
+ try:
+ checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
+ state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
+ model.load_state_dict(state_dict)
+ print("Loaded pre-trained model with success.")
+ except FileNotFoundError:
+ print("Pre-trained weights not found. Training from scratch.")
+
+ return model
+
+ def _validate(self, model, valid_loader):
+
+ # validation steps
+ with torch.no_grad():
+ model.eval()
+
+ valid_loss = 0.0
+ counter = 0
+
+ for (xis, xjs) in valid_loader:
+ xis = xis.to(self.device)
+ xjs = xjs.to(self.device)
+
+ loss = self._step(model, xis, xjs, counter)
+ valid_loss += loss.item()
+ counter += 1
+ valid_loss /= counter
+ model.train()
+ return valid_loss
diff --git a/feature_extractor/viewer.py b/feature_extractor/viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4ca901d07808dc1186efba948af0bd5e763559
--- /dev/null
+++ b/feature_extractor/viewer.py
@@ -0,0 +1,227 @@
+#!/usr/bin/env python
+#
+# deepzoom_server - Example web application for serving whole-slide images
+#
+# Copyright (c) 2010-2015 Carnegie Mellon University
+#
+# This library is free software; you can redistribute it and/or modify it
+# under the terms of version 2.1 of the GNU Lesser General Public License
+# as published by the Free Software Foundation.
+#
+# This library is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
+# License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this library; if not, write to the Free Software Foundation,
+# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+
+from io import BytesIO
+from optparse import OptionParser
+import os
+import re
+from unicodedata import normalize
+
+from flask import Flask, abort, make_response, render_template, url_for
+
+if os.name == 'nt':
+ _dll_path = os.getenv('OPENSLIDE_PATH')
+ if _dll_path is not None:
+ if hasattr(os, 'add_dll_directory'):
+ # Python >= 3.8
+ with os.add_dll_directory(_dll_path):
+ import openslide
+ else:
+ # Python < 3.8
+ _orig_path = os.environ.get('PATH', '')
+ os.environ['PATH'] = _orig_path + ';' + _dll_path
+ import openslide
+
+ os.environ['PATH'] = _orig_path
+else:
+ import openslide
+
+from openslide import ImageSlide, open_slide
+from openslide.deepzoom import DeepZoomGenerator
+
+DEEPZOOM_SLIDE = None
+DEEPZOOM_FORMAT = 'jpeg'
+DEEPZOOM_TILE_SIZE = 254
+DEEPZOOM_OVERLAP = 1
+DEEPZOOM_LIMIT_BOUNDS = True
+DEEPZOOM_TILE_QUALITY = 75
+SLIDE_NAME = 'slide'
+
+app = Flask(__name__)
+app.config.from_object(__name__)
+app.config.from_envvar('DEEPZOOM_TILER_SETTINGS', silent=True)
+
+
+@app.before_first_request
+def load_slide():
+ slidefile = app.config['DEEPZOOM_SLIDE']
+ if slidefile is None:
+ raise ValueError('No slide file specified')
+ config_map = {
+ 'DEEPZOOM_TILE_SIZE': 'tile_size',
+ 'DEEPZOOM_OVERLAP': 'overlap',
+ 'DEEPZOOM_LIMIT_BOUNDS': 'limit_bounds',
+ }
+ opts = {v: app.config[k] for k, v in config_map.items()}
+ slide = open_slide(slidefile)
+ app.slides = {SLIDE_NAME: DeepZoomGenerator(slide, **opts)}
+ app.associated_images = []
+ app.slide_properties = slide.properties
+ for name, image in slide.associated_images.items():
+ app.associated_images.append(name)
+ slug = slugify(name)
+ app.slides[slug] = DeepZoomGenerator(ImageSlide(image), **opts)
+ try:
+ mpp_x = slide.properties[openslide.PROPERTY_NAME_MPP_X]
+ mpp_y = slide.properties[openslide.PROPERTY_NAME_MPP_Y]
+ app.slide_mpp = (float(mpp_x) + float(mpp_y)) / 2
+ except (KeyError, ValueError):
+ app.slide_mpp = 0
+
+
+@app.route('/')
+def index():
+ slide_url = url_for('dzi', slug=SLIDE_NAME)
+ associated_urls = {
+ name: url_for('dzi', slug=slugify(name)) for name in app.associated_images
+ }
+ return render_template(
+ 'slide-multipane.html',
+ slide_url=slide_url,
+ associated=associated_urls,
+ properties=app.slide_properties,
+ slide_mpp=app.slide_mpp,
+ )
+
+
+@app.route('/.dzi')
+def dzi(slug):
+ format = app.config['DEEPZOOM_FORMAT']
+ try:
+ resp = make_response(app.slides[slug].get_dzi(format))
+ resp.mimetype = 'application/xml'
+ return resp
+ except KeyError:
+ # Unknown slug
+ abort(404)
+
+
+@app.route('/_files//_.')
+def tile(slug, level, col, row, format):
+ format = format.lower()
+ if format != 'jpeg' and format != 'png':
+ # Not supported by Deep Zoom
+ abort(404)
+ try:
+ tile = app.slides[slug].get_tile(level, (col, row))
+ except KeyError:
+ # Unknown slug
+ abort(404)
+ except ValueError:
+ # Invalid level or coordinates
+ abort(404)
+ buf = BytesIO()
+ tile.save(buf, format, quality=app.config['DEEPZOOM_TILE_QUALITY'])
+ resp = make_response(buf.getvalue())
+ resp.mimetype = 'image/%s' % format
+ return resp
+
+
+def slugify(text):
+ text = normalize('NFKD', text.lower()).encode('ascii', 'ignore').decode()
+ return re.sub('[^a-z0-9]+', '-', text)
+
+
+if __name__ == '__main__':
+ parser = OptionParser(usage='Usage: %prog [options] [slide]')
+ parser.add_option(
+ '-B',
+ '--ignore-bounds',
+ dest='DEEPZOOM_LIMIT_BOUNDS',
+ default=True,
+ action='store_false',
+ help='display entire scan area',
+ )
+ parser.add_option(
+ '-c', '--config', metavar='FILE', dest='config', help='config file'
+ )
+ parser.add_option(
+ '-d',
+ '--debug',
+ dest='DEBUG',
+ action='store_true',
+ help='run in debugging mode (insecure)',
+ )
+ parser.add_option(
+ '-e',
+ '--overlap',
+ metavar='PIXELS',
+ dest='DEEPZOOM_OVERLAP',
+ type='int',
+ help='overlap of adjacent tiles [1]',
+ )
+ parser.add_option(
+ '-f',
+ '--format',
+ metavar='{jpeg|png}',
+ dest='DEEPZOOM_FORMAT',
+ help='image format for tiles [jpeg]',
+ )
+ parser.add_option(
+ '-l',
+ '--listen',
+ metavar='ADDRESS',
+ dest='host',
+ default='127.0.0.1',
+ help='address to listen on [127.0.0.1]',
+ )
+ parser.add_option(
+ '-p',
+ '--port',
+ metavar='PORT',
+ dest='port',
+ type='int',
+ default=5000,
+ help='port to listen on [5000]',
+ )
+ parser.add_option(
+ '-Q',
+ '--quality',
+ metavar='QUALITY',
+ dest='DEEPZOOM_TILE_QUALITY',
+ type='int',
+ help='JPEG compression quality [75]',
+ )
+ parser.add_option(
+ '-s',
+ '--size',
+ metavar='PIXELS',
+ dest='DEEPZOOM_TILE_SIZE',
+ type='int',
+ help='tile size [254]',
+ )
+
+ (opts, args) = parser.parse_args()
+ # Load config file if specified
+ if opts.config is not None:
+ app.config.from_pyfile(opts.config)
+ # Overwrite only those settings specified on the command line
+ for k in dir(opts):
+ if not k.startswith('_') and getattr(opts, k) is None:
+ delattr(opts, k)
+ app.config.from_object(opts)
+ # Set slide file
+ try:
+ app.config['DEEPZOOM_SLIDE'] = args[0]
+ except IndexError:
+ if app.config['DEEPZOOM_SLIDE'] is None:
+ parser.error('No slide file specified')
+
+ app.run(host=opts.host, port=opts.port, threaded=True)
\ No newline at end of file
diff --git a/helper.py b/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..374c1a7345d298bc315051785be1727b470e454a
--- /dev/null
+++ b/helper.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+from __future__ import absolute_import, division, print_function
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from torchvision import transforms
+from utils.metrics import ConfusionMatrix
+from PIL import Image
+import os
+
+# torch.cuda.synchronize()
+# torch.backends.cudnn.benchmark = True
+torch.backends.cudnn.deterministic = True
+
+def collate(batch):
+ image = [ b['image'] for b in batch ] # w, h
+ label = [ b['label'] for b in batch ]
+ id = [ b['id'] for b in batch ]
+ adj_s = [ b['adj_s'] for b in batch ]
+ return {'image': image, 'label': label, 'id': id, 'adj_s': adj_s}
+
+def preparefeatureLabel(batch_graph, batch_label, batch_adjs, device='cpu'):
+ batch_size = len(batch_graph)
+ labels = torch.LongTensor(batch_size)
+ max_node_num = 0
+
+ for i in range(batch_size):
+ labels[i] = batch_label[i]
+ max_node_num = max(max_node_num, batch_graph[i].shape[0])
+
+ masks = torch.zeros(batch_size, max_node_num)
+ adjs = torch.zeros(batch_size, max_node_num, max_node_num)
+ batch_node_feat = torch.zeros(batch_size, max_node_num, 512)
+
+ for i in range(batch_size):
+ cur_node_num = batch_graph[i].shape[0]
+ #node attribute feature
+ tmp_node_fea = batch_graph[i]
+ batch_node_feat[i, 0:cur_node_num] = tmp_node_fea
+
+ #adjs
+ adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i]
+
+ #masks
+ masks[i,0:cur_node_num] = 1
+
+ node_feat = batch_node_feat.to(device)
+ labels = labels.to(device)
+ adjs = adjs.to(device)
+ masks = masks.to(device)
+
+ return node_feat, labels, adjs, masks
+
+class Trainer(object):
+ def __init__(self, n_class):
+ self.metrics = ConfusionMatrix(n_class)
+
+ def get_scores(self):
+ acc = self.metrics.get_scores()
+
+ return acc
+
+ def reset_metrics(self):
+ self.metrics.reset()
+
+ def plot_cm(self):
+ self.metrics.plotcm()
+
+ def train(self, sample, model):
+ node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s'])
+ pred,labels,loss = model.forward(node_feat, labels, adjs, masks)
+
+ return pred,labels,loss
+
+class Evaluator(object):
+ def __init__(self, n_class):
+ self.metrics = ConfusionMatrix(n_class)
+
+ def get_scores(self):
+ acc = self.metrics.get_scores()
+
+ return acc
+
+ def reset_metrics(self):
+ self.metrics.reset()
+
+ def plot_cm(self):
+ self.metrics.plotcm()
+
+ def eval_test(self, sample, model, graphcam_flag=False):
+ node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s'])
+ if not graphcam_flag:
+ with torch.no_grad():
+ pred,labels,loss = model.forward(node_feat, labels, adjs, masks)
+ else:
+ torch.set_grad_enabled(True)
+ pred,labels,loss= model.forward(node_feat, labels, adjs, masks, graphcam_flag=graphcam_flag)
+ return pred,labels,loss
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b1949847e2f6ae66707c93c86ecf72ccdb7b445
--- /dev/null
+++ b/main.py
@@ -0,0 +1,169 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+from torchvision import transforms
+
+from utils.dataset import GraphDataset
+from utils.lr_scheduler import LR_Scheduler
+from tensorboardX import SummaryWriter
+from helper import Trainer, Evaluator, collate
+from option import Options
+
+from models.GraphTransformer import Classifier
+from models.weight_init import weight_init
+import pickle
+args = Options().parse()
+
+label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb'))
+
+n_class = len(label_map)
+
+torch.cuda.synchronize()
+torch.backends.cudnn.deterministic = True
+
+data_path = args.data_path
+model_path = args.model_path
+if not os.path.isdir(model_path): os.mkdir(model_path)
+log_path = args.log_path
+if not os.path.isdir(log_path): os.mkdir(log_path)
+task_name = args.task_name
+
+print(task_name)
+###################################
+train = args.train
+test = args.test
+graphcam = args.graphcam
+print("train:", train, "test:", test, "graphcam:", graphcam)
+
+##### Load datasets
+print("preparing datasets and dataloaders......")
+batch_size = args.batch_size
+
+if train:
+ ids_train = open(args.train_set).readlines()
+ dataset_train = GraphDataset(os.path.join(data_path, ""), ids_train, args.dataset_metadata_path)
+ dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=True, pin_memory=True, drop_last=True)
+ total_train_num = len(dataloader_train) * batch_size
+
+ids_val = open(args.val_set).readlines()
+dataset_val = GraphDataset(os.path.join(data_path, ""), ids_val, args.dataset_metadata_path)
+dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True)
+total_val_num = len(dataloader_val) * batch_size
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+##### creating models #############
+print("creating models......")
+
+num_epochs = args.num_epochs
+learning_rate = args.lr
+
+model = Classifier(n_class)
+model = nn.DataParallel(model)
+if args.resume:
+ print('load model{}'.format(args.resume))
+ model.load_state_dict(torch.load(args.resume))
+
+if torch.cuda.is_available():
+ model = model.cuda()
+#model.apply(weight_init)
+
+optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 5e-4) # best:5e-4, 4e-3
+scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,100], gamma=0.1) # gamma=0.3 # 30,90,130 # 20,90,130 -> 150
+
+##################################
+
+criterion = nn.CrossEntropyLoss()
+
+if not test:
+ writer = SummaryWriter(log_dir=log_path + task_name)
+ f_log = open(log_path + task_name + ".log", 'w')
+
+trainer = Trainer(n_class)
+evaluator = Evaluator(n_class)
+
+best_pred = 0.0
+for epoch in range(num_epochs):
+ # optimizer.zero_grad()
+ model.train()
+ train_loss = 0.
+ total = 0.
+
+ current_lr = optimizer.param_groups[0]['lr']
+ print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' % (epoch+1, current_lr, best_pred))
+
+ if train:
+ for i_batch, sample_batched in enumerate(dataloader_train):
+ scheduler.step(epoch)
+
+ preds,labels,loss = trainer.train(sample_batched, model)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ train_loss += loss
+ total += len(labels)
+
+ trainer.metrics.update(labels, preds)
+ if (i_batch + 1) % args.log_interval_local == 0:
+ print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total, total_train_num, train_loss / total, trainer.get_scores()))
+ trainer.plot_cm()
+
+ if not test:
+ print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total_train_num, total_train_num, train_loss / total, trainer.get_scores()))
+ trainer.plot_cm()
+
+
+ if epoch % 1 == 0:
+ with torch.no_grad():
+ model.eval()
+ print("evaluating...")
+
+ total = 0.
+ batch_idx = 0
+
+ for i_batch, sample_batched in enumerate(dataloader_val):
+ preds, labels, _ = evaluator.eval_test(sample_batched, model, graphcam)
+
+ total += len(labels)
+
+ evaluator.metrics.update(labels, preds)
+
+ if (i_batch + 1) % args.log_interval_local == 0:
+ print('[%d/%d] val agg acc: %.3f' % (total, total_val_num, evaluator.get_scores()))
+ evaluator.plot_cm()
+
+ print('[%d/%d] val agg acc: %.3f' % (total_val_num, total_val_num, evaluator.get_scores()))
+ evaluator.plot_cm()
+
+ # torch.cuda.empty_cache()
+
+ val_acc = evaluator.get_scores()
+ if val_acc > best_pred:
+ best_pred = val_acc
+ if not test:
+ print("saving model...")
+ torch.save(model.state_dict(), model_path + task_name + ".pth")
+
+ log = ""
+ log = log + 'epoch [{}/{}] ------ acc: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, trainer.get_scores(), evaluator.get_scores()) + "\n"
+
+ log += "================================\n"
+ print(log)
+ if test: break
+
+ f_log.write(log)
+ f_log.flush()
+
+ writer.add_scalars('accuracy', {'train acc': trainer.get_scores(), 'val acc': evaluator.get_scores()}, epoch+1)
+
+ trainer.reset_metrics()
+ evaluator.reset_metrics()
+
+if not test: f_log.close()
\ No newline at end of file
diff --git a/metadata/label_map.pkl b/metadata/label_map.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..d1fd4d98f76c037472de0a292a6da3586c7736ae
--- /dev/null
+++ b/metadata/label_map.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce5be416a8667c9379502eaf8407e6d07bbae03749085190be630bd3b026eb52
+size 34
diff --git a/models/.gitkeep b/models/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/.gitkeep
@@ -0,0 +1 @@
+
diff --git a/models/GraphTransformer.py b/models/GraphTransformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..34ecdf561e21dace9a68c81623dbddca2f37475b
--- /dev/null
+++ b/models/GraphTransformer.py
@@ -0,0 +1,123 @@
+import sys
+import os
+import torch
+import random
+import numpy as np
+
+from torch.autograd import Variable
+from torch.nn.parameter import Parameter
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+from .ViT import *
+from .gcn import GCNBlock
+
+from torch_geometric.nn import GCNConv, DenseGraphConv, dense_mincut_pool
+from torch.nn import Linear
+class Classifier(nn.Module):
+ def __init__(self, n_class):
+ super(Classifier, self).__init__()
+
+ self.n_class = n_class
+ self.embed_dim = 64
+ self.num_layers = 3
+ self.node_cluster_num = 100
+
+ self.transformer = VisionTransformer(num_classes=n_class, embed_dim=self.embed_dim)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ self.criterion = nn.CrossEntropyLoss()
+
+ self.bn = 1
+ self.add_self = 1
+ self.normalize_embedding = 1
+ self.conv1 = GCNBlock(512,self.embed_dim,self.bn,self.add_self,self.normalize_embedding,0.,0) # 64->128
+ self.pool1 = Linear(self.embed_dim, self.node_cluster_num) # 100-> 20
+
+
+ def forward(self,node_feat,labels,adj,mask,is_print=False, graphcam_flag=False, to_file=True):
+ # node_feat, labels = self.PrepareFeatureLabel(batch_graph)
+ cls_loss=node_feat.new_zeros(self.num_layers)
+ rank_loss=node_feat.new_zeros(self.num_layers-1)
+ X=node_feat
+ p_t=[]
+ pred_logits=0
+ visualize_tools=[]
+ if labels is not None:
+ visualize_tools1=[labels.cpu()]
+ embeds=0
+ concats=[]
+
+ layer_acc=[]
+
+ X=mask.unsqueeze(2)*X
+ X = self.conv1(X, adj, mask)
+ s = self.pool1(X)
+
+
+ graphcam_tensors = {}
+
+ if graphcam_flag:
+ s_matrix = torch.argmax(s[0], dim=1)
+ if to_file:
+ from os import path
+ os.makedirs('graphcam', exist_ok=True)
+ torch.save(s_matrix, 'graphcam/s_matrix.pt')
+ torch.save(s[0], 'graphcam/s_matrix_ori.pt')
+
+ if path.exists('graphcam/att_1.pt'):
+ os.remove('graphcam/att_1.pt')
+ os.remove('graphcam/att_2.pt')
+ os.remove('graphcam/att_3.pt')
+
+ if not to_file:
+ graphcam_tensors['s_matrix'] = s_matrix
+ graphcam_tensors['s_matrix_ori'] = s[0]
+
+
+ X, adj, mc1, o1 = dense_mincut_pool(X, adj, s, mask)
+ b, _, _ = X.shape
+ cls_token = self.cls_token.repeat(b, 1, 1)
+ X = torch.cat([cls_token, X], dim=1)
+
+ out = self.transformer(X)
+
+ loss = None
+ if labels is not None:
+ # loss
+ loss = self.criterion(out, labels)
+ loss = loss + mc1 + o1
+ # pred
+ pred = out.data.max(1)[1]
+
+ if graphcam_flag:
+ #print('GraphCAM enabled')
+ #print(out.shape)
+ p = F.softmax(out)
+ #print(p.shape)
+ if to_file:
+ torch.save(p, 'graphcam/prob.pt')
+ if not to_file:
+ graphcam_tensors['prob'] = p
+ index = np.argmax(out.cpu().data.numpy(), axis=-1)
+
+ for index_ in range(self.n_class):
+ one_hot = np.zeros((1, out.size()[-1]), dtype=np.float32)
+ one_hot[0, index_] = out[0][index_]
+ one_hot_vector = one_hot
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
+ one_hot = torch.sum(one_hot.to( 'cuda' if torch.cuda.is_available() else 'cpu') * out) #!!!!!!!!!!!!!!!!!!!!out-->p
+ self.transformer.zero_grad()
+ one_hot.backward(retain_graph=True)
+
+ kwargs = {"alpha": 1}
+ cam = self.transformer.relprop(torch.tensor(one_hot_vector).to(X.device), method="transformer_attribution", is_ablation=False,
+ start_layer=0, **kwargs)
+ if to_file:
+ torch.save(cam, 'graphcam/cam_{}.pt'.format(index_))
+ if not to_file:
+ graphcam_tensors[f'cam_{index_}'] = cam
+
+ if not to_file:
+ return pred,labels,loss, graphcam_tensors
+ return pred,labels,loss
diff --git a/models/ViT.py b/models/ViT.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e07347410897293d32e697d426531b521cc38b5
--- /dev/null
+++ b/models/ViT.py
@@ -0,0 +1,415 @@
+""" Vision Transformer (ViT) in PyTorch
+"""
+import torch
+import torch.nn as nn
+from einops import rearrange
+from .layers import *
+import math
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic',
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # patch models
+ 'vit_small_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
+ ),
+ 'vit_base_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ ),
+ 'vit_large_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+}
+
+def compute_rollout_attention(all_layer_matrices, start_layer=0):
+ # adding residual consideration
+ num_tokens = all_layer_matrices[0].shape[1]
+ batch_size = all_layer_matrices[0].shape[0]
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
+ # for i in range(len(all_layer_matrices))]
+ joint_attention = all_layer_matrices[start_layer]
+ for i in range(start_layer+1, len(all_layer_matrices)):
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
+ return joint_attention
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = Linear(in_features, hidden_features)
+ self.act = GELU()
+ self.fc2 = Linear(hidden_features, out_features)
+ self.drop = Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+ def relprop(self, cam, **kwargs):
+ cam = self.drop.relprop(cam, **kwargs)
+ cam = self.fc2.relprop(cam, **kwargs)
+ cam = self.act.relprop(cam, **kwargs)
+ cam = self.fc1.relprop(cam, **kwargs)
+ return cam
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = head_dim ** -0.5
+
+ # A = Q*K^T
+ self.matmul1 = einsum('bhid,bhjd->bhij')
+ # attn = A*V
+ self.matmul2 = einsum('bhij,bhjd->bhid')
+
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = Dropout(attn_drop)
+ self.proj = Linear(dim, dim)
+ self.proj_drop = Dropout(proj_drop)
+ self.softmax = Softmax(dim=-1)
+
+ self.attn_cam = None
+ self.attn = None
+ self.v = None
+ self.v_cam = None
+ self.attn_gradients = None
+
+ def get_attn(self):
+ return self.attn
+
+ def save_attn(self, attn):
+ self.attn = attn
+
+ def save_attn_cam(self, cam):
+ self.attn_cam = cam
+
+ def get_attn_cam(self):
+ return self.attn_cam
+
+ def get_v(self):
+ return self.v
+
+ def save_v(self, v):
+ self.v = v
+
+ def save_v_cam(self, cam):
+ self.v_cam = cam
+
+ def get_v_cam(self):
+ return self.v_cam
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def forward(self, x):
+ b, n, _, h = *x.shape, self.num_heads
+ qkv = self.qkv(x)
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
+
+ self.save_v(v)
+
+ dots = self.matmul1([q, k]) * self.scale
+
+ attn = self.softmax(dots)
+ attn = self.attn_drop(attn)
+
+ # Get attention
+ if False:
+ from os import path
+ if not path.exists('att_1.pt'):
+ torch.save(attn, 'att_1.pt')
+ elif not path.exists('att_2.pt'):
+ torch.save(attn, 'att_2.pt')
+ else:
+ torch.save(attn, 'att_3.pt')
+
+ #comment in training
+ if x.requires_grad:
+ self.save_attn(attn)
+ attn.register_hook(self.save_attn_gradients)
+
+ out = self.matmul2([attn, v])
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ out = self.proj(out)
+ out = self.proj_drop(out)
+ return out
+
+ def relprop(self, cam, **kwargs):
+ cam = self.proj_drop.relprop(cam, **kwargs)
+ cam = self.proj.relprop(cam, **kwargs)
+ cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
+
+ # attn = A*V
+ (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
+ cam1 /= 2
+ cam_v /= 2
+
+ self.save_v_cam(cam_v)
+ self.save_attn_cam(cam1)
+
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
+ cam1 = self.softmax.relprop(cam1, **kwargs)
+
+ # A = Q*K^T
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
+ cam_q /= 2
+ cam_k /= 2
+
+ cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
+
+ return self.qkv.relprop(cam_qkv, **kwargs)
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
+ super().__init__()
+ self.norm1 = LayerNorm(dim, eps=1e-6)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.norm2 = LayerNorm(dim, eps=1e-6)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
+
+ self.add1 = Add()
+ self.add2 = Add()
+ self.clone1 = Clone()
+ self.clone2 = Clone()
+
+ def forward(self, x):
+ x1, x2 = self.clone1(x, 2)
+ x = self.add1([x1, self.attn(self.norm1(x2))])
+ x1, x2 = self.clone2(x, 2)
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
+ return x
+
+ def relprop(self, cam, **kwargs):
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
+ cam2 = self.mlp.relprop(cam2, **kwargs)
+ cam2 = self.norm2.relprop(cam2, **kwargs)
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
+
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
+ cam2 = self.attn.relprop(cam2, **kwargs)
+ cam2 = self.norm1.relprop(cam2, **kwargs)
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
+ return cam
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+ def __init__(self, num_classes=2, embed_dim=64, depth=3,
+ num_heads=8, mlp_ratio=2., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate)
+ for i in range(depth)])
+
+ self.norm = LayerNorm(embed_dim)
+ if mlp_head:
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
+ else:
+ # with a single Linear layer as head, the param count within rounding of paper
+ self.head = Linear(embed_dim, num_classes)
+
+ #self.apply(self._init_weights)
+
+ self.pool = IndexSelect()
+ self.add = Add()
+
+ self.inp_grad = None
+
+ def save_inp_grad(self,grad):
+ self.inp_grad = grad
+
+ def get_inp_grad(self):
+ return self.inp_grad
+
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @property
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward(self, x):
+ if x.requires_grad:
+ x.register_hook(self.save_inp_grad) #comment it in train
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
+ x = x.squeeze(1)
+ x = self.head(x)
+ return x
+
+ def relprop(self, cam=None,method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
+ # print(kwargs)
+ # print("conservation 1", cam.sum())
+ cam = self.head.relprop(cam, **kwargs)
+ cam = cam.unsqueeze(1)
+ cam = self.pool.relprop(cam, **kwargs)
+ cam = self.norm.relprop(cam, **kwargs)
+ for blk in reversed(self.blocks):
+ cam = blk.relprop(cam, **kwargs)
+
+ # print("conservation 2", cam.sum())
+ # print("min", cam.min())
+
+ if method == "full":
+ (cam, _) = self.add.relprop(cam, **kwargs)
+ cam = cam[:, 1:]
+ cam = self.patch_embed.relprop(cam, **kwargs)
+ # sum on channels
+ cam = cam.sum(dim=1)
+ return cam
+
+ elif method == "rollout":
+ # cam rollout
+ attn_cams = []
+ for blk in self.blocks:
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
+ attn_cams.append(avg_heads)
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
+ cam = cam[:, 0, 1:]
+ return cam
+
+ # our method, method name grad is legacy
+ elif method == "transformer_attribution" or method == "grad":
+ cams = []
+ for blk in self.blocks:
+ grad = blk.attn.get_attn_gradients()
+ cam = blk.attn.get_attn_cam()
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
+ cam = grad * cam
+ cam = cam.clamp(min=0).mean(dim=0)
+ cams.append(cam.unsqueeze(0))
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
+ cam = rollout[:, 0, 1:]
+ return cam
+
+ elif method == "last_layer":
+ cam = self.blocks[-1].attn.get_attn_cam()
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
+ if is_ablation:
+ grad = self.blocks[-1].attn.get_attn_gradients()
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
+ cam = grad * cam
+ cam = cam.clamp(min=0).mean(dim=0)
+ cam = cam[0, 1:]
+ return cam
+
+ elif method == "last_layer_attn":
+ cam = self.blocks[-1].attn.get_attn()
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
+ cam = cam.clamp(min=0).mean(dim=0)
+ cam = cam[0, 1:]
+ return cam
+
+ elif method == "second_layer":
+ cam = self.blocks[1].attn.get_attn_cam()
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
+ if is_ablation:
+ grad = self.blocks[1].attn.get_attn_gradients()
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
+ cam = grad * cam
+ cam = cam.clamp(min=0).mean(dim=0)
+ cam = cam[0, 1:]
+ return cam
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/__pycache__/GraphTransformer.cpython-38.pyc b/models/__pycache__/GraphTransformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b3d5ab870b19f73a63551535ccbf5902969cfd3
Binary files /dev/null and b/models/__pycache__/GraphTransformer.cpython-38.pyc differ
diff --git a/models/__pycache__/ViT.cpython-38.pyc b/models/__pycache__/ViT.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dcfb6a6c7807778d0c1b2e2aed83292c8583cb50
Binary files /dev/null and b/models/__pycache__/ViT.cpython-38.pyc differ
diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..759c53064228f849d841cb166a83d30ba0ff1580
Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/__pycache__/gcn.cpython-38.pyc b/models/__pycache__/gcn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cb6cff43d3474ff6fd1af86eed4198d6bff692f
Binary files /dev/null and b/models/__pycache__/gcn.cpython-38.pyc differ
diff --git a/models/__pycache__/layers.cpython-38.pyc b/models/__pycache__/layers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a2b4eddb6fc10f1acab2da50966ee2f016ec5ad5
Binary files /dev/null and b/models/__pycache__/layers.cpython-38.pyc differ
diff --git a/models/__pycache__/weight_init.cpython-38.pyc b/models/__pycache__/weight_init.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..410a7e0d7f64a4cbd2b1fc31df416e65fd63e1b5
Binary files /dev/null and b/models/__pycache__/weight_init.cpython-38.pyc differ
diff --git a/models/gcn.py b/models/gcn.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d52daa8171c0dadfb06b6a8dde411ef778188e6
--- /dev/null
+++ b/models/gcn.py
@@ -0,0 +1,420 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+import torch.nn.functional as F
+import math
+
+import numpy as np
+
+torch.set_printoptions(precision=2,threshold=float('inf'))
+
+class AGCNBlock(nn.Module):
+ def __init__(self,input_dim,hidden_dim,gcn_layer=2,dropout=0.0,relu=0):
+ super(AGCNBlock,self).__init__()
+ if dropout > 0.001:
+ self.dropout_layer = nn.Dropout(p=dropout)
+ self.sort = 'sort'
+ self.model='agcn'
+ self.gcns=nn.ModuleList()
+ self.bn = 0
+ self.add_self = 1
+ self.normalize_embedding = 1
+ self.gcns.append(GCNBlock(input_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
+ self.pool = 'mean'
+ self.tau = 1.
+ self.lamda = 1.
+
+ for i in range(gcn_layer-1):
+ if i==gcn_layer-2 and (not 1):
+ self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,0))
+ else:
+ self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
+
+ if self.model=='diffpool':
+ self.pool_gcns=nn.ModuleList()
+ tmp=input_dim
+ self.diffpool_k=200
+ for i in range(3):
+ self.pool_gcns.append(GCNBlock(tmp,200,0,0,0,dropout,relu))
+ tmp=200
+
+ self.w_a=nn.Parameter(torch.zeros(1,hidden_dim,1))
+ self.w_b=nn.Parameter(torch.zeros(1,hidden_dim,1))
+ torch.nn.init.normal_(self.w_a)
+ torch.nn.init.uniform_(self.w_b,-1,1)
+
+ self.pass_dim=hidden_dim
+
+ if self.pool=='mean':
+ self.pool=self.mean_pool
+ elif self.pool=='max':
+ self.pool=self.max_pool
+ elif self.pool=='sum':
+ self.pool=self.sum_pool
+
+ self.softmax='global'
+ if self.softmax=='gcn':
+ self.att_gcn=GCNBlock(2,1,0,0,dropout,relu)
+ self.khop=1
+ self.adj_norm='none'
+
+ self.filt_percent=0.25 #default 0.5
+ self.eps=1e-10
+
+ self.tau_config=1
+ if 1==-1.:
+ self.tau=nn.Parameter(torch.tensor(1),requires_grad=False)
+ elif 1==-2.:
+ self.tau_fc=nn.Linear(hidden_dim,1)
+ torch.nn.init.constant_(self.tau_fc.bias,1)
+ torch.nn.init.xavier_normal_(self.tau_fc.weight.t())
+ else:
+ self.tau=nn.Parameter(torch.tensor(self.tau))
+ self.lamda1=nn.Parameter(torch.tensor(self.lamda))
+ self.lamda2=nn.Parameter(torch.tensor(self.lamda))
+
+ self.att_norm=0
+
+ self.dnorm=0
+ self.dnorm_coe=1
+
+ self.att_out=0
+ self.single_att=0
+
+
+ def forward(self,X,adj,mask,is_print=False):
+ '''
+ input:
+ X: node input features , [batch,node_num,input_dim],dtype=float
+ adj: adj matrix, [batch,node_num,node_num], dtype=float
+ mask: mask for nodes, [batch,node_num]
+ outputs:
+ out:unormalized classification prob, [batch,hidden_dim]
+ H: batch of node hidden features, [batch,node_num,pass_dim]
+ new_adj: pooled new adj matrix, [batch, k_max, k_max]
+ new_mask: [batch, k_max]
+ '''
+ hidden=X
+ #adj = adj.float()
+ # print('input size:')
+ # print(hidden.shape)
+
+ is_print1=is_print2=is_print
+ if adj.shape[-1]>100:
+ is_print1=False
+
+ for gcn in self.gcns:
+ hidden=gcn(hidden,adj,mask)
+ # print('gcn:')
+ # print(hidden.shape)
+ # print('mask:')
+ # print(mask.unsqueeze(2).shape)
+ # print(mask.sum(dim=1))
+
+ hidden=mask.unsqueeze(2)*hidden
+ # print(hidden[0][0])
+ # print(hidden[0][-1])
+
+ if self.model=='unet':
+ att=torch.matmul(hidden,self.w_a).squeeze()
+ att=att/torch.sqrt((self.w_a.squeeze(2)**2).sum(dim=1,keepdim=True))
+ elif self.model=='agcn':
+ if self.softmax=='global' or self.softmax=='mix':
+ if False:
+ dgree_w = torch.sum(adj, dim=2) / torch.sum(adj, dim=2).max(1, keepdim=True)[0]
+ att_a=torch.matmul(hidden,self.w_a).squeeze()*dgree_w+(mask-1)*1e10
+ else:
+ att_a=torch.matmul(hidden,self.w_a).squeeze()+(mask-1)*1e10
+ # print(att_a[0][:10])
+ # print(att_a[0][-10:-1])
+ att_a_1=att_a=torch.nn.functional.softmax(att_a,dim=1)
+ # print(att_a[0][:10])
+ # print(att_a[0][-10:-1])
+
+ if self.dnorm:
+ scale=mask.sum(dim=1,keepdim=True)/self.dnorm_coe
+ att_a=scale*att_a
+ if self.softmax=='neibor' or self.softmax=='mix':
+ att_b=torch.matmul(hidden,self.w_b).squeeze()+(mask-1)*1e10
+ att_b_max,_=att_b.max(dim=1,keepdim=True)
+ if self.tau_config!=-2:
+ att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau))
+ else:
+ att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau_fc(self.pool(hidden,mask))))
+ denom=att_b.unsqueeze(2)
+ for _ in range(self.khop):
+ denom=torch.matmul(adj,denom)
+ denom=denom.squeeze()+self.eps
+ att_b=(att_b*torch.diagonal(adj,0,1,2))/denom
+ if self.dnorm:
+ if self.adj_norm=='diag':
+ diag_scale=mask/(torch.diagonal(adj,0,1,2)+self.eps)
+ elif self.adj_norm=='none':
+ diag_scale=adj.sum(dim=1)
+ att_b=att_b*diag_scale
+ att_b=att_b*mask
+
+ if self.softmax=='global':
+ att=att_a
+ elif self.softmax=='neibor' or self.softmax=='hardnei':
+ att=att_b
+ elif self.softmax=='mix':
+ att=att_a*torch.abs(self.lamda1)+att_b*torch.abs(self.lamda2)
+ # print('att:')
+ # print(att.shape)
+ Z=hidden
+
+ if self.model=='unet':
+ Z=torch.tanh(att.unsqueeze(2))*Z
+ elif self.model=='agcn':
+ if self.single_att:
+ Z=Z
+ else:
+ Z=att.unsqueeze(2)*Z
+ # print('Z shape')
+ # print(Z.shape)
+ k_max=int(math.ceil(self.filt_percent*adj.shape[-1]))
+ # print('k_max')
+ # print(k_max)
+ if self.model=='diffpool':
+ k_max=min(k_max,self.diffpool_k)
+
+ k_list=[int(math.ceil(self.filt_percent*x)) for x in mask.sum(dim=1).tolist()]
+ # print('k_list')
+ # print(k_list)
+ if self.model!='diffpool':
+ if self.sort=='sample':
+ att_samp = att * mask
+ att_samp = (att_samp/att_samp.sum(1)).detach().cpu().numpy()
+ top_index = ()
+ for i in range(att.size(0)):
+ top_index = (torch.LongTensor(np.random.choice(att_samp.size(1), k_max, att_samp[i])) ,)
+ top_index = torch.stack(top_index,1)
+ elif self.sort=='random_sample':
+ top_index = torch.LongTensor(att.size(0), k_max)*0
+ for i in range(att.size(0)):
+ top_index[i,0:k_list[i]] = torch.randperm(int(mask[i].sum().item()))[0:k_list[i]]
+ else: #sort
+ _,top_index=torch.topk(att,k_max,dim=1)
+ # print('top_index')
+ # print(top_index)
+ # print(len(top_index[0]))
+ new_mask=X.new_zeros(X.shape[0],k_max)
+ # print('new_mask')
+ # print(new_mask.shape)
+ visualize_tools=None
+ if self.model=='unet':
+ for i,k in enumerate(k_list):
+ for j in range(int(k),k_max):
+ top_index[i][j]=adj.shape[-1]-1
+ new_mask[i][j]=-1.
+ new_mask=new_mask+1
+ top_index,_=torch.sort(top_index,dim=1)
+ assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
+ for i,x in enumerate(top_index):
+ assign_m[i]=torch.index_select(adj[i],0,x)
+ new_adj=X.new_zeros(X.shape[0],k_max,k_max)
+ H=Z.new_zeros(Z.shape[0],k_max,Z.shape[-1])
+ for i,x in enumerate(top_index):
+ new_adj[i]=torch.index_select(assign_m[i],1,x)
+ H[i]=torch.index_select(Z[i],0,x)
+
+ elif self.model=='agcn':
+ assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
+ # print('assign_m.shape')
+ # print(assign_m.shape)
+ for i,k in enumerate(k_list):
+ #print('top_index[i][j]')
+ for j in range(int(k)):
+ #print(str(top_index[i][j].item())+' ', end='')
+ assign_m[i][j]=adj[i][top_index[i][j]]
+ #print(assign_m[i][j])
+ new_mask[i][j]=1.
+
+ assign_m=assign_m/(assign_m.sum(dim=1,keepdim=True)+self.eps)
+ H=torch.matmul(assign_m,Z)
+ # print('H')
+ # print(H.shape)
+ new_adj=torch.matmul(torch.matmul(assign_m,adj),torch.transpose(assign_m,1,2))
+ # print(torch.matmul(assign_m,adj).shape)
+ # print('new_adj:')
+ # print(new_adj.shape)
+
+ elif self.model=='diffpool':
+ hidden1=X
+ for gcn in self.pool_gcns:
+ hidden1=gcn(hidden1,adj,mask)
+ assign_m=X.new_ones(X.shape[0],X.shape[1],k_max)*(-100000000.)
+ for i,x in enumerate(hidden1):
+ k=min(k_list[i],k_max)
+ assign_m[i,:,0:k]=hidden1[i,:,0:k]
+ for j in range(int(k)):
+ new_mask[i][j]=1.
+
+ assign_m=torch.nn.functional.softmax(assign_m,dim=2)*mask.unsqueeze(2)
+ assign_m_t=torch.transpose(assign_m,1,2)
+ new_adj=torch.matmul(torch.matmul(assign_m_t,adj),assign_m)
+ H=torch.matmul(assign_m_t,Z)
+ # print('pool')
+ if self.att_out and self.model=='agcn':
+ if self.softmax=='global':
+ out=self.pool(att_a_1.unsqueeze(2)*hidden,mask)
+ elif self.softmax=='neibor':
+ att_b_sum=att_b.sum(dim=1,keepdim=True)
+ out=self.pool((att_b/(att_b_sum+self.eps)).unsqueeze(2)*hidden,mask)
+ else:
+ # print('hidden.shape')
+ # print(hidden.shape)
+ out=self.pool(hidden,mask)
+ # print('out shape')
+ # print(out.shape)
+
+ if self.adj_norm=='tanh' or self.adj_norm=='mix':
+ new_adj=torch.tanh(new_adj)
+ elif self.adj_norm=='diag' or self.adj_norm=='mix':
+ diag_elem=torch.pow(new_adj.sum(dim=2)+self.eps,-0.5)
+ diag=new_adj.new_zeros(new_adj.shape)
+ for i,x in enumerate(diag_elem):
+ diag[i]=torch.diagflat(x)
+ new_adj=torch.matmul(torch.matmul(diag,new_adj),diag)
+
+ visualize_tools=[]
+ '''
+ if (not self.training) and is_print1:
+ print('**********************************')
+ print('node_feat:',X.type(),X.shape)
+ print(X)
+ if self.model!='diffpool':
+ print('**********************************')
+ print('att:',att.type(),att.shape)
+ print(att)
+ print('**********************************')
+ print('top_index:',top_index.type(),top_index.shape)
+ print(top_index)
+ print('**********************************')
+ print('adj:',adj.type(),adj.shape)
+ print(adj)
+ print('**********************************')
+ print('assign_m:',assign_m.type(),assign_m.shape)
+ print(assign_m)
+ print('**********************************')
+ print('new_adj:',new_adj.type(),new_adj.shape)
+ print(new_adj)
+ print('**********************************')
+ print('new_mask:',new_mask.type(),new_mask.shape)
+ print(new_mask)
+ '''
+ #visualization
+ from os import path
+ if not path.exists('att_1.pt'):
+ torch.save(att[0], 'att_1.pt')
+ torch.save(top_index[0], 'att_ind1.pt')
+ elif not path.exists('att_2.pt'):
+ torch.save(att[0], 'att_2.pt')
+ torch.save(top_index[0], 'att_ind2.pt')
+ else:
+ torch.save(att[0], 'att_3.pt')
+ torch.save(top_index[0], 'att_ind3.pt')
+
+ if (not self.training) and is_print2:
+ if self.model!='diffpool':
+ visualize_tools.append(att[0])
+ visualize_tools.append(top_index[0])
+ visualize_tools.append(new_adj[0])
+ visualize_tools.append(new_mask.sum())
+ # print('**********************************')
+ return out,H,new_adj,new_mask,visualize_tools
+
+ def mean_pool(self,x,mask):
+ return x.sum(dim=1)/(self.eps+mask.sum(dim=1,keepdim=True))
+
+ def sum_pool(self,x,mask):
+ return x.sum(dim=1)
+
+ @staticmethod
+ def max_pool(x,mask):
+ #output: [batch,x.shape[2]]
+ m=(mask-1)*1e10
+ r,_=(x+m.unsqueeze(2)).max(dim=1)
+ return r
+# GCN basic operation
+class GCNBlock(nn.Module):
+ def __init__(self, input_dim, output_dim, bn=0,add_self=0, normalize_embedding=0,
+ dropout=0.0,relu=0, bias=True):
+ super(GCNBlock,self).__init__()
+ self.add_self = add_self
+ self.dropout = dropout
+ self.relu=relu
+ self.bn=bn
+ if dropout > 0.001:
+ self.dropout_layer = nn.Dropout(p=dropout)
+ if self.bn:
+ self.bn_layer = torch.nn.BatchNorm1d(output_dim)
+
+ self.normalize_embedding = normalize_embedding
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+
+ self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
+ torch.nn.init.xavier_normal_(self.weight)
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
+ else:
+ self.bias = None
+
+ def forward(self, x, adj, mask):
+ y = torch.matmul(adj, x)
+ if self.add_self:
+ y += x
+ y = torch.matmul(y,self.weight)
+ if self.bias is not None:
+ y = y + self.bias
+ if self.normalize_embedding:
+ y = F.normalize(y, p=2, dim=2)
+ if self.bn:
+ index=mask.sum(dim=1).long().tolist()
+ bn_tensor_bf=mask.new_zeros((sum(index),y.shape[2]))
+ bn_tensor_af=mask.new_zeros(*y.shape)
+ start_index=[]
+ ssum=0
+ for i in range(x.shape[0]):
+ start_index.append(ssum)
+ ssum+=index[i]
+ start_index.append(ssum)
+ for i in range(x.shape[0]):
+ bn_tensor_bf[start_index[i]:start_index[i+1]]=y[i,0:index[i]]
+ bn_tensor_bf=self.bn_layer(bn_tensor_bf)
+ for i in range(x.shape[0]):
+ bn_tensor_af[i,0:index[i]]=bn_tensor_bf[start_index[i]:start_index[i+1]]
+ y=bn_tensor_af
+ if self.dropout > 0.001:
+ y = self.dropout_layer(y)
+ if self.relu=='relu':
+ y=torch.nn.functional.relu(y)
+ print('hahah')
+ elif self.relu=='lrelu':
+ y=torch.nn.functional.leaky_relu(y,0.1)
+ return y
+
+#experimental function, untested
+class masked_batchnorm(nn.Module):
+ def __init__(self,feat_dim,epsilon=1e-10):
+ super().__init__()
+ self.alpha=nn.Parameter(torch.ones(feat_dim))
+ self.beta=nn.Parameter(torch.zeros(feat_dim))
+ self.eps=epsilon
+
+ def forward(self,x,mask):
+ '''
+ x: node feat, [batch,node_num,feat_dim]
+ mask: [batch,node_num]
+ '''
+ mask1 = mask.unsqueeze(2)
+ mask_sum = mask.sum()
+ mean = x.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
+ temp = (x - mean)**2
+ temp = temp*mask1
+ var = temp.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
+ rstd = torch.rsqrt(var+self.eps)
+ x=(x-mean)*rstd
+ return ((x*self.alpha) + self.beta)*mask1
\ No newline at end of file
diff --git a/models/layers.py b/models/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..66703b11cdd30c0bba249dc7a376f3638a14a253
--- /dev/null
+++ b/models/layers.py
@@ -0,0 +1,280 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
+ 'LayerNorm', 'AddEye']
+
+
+def safe_divide(a, b):
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
+ den = den + den.eq(0).type(den.type()) * 1e-9
+ return a / den * b.ne(0).type(b.type())
+
+
+def forward_hook(self, input, output):
+ if type(input[0]) in (list, tuple):
+ self.X = []
+ for i in input[0]:
+ x = i.detach()
+ x.requires_grad = True
+ self.X.append(x)
+ else:
+ self.X = input[0].detach()
+ self.X.requires_grad = True
+
+ self.Y = output
+
+
+def backward_hook(self, grad_input, grad_output):
+ self.grad_input = grad_input
+ self.grad_output = grad_output
+
+
+class RelProp(nn.Module):
+ def __init__(self):
+ super(RelProp, self).__init__()
+ # if not self.training:
+ self.register_forward_hook(forward_hook)
+
+ def gradprop(self, Z, X, S):
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
+ return C
+
+ def relprop(self, R, alpha):
+ return R
+
+class RelPropSimple(RelProp):
+ def relprop(self, R, alpha):
+ Z = self.forward(self.X)
+ S = safe_divide(R, Z)
+ C = self.gradprop(Z, self.X, S)
+
+ if torch.is_tensor(self.X) == False:
+ outputs = []
+ outputs.append(self.X[0] * C[0])
+ outputs.append(self.X[1] * C[1])
+ else:
+ outputs = self.X * (C[0])
+ return outputs
+
+class AddEye(RelPropSimple):
+ # input of shape B, C, seq_len, seq_len
+ def forward(self, input):
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
+
+class ReLU(nn.ReLU, RelProp):
+ pass
+
+class GELU(nn.GELU, RelProp):
+ pass
+
+class Softmax(nn.Softmax, RelProp):
+ pass
+
+class LayerNorm(nn.LayerNorm, RelProp):
+ pass
+
+class Dropout(nn.Dropout, RelProp):
+ pass
+
+
+class MaxPool2d(nn.MaxPool2d, RelPropSimple):
+ pass
+
+class LayerNorm(nn.LayerNorm, RelProp):
+ pass
+
+class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
+ pass
+
+
+class AvgPool2d(nn.AvgPool2d, RelPropSimple):
+ pass
+
+
+class Add(RelPropSimple):
+ def forward(self, inputs):
+ return torch.add(*inputs)
+
+ def relprop(self, R, alpha):
+ Z = self.forward(self.X)
+ S = safe_divide(R, Z)
+ C = self.gradprop(Z, self.X, S)
+
+ a = self.X[0] * C[0]
+ b = self.X[1] * C[1]
+
+ a_sum = a.sum()
+ b_sum = b.sum()
+
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
+
+ a = a * safe_divide(a_fact, a.sum())
+ b = b * safe_divide(b_fact, b.sum())
+
+ outputs = [a, b]
+
+ return outputs
+
+class einsum(RelPropSimple):
+ def __init__(self, equation):
+ super().__init__()
+ self.equation = equation
+ def forward(self, *operands):
+ return torch.einsum(self.equation, *operands)
+
+class IndexSelect(RelProp):
+ def forward(self, inputs, dim, indices):
+ self.__setattr__('dim', dim)
+ self.__setattr__('indices', indices)
+
+ return torch.index_select(inputs, dim, indices)
+
+ def relprop(self, R, alpha):
+ Z = self.forward(self.X, self.dim, self.indices)
+ S = safe_divide(R, Z)
+ C = self.gradprop(Z, self.X, S)
+
+ if torch.is_tensor(self.X) == False:
+ outputs = []
+ outputs.append(self.X[0] * C[0])
+ outputs.append(self.X[1] * C[1])
+ else:
+ outputs = self.X * (C[0])
+ return outputs
+
+
+
+class Clone(RelProp):
+ def forward(self, input, num):
+ self.__setattr__('num', num)
+ outputs = []
+ for _ in range(num):
+ outputs.append(input)
+
+ return outputs
+
+ def relprop(self, R, alpha):
+ Z = []
+ for _ in range(self.num):
+ Z.append(self.X)
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
+ C = self.gradprop(Z, self.X, S)[0]
+
+ R = self.X * C
+
+ return R
+
+class Cat(RelProp):
+ def forward(self, inputs, dim):
+ self.__setattr__('dim', dim)
+ return torch.cat(inputs, dim)
+
+ def relprop(self, R, alpha):
+ Z = self.forward(self.X, self.dim)
+ S = safe_divide(R, Z)
+ C = self.gradprop(Z, self.X, S)
+
+ outputs = []
+ for x, c in zip(self.X, C):
+ outputs.append(x * c)
+
+ return outputs
+
+
+class Sequential(nn.Sequential):
+ def relprop(self, R, alpha):
+ for m in reversed(self._modules.values()):
+ R = m.relprop(R, alpha)
+ return R
+
+class BatchNorm2d(nn.BatchNorm2d, RelProp):
+ def relprop(self, R, alpha):
+ X = self.X
+ beta = 1 - alpha
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
+ Z = X * weight + 1e-9
+ S = R / Z
+ Ca = S * weight
+ R = self.X * (Ca)
+ return R
+
+
+class Linear(nn.Linear, RelProp):
+ def relprop(self, R, alpha):
+ beta = alpha - 1
+ pw = torch.clamp(self.weight, min=0)
+ nw = torch.clamp(self.weight, max=0)
+ px = torch.clamp(self.X, min=0)
+ nx = torch.clamp(self.X, max=0)
+
+ def f(w1, w2, x1, x2):
+ Z1 = F.linear(x1, w1)
+ Z2 = F.linear(x2, w2)
+ S1 = safe_divide(R, Z1 + Z2)
+ S2 = safe_divide(R, Z1 + Z2)
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
+
+ return C1 + C2
+
+ activator_relevances = f(pw, nw, px, nx)
+ inhibitor_relevances = f(nw, pw, px, nx)
+
+ R = alpha * activator_relevances - beta * inhibitor_relevances
+
+ return R
+
+
+class Conv2d(nn.Conv2d, RelProp):
+ def gradprop2(self, DY, weight):
+ Z = self.forward(self.X)
+
+ output_padding = self.X.size()[2] - (
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
+
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
+
+ def relprop(self, R, alpha):
+ if self.X.shape[1] == 3:
+ pw = torch.clamp(self.weight, min=0)
+ nw = torch.clamp(self.weight, max=0)
+ X = self.X
+ L = self.X * 0 + \
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
+ keepdim=True)[0]
+ H = self.X * 0 + \
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
+ keepdim=True)[0]
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
+
+ S = R / Za
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
+ R = C
+ else:
+ beta = alpha - 1
+ pw = torch.clamp(self.weight, min=0)
+ nw = torch.clamp(self.weight, max=0)
+ px = torch.clamp(self.X, min=0)
+ nx = torch.clamp(self.X, max=0)
+
+ def f(w1, w2, x1, x2):
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
+ S1 = safe_divide(R, Z1)
+ S2 = safe_divide(R, Z2)
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
+ return C1 + C2
+
+ activator_relevances = f(pw, nw, px, nx)
+ inhibitor_relevances = f(nw, pw, px, nx)
+
+ R = alpha * activator_relevances - beta * inhibitor_relevances
+ return R
\ No newline at end of file
diff --git a/models/weight_init.py b/models/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa71c04254105ef5ba0c89bb730270328fc49bb1
--- /dev/null
+++ b/models/weight_init.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python
+# -*- coding:UTF-8 -*-
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+
+def weight_init(m):
+ '''
+ Usage:
+ model = Model()
+ model.apply(weight_init)
+ '''
+ if isinstance(m, nn.Conv1d):
+ init.normal_(m.weight.data)
+ if m.bias is not None:
+ init.normal_(m.bias.data)
+ elif isinstance(m, nn.Conv2d):
+ init.xavier_normal_(m.weight.data)
+ if m.bias is not None:
+ init.normal_(m.bias.data)
+ elif isinstance(m, nn.Conv3d):
+ init.xavier_normal_(m.weight.data)
+ if m.bias is not None:
+ init.normal_(m.bias.data)
+ elif isinstance(m, nn.ConvTranspose1d):
+ init.normal_(m.weight.data)
+ if m.bias is not None:
+ init.normal_(m.bias.data)
+ elif isinstance(m, nn.ConvTranspose2d):
+ init.xavier_normal_(m.weight.data)
+ if m.bias is not None:
+ init.normal_(m.bias.data)
+ elif isinstance(m, nn.ConvTranspose3d):
+ init.xavier_normal_(m.weight.data)
+ if m.bias is not None:
+ init.normal_(m.bias.data)
+ elif isinstance(m, nn.BatchNorm1d):
+ init.normal_(m.weight.data, mean=1, std=0.02)
+ init.constant_(m.bias.data, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ init.normal_(m.weight.data, mean=1, std=0.02)
+ init.constant_(m.bias.data, 0)
+ elif isinstance(m, nn.BatchNorm3d):
+ init.normal_(m.weight.data, mean=1, std=0.02)
+ init.constant_(m.bias.data, 0)
+ elif isinstance(m, nn.Linear):
+ init.xavier_normal_(m.weight.data)
+ init.normal_(m.bias.data)
+ elif isinstance(m, nn.LSTM):
+ for param in m.parameters():
+ if len(param.shape) >= 2:
+ init.orthogonal_(param.data)
+ else:
+ init.normal_(param.data)
+ elif isinstance(m, nn.LSTMCell):
+ for param in m.parameters():
+ if len(param.shape) >= 2:
+ init.orthogonal_(param.data)
+ else:
+ init.normal_(param.data)
+ elif isinstance(m, nn.GRU):
+ for param in m.parameters():
+ if len(param.shape) >= 2:
+ init.orthogonal_(param.data)
+ else:
+ init.normal_(param.data)
+ elif isinstance(m, nn.GRUCell):
+ for param in m.parameters():
+ if len(param.shape) >= 2:
+ init.orthogonal_(param.data)
+ else:
+ init.normal_(param.data)
+
+
+if __name__ == '__main__':
+ pass
\ No newline at end of file
diff --git a/option.py b/option.py
new file mode 100644
index 0000000000000000000000000000000000000000..35af9d6c88465d28aaff5e5472c439b554b3c728
--- /dev/null
+++ b/option.py
@@ -0,0 +1,41 @@
+###########################################################################
+# Created by: YI ZHENG
+# Email: yizheng@bu.edu
+# Copyright (c) 2020
+###########################################################################
+
+import os
+import argparse
+import torch
+
+class Options():
+ def __init__(self):
+ parser = argparse.ArgumentParser(description='PyTorch Classification')
+ parser.add_argument('--data_path', type=str, help='path to dataset where images store')
+ parser.add_argument('--train_set', type=str, help='train')
+ parser.add_argument('--val_set', type=str, help='validation')
+ parser.add_argument('--model_path', type=str, help='path to trained model')
+ parser.add_argument('--log_path', type=str, help='path to log files')
+ parser.add_argument('--task_name', type=str, help='task name for naming saved model files and log files')
+ parser.add_argument('--train', action='store_true', default=False, help='train only')
+ parser.add_argument('--test', action='store_true', default=False, help='test only')
+ parser.add_argument('--batch_size', type=int, default=6, help='batch size for origin global image (without downsampling)')
+ parser.add_argument('--log_interval_local', type=int, default=10, help='classification classes')
+ parser.add_argument('--resume', type=str, default="", help='path for model')
+ parser.add_argument('--graphcam', action='store_true', default=False, help='GraphCAM')
+ parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on')
+
+
+ # the parser
+ self.parser = parser
+
+ def parse(self):
+ args = self.parser.parse_args()
+ # default settings for epochs and lr
+
+ args.num_epochs = 120
+ args.lr = 1e-3
+
+ if args.test:
+ args.num_epochs = 1
+ return args
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..717b1fa2b0f030c5b959f0ad5ae718867f2b564f
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,3 @@
+openslide-tools
+python3-openslide
+python3-opencv
\ No newline at end of file
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..a16eea20541f1975db22befd5c8490786768b7af
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,305 @@
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+from torchvision import transforms
+
+import torchvision.models as models
+from feature_extractor import cl
+from models.GraphTransformer import Classifier
+from models.weight_init import weight_init
+from feature_extractor.build_graph_utils import ToTensor, Compose, bag_dataset, adj_matrix
+import torchvision.transforms.functional as VF
+from src.vis_graphcam import show_cam_on_image,cam_to_mask
+from easydict import EasyDict as edict
+from models.GraphTransformer import Classifier
+from slide_tiling import save_tiles
+import pickle
+from collections import OrderedDict
+import glob
+import openslide
+import numpy as np
+import skimage.transform
+import cv2
+
+
+class Predictor:
+
+ def __init__(self):
+ self.classdict = pickle.load(open(os.environ['CLASS_METADATA'], 'rb' ))
+ self.label_map_inv = dict()
+ for label_name, label_id in self.classdict.items():
+ self.label_map_inv[label_id] = label_name
+
+ iclf_weights = os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH']
+ graph_transformer_weights = os.environ['GT_WEIGHT_PATH']
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ self.__init_iclf(iclf_weights, backbone='resnet18')
+ self.__init_graph_transformer(graph_transformer_weights)
+
+ def predict(self, slide_path):
+
+ # get tiles for a given WSI slide
+ save_tiles(slide_path)
+
+ filename = os.path.basename(slide_path)
+ FILEID = filename.rsplit('.', maxsplit=1)[0]
+ patches_glob_path = os.path.join(os.environ['PATCHES_DIR'], f'{FILEID}_files', '*', '*.jpeg')
+ patches_paths = glob.glob(patches_glob_path)
+
+ sample = self.iclf_predict(patches_paths)
+
+
+ torch.set_grad_enabled(True)
+ node_feat, adjs, masks = Predictor.preparefeatureLabel(sample['image'], sample['adj_s'], self.device)
+ pred,labels,loss,graphcam_tensors = self.model.forward(node_feat=node_feat, labels=None, adj=adjs, mask=masks, graphcam_flag=True, to_file=False)
+
+ patches_coords = sample['c_idx'][0]
+ viz_dict = self.get_graphcams(graphcam_tensors, patches_coords, slide_path, FILEID)
+ return self.label_map_inv[pred.item()], viz_dict
+
+ def iclf_predict(self, patches_paths):
+ feats_list = []
+
+ batch_size = 128
+ num_workers = 0
+ args = edict({'batch_size':batch_size, 'num_workers':num_workers} )
+ dataloader, bag_size = bag_dataset(args, patches_paths)
+
+ with torch.no_grad():
+ for iteration, batch in enumerate(dataloader):
+ patches = batch['input'].float().to(self.device)
+ feats, classes = self.i_classifier(patches)
+ #feats = feats.cpu().numpy()
+ feats_list.extend(feats)
+ output = torch.stack(feats_list, dim=0).to(self.device)
+ # save adjacent matrix
+ adj_s = adj_matrix(patches_paths, output)
+
+
+ patch_infos = []
+ for path in patches_paths:
+ x, y = path.split('/')[-1].split('.')[0].split('_')
+ patch_infos.append((x,y))
+
+ preds = {'image': [output],
+ 'adj_s': [adj_s],
+ 'c_idx': [patch_infos]}
+ return preds
+
+
+
+ def get_graphcams(self, graphcam_tensors, patches_coords, slide_path, FILEID):
+ label_map = self.classdict
+ label_name_from_id = self.label_map_inv
+
+ n_class = len(label_map)
+
+ p = graphcam_tensors['prob'].cpu().detach().numpy()[0]
+ ori = openslide.OpenSlide(slide_path)
+ width, height = ori.dimensions
+
+ REDUCTION_FACTOR = 20
+ w, h = int(width/512), int(height/512)
+ w_r, h_r = int(width/20), int(height/20)
+ resized_img = ori.get_thumbnail((width,height))#ori.get_thumbnail((w_r,h_r))
+ resized_img = resized_img.resize((w_r,h_r))
+ ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
+ #print('ratios ', ratio_w, ratio_h)
+ w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
+
+ patches = []
+ xmax, ymax = 0, 0
+ for patch_coords in patches_coords:
+ x, y = patch_coords
+ if xmax < int(x): xmax = int(x)
+ if ymax < int(y): ymax = int(y)
+ patches.append('{}_{}.jpeg'.format(x,y))
+
+
+
+ output_img = np.asarray(resized_img)[:,:,::-1].copy()
+ #-----------------------------------------------------------------------------------------------------#
+ # GraphCAM
+ #print('visulize GraphCAM')
+ assign_matrix = graphcam_tensors['s_matrix_ori']
+ m = nn.Softmax(dim=1)
+ assign_matrix = m(assign_matrix)
+
+ # Thresholding for better visualization
+ p = np.clip(p, 0.4, 1)
+
+
+
+ output_img_copy =np.copy(output_img)
+ gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+ image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min())
+ cam_matrices = []
+ masks = []
+ visualizations = []
+ viz_dict = dict()
+
+ SAMPLE_VIZ_DIR = os.path.join(os.environ['GRAPHCAM_DIR'],
+ FILEID)
+ os.makedirs(SAMPLE_VIZ_DIR, exist_ok=True)
+
+ for class_i in range(n_class):
+
+ # Load graphcam for each class
+ cam_matrix = graphcam_tensors[f'cam_{class_i}']
+ cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0))
+ cam_matrix = cam_matrix.cpu()
+
+ # Normalize the graphcam
+ cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min())
+ cam_matrix = cam_matrix.detach().numpy()
+ cam_matrix = p[class_i] * cam_matrix
+ cam_matrix = np.clip(cam_matrix, 0, 1)
+
+
+ mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s)
+
+ vis = show_cam_on_image(image_transformer_attribution, mask)
+ vis = np.uint8(255 * vis)
+
+ cam_matrices.append(cam_matrix)
+ masks.append(mask)
+ visualizations.append(vis)
+ viz_dict['{}'.format(label_name_from_id[class_i]) ] = vis
+ cv2.imwrite(os.path.join(
+ SAMPLE_VIZ_DIR,
+ '{}_all_types_cam_{}.png'.format(FILEID, label_name_from_id[class_i] )
+ ), vis)
+ h, w, _ = output_img.shape
+ if h > w:
+ vis_merge = cv2.hconcat([output_img] + visualizations)
+ else:
+ vis_merge = cv2.vconcat([output_img] + visualizations)
+
+
+ cv2.imwrite(os.path.join(
+ SAMPLE_VIZ_DIR,
+ '{}_all_types_cam_all.png'.format(FILEID)),
+ vis_merge)
+ viz_dict['ALL'] = vis_merge
+ cv2.imwrite(os.path.join(
+ SAMPLE_VIZ_DIR,
+ '{}_all_types_ori.png'.format(FILEID )
+ ),
+ output_img)
+ viz_dict['ORI'] = output_img
+ return viz_dict
+
+
+
+
+
+
+ def preparefeatureLabel(batch_graph, batch_adjs, device='cpu'):
+ batch_size = len(batch_graph)
+ max_node_num = 0
+
+ for i in range(batch_size):
+ max_node_num = max(max_node_num, batch_graph[i].shape[0])
+
+ masks = torch.zeros(batch_size, max_node_num)
+ adjs = torch.zeros(batch_size, max_node_num, max_node_num)
+ batch_node_feat = torch.zeros(batch_size, max_node_num, 512)
+
+ for i in range(batch_size):
+ cur_node_num = batch_graph[i].shape[0]
+ #node attribute feature
+ tmp_node_fea = batch_graph[i]
+ batch_node_feat[i, 0:cur_node_num] = tmp_node_fea
+
+ #adjs
+ adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i]
+
+ #masks
+ masks[i,0:cur_node_num] = 1
+
+ node_feat = batch_node_feat.to()
+ adjs = adjs.to(device)
+ masks = masks.to(device)
+
+ return node_feat, adjs, masks
+
+ def __init_graph_transformer(self, graph_transformer_weights):
+ n_class = len(self.classdict)
+ model = Classifier(n_class)
+ model = nn.DataParallel(model)
+ model.load_state_dict(torch.load(graph_transformer_weights,
+ map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) ))
+ if torch.cuda.is_available():
+ model = model.cuda()
+ self.model = model
+
+
+ def __init_iclf(self, iclf_weights, backbone='resnet18'):
+ if backbone == 'resnet18':
+ resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 512
+ if backbone == 'resnet34':
+ resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 512
+ if backbone == 'resnet50':
+ resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 2048
+ if backbone == 'resnet101':
+ resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)
+ num_feats = 2048
+ for param in resnet.parameters():
+ param.requires_grad = False
+ resnet.fc = nn.Identity()
+ i_classifier = cl.IClassifier(resnet, num_feats, output_class=2).to(self.device)
+
+ # load feature extractor
+
+ state_dict_weights = torch.load(iclf_weights, map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ))
+ state_dict_init = i_classifier.state_dict()
+ new_state_dict = OrderedDict()
+ for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):
+ if 'features' not in k:
+ continue
+ name = k_0
+ new_state_dict[name] = v
+ i_classifier.load_state_dict(new_state_dict, strict=False)
+
+ self.i_classifier = i_classifier
+
+
+
+
+
+#0 load metadata dicitonary for class names
+#1 TILE THE IMAGE
+#2 FEED IT TO FEATURE EXTRACTOR
+#3 PRODUCE GRAPH
+#4 predict graphcams
+import subprocess
+import argparse
+import os
+import shutil
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='PyTorch Classification')
+ parser.add_argument('--slide_path', type=str, help='path to the WSI slide')
+ args = parser.parse_args()
+ predictor = Predictor()
+
+ predicted_class, viz_dict = predictor.predict(args.slide_path)
+ print('Class prediction is: ', predicted_class)
+
+
+
+
+
+
+
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..70a40c61413e97f28cc98bf3532f0b6f79eeb599
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,29 @@
+streamlit
+
+#-f https://download.pytorch.org/whl/cpu/torch_stable.html
+#-f https://data.pyg.org/whl/torch-1.7.1+cpu.html
+#torch==1.7.1+cpu
+
+#-f https://download.pytorch.org/whl/torch_stable.html
+#-f https://data.pyg.org/whl/torch-1.10.0+cu113.html
+#torch==1.10.0+cu113
+
+-f https://download.pytorch.org/whl/cpu/torch_stable.html
+torch
+torchvision
+#torch-scatter
+#torch-sparse
+
+einops
+streamlit-option-menu
+numpy
+pandas
+scikit-image
+opencv-python
+PyYAML
+tqdm
+scipy
+imageio
+easydict
+openslide-python
+pydicom
\ No newline at end of file
diff --git a/set_env.sh b/set_env.sh
new file mode 100755
index 0000000000000000000000000000000000000000..06021e440e5f23a62c786e5b3f85143f9ab345f9
--- /dev/null
+++ b/set_env.sh
@@ -0,0 +1,20 @@
+
+# environment variables for model training
+
+
+
+# environment variables for the inference api
+export DATA_DIR=queries
+export PATCHES_DIR=${DATA_DIR}/patches
+export SLIDES_DIR=${DATA_DIR}/slides
+export GRAPHCAM_DIR=${DATA_DIR}/graphcam_plots
+mkdir $GRAPHCAM_DIR -p
+
+
+# manually put the metadata in the metadata folder
+export CLASS_METADATA='metadata/label_map.pkl'
+
+# manually put the desired weights in the weights folder
+export WEIGHTS_PATH='weights'
+export FEATURE_EXTRACTOR_WEIGHT_PATH=${WEIGHTS_PATH}/feature_extractor/model.pth
+export GT_WEIGHT_PATH=${WEIGHTS_PATH}/graph_transformer/GraphCAM.pth
\ No newline at end of file
diff --git a/slide_tiling.py b/slide_tiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..8736ce0d43bab3cc3825ae52747cd7fbe433a15f
--- /dev/null
+++ b/slide_tiling.py
@@ -0,0 +1,41 @@
+import subprocess
+import argparse
+import os
+import shutil
+
+
+def save_tiles(slide_path):
+
+ filename = os.path.basename(slide_path)
+ FILEID = filename.rsplit('.', maxsplit=1)[0]
+ PATCHES_DIR = os.environ['PATCHES_DIR']
+ SLIDES_DIR = os.environ['SLIDES_DIR']
+ os.makedirs(PATCHES_DIR, exist_ok=True)
+ os.makedirs(SLIDES_DIR, exist_ok=True)
+ shutil.copy(slide_path, SLIDES_DIR)
+
+ INPUT_PATH = os.path.join(SLIDES_DIR, filename)
+ CMD = ['python3', 'src/tile_WSI.py', '-s', '512', '-e', '0', '-j', '16', '-B', '50', '-M', '20', '-o', PATCHES_DIR, INPUT_PATH]
+ subprocess.call(CMD)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='PyTorch Classification')
+ parser.add_argument('--slide_path', type=str, help='path to the WSI slide')
+ args = parser.parse_args()
+
+
+ filename = os.path.basename(args.slide_path)
+ FILEID = filename.rsplit('.', maxsplit=1)[0]
+ PATCHES_DIR = os.environ['PATCHES_DIR']
+ SLIDES_DIR = os.environ['SLIDES_DIR']
+ os.makedirs(PATCHES_DIR, exist_ok=True)
+ os.makedirs(SLIDES_DIR, exist_ok=True)
+ shutil.move(args.slide_path, SLIDES_DIR)
+
+ INPUT_PATH = os.path.join(SLIDES_DIR, filename)
+
+
+ CMD = ['python3', 'src/tile_WSI.py', '-s', '512', '-e', '0', '-j', '16', '-B', '50', '-M', '20', '-o', PATCHES_DIR, INPUT_PATH]
+
+ subprocess.call(CMD)
+
diff --git a/src/__pycache__/vis_graphcam.cpython-38.pyc b/src/__pycache__/vis_graphcam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1c8c10ff56515f819d2744de9d24a3d11a0ce64
Binary files /dev/null and b/src/__pycache__/vis_graphcam.cpython-38.pyc differ
diff --git a/src/tile_WSI.py b/src/tile_WSI.py
new file mode 100644
index 0000000000000000000000000000000000000000..47433c954a4252b4749f89cec7e5284963596a8b
--- /dev/null
+++ b/src/tile_WSI.py
@@ -0,0 +1,980 @@
+'''
+ File name: tile_WSI.py
+ Date created: March/2021
+ Source:
+ Tiling code inspired from
+ https://github.com/openslide/openslide-python/blob/master/examples/deepzoom/deepzoom_tile.py
+
+ The code has been extensively modified
+ Objective:
+ Tile svs, jpg or dcm images with the possibility of rejecting some tiles based based on xml or jpg masks
+ Be careful:
+ Overload of the node - may have memory issue if node is shared with other jobs.
+'''
+
+from __future__ import print_function
+import json
+import openslide
+from openslide import open_slide, ImageSlide
+from openslide.deepzoom import DeepZoomGenerator
+from optparse import OptionParser
+import re
+import shutil
+from unicodedata import normalize
+import numpy as np
+import scipy.misc
+import subprocess
+from glob import glob
+from multiprocessing import Process, JoinableQueue
+import time
+import os
+import sys
+try:
+ import pydicom as dicom
+except ImportError:
+ import dicom
+# from scipy.misc import imsave
+from imageio import imwrite as imsave
+# from scipy.misc import imread
+from imageio import imread
+# from scipy.misc import imresize
+
+from xml.dom import minidom
+from PIL import Image, ImageDraw, ImageCms
+from skimage import color, io
+from tqdm import tqdm
+Image.MAX_IMAGE_PIXELS = None
+
+
+VIEWER_SLIDE_NAME = 'slide'
+
+
+class TileWorker(Process):
+ """A child process that generates and writes tiles."""
+
+ def __init__(self, queue, slidepath, tile_size, overlap, limit_bounds,quality, _Bkg, _ROIpc):
+ Process.__init__(self, name='TileWorker')
+ self.daemon = True
+ self._queue = queue
+ self._slidepath = slidepath
+ self._tile_size = tile_size
+ self._overlap = overlap
+ self._limit_bounds = limit_bounds
+ self._quality = quality
+ self._slide = None
+ self._Bkg = _Bkg
+ self._ROIpc = _ROIpc
+
+ def RGB_to_lab(self, tile):
+ # srgb_p = ImageCms.createProfile("sRGB")
+ # lab_p = ImageCms.createProfile("LAB")
+ # rgb2lab = ImageCms.buildTransformFromOpenProfiles(srgb_p, lab_p, "RGB", "LAB")
+ # Lab = ImageCms.applyTransform(tile, rgb2lab)
+ # Lab = np.array(Lab)
+ # Lab = Lab.astype('float')
+ # Lab[:,:,0] = Lab[:,:,0] / 2.55
+ # Lab[:,:,1] = Lab[:,:,1] - 128
+ # Lab[:,:,2] = Lab[:,:,2] - 128
+ print("RGB to Lab")
+ Lab = color.rgb2lab(tile)
+ return Lab
+
+ def Lab_to_RGB(self,Lab):
+ # srgb_p = ImageCms.createProfile("sRGB")
+ # lab_p = ImageCms.createProfile("LAB")
+ # lab2rgb = ImageCms.buildTransformFromOpenProfiles(srgb_p, lab_p, "LAB", "RGB")
+ # Lab[:,:,0] = Lab[:,:,0] * 2.55
+ # Lab[:,:,1] = Lab[:,:,1] + 128
+ # Lab[:,:,2] = Lab[:,:,2] + 128
+ # newtile = ImageCms.applyTransform(Lab, lab2rgb)
+ print("Lab to RGB")
+ newtile = (color.lab2rgb(Lab) * 255).astype(np.uint8)
+ return newtile
+
+
+ def normalize_tile(self, tile, NormVec):
+ Lab = self.RGB_to_lab(tile)
+ TileMean = [0,0,0]
+ TileStd = [1,1,1]
+ newMean = NormVec[0:3]
+ newStd = NormVec[3:6]
+ for i in range(3):
+ TileMean[i] = np.mean(Lab[:,:,i])
+ TileStd[i] = np.std(Lab[:,:,i])
+ # print("mean/std chanel " + str(i) + ": " + str(TileMean[i]) + " / " + str(TileStd[i]))
+ tmp = ((Lab[:,:,i] - TileMean[i]) * (newStd[i] / TileStd[i])) + newMean[i]
+ if i == 0:
+ tmp[tmp<0] = 0
+ tmp[tmp>100] = 100
+ Lab[:,:,i] = tmp
+ else:
+ tmp[tmp<-128] = 128
+ tmp[tmp>127] = 127
+ Lab[:,:,i] = tmp
+ tile = self.Lab_to_RGB(Lab)
+ return tile
+
+ def run(self):
+ self._slide = open_slide(self._slidepath)
+ last_associated = None
+ dz = self._get_dz()
+ while True:
+ data = self._queue.get()
+ if data is None:
+ self._queue.task_done()
+ break
+ #associated, level, address, outfile = data
+ associated, level, address, outfile, format, outfile_bw, PercentMasked, SaveMasks, TileMask, Normalize = data
+ if last_associated != associated:
+ dz = self._get_dz(associated)
+ last_associated = associated
+ #try:
+ if True:
+ try:
+ tile = dz.get_tile(level, address)
+ # A single tile is being read
+ #check the percentage of the image with "information". Should be above 50%
+ gray = tile.convert('L')
+ bw = gray.point(lambda x: 0 if x<220 else 1, 'F')
+ arr = np.array(np.asarray(bw))
+ avgBkg = np.average(bw)
+ bw = gray.point(lambda x: 0 if x<220 else 1, '1')
+ # check if the image is mostly background
+
+ #print("res: " + outfile + " is " + str(avgBkg))
+ if avgBkg <= (self._Bkg / 100.0):
+ # print("PercentMasked: %.6f, %.6f" % (PercentMasked, self._ROIpc / 100.0) )
+ # if an Aperio selection was made, check if is within the selected region
+ if PercentMasked >= (self._ROIpc / 100.0):
+
+ if Normalize != '':
+ print("normalize " + str(outfile))
+ # arrtile = np.array(tile)
+ tile = Image.fromarray(self.normalize_tile(tile, Normalize).astype('uint8'),'RGB')
+
+ tile.save(outfile, quality=self._quality)
+ if bool(SaveMasks)==True:
+ height = TileMask.shape[0]
+ width = TileMask.shape[1]
+ TileMaskO = np.zeros((height,width,3), 'uint8')
+ maxVal = float(TileMask.max())
+ TileMaskO[...,0] = (TileMask[:,:].astype(float) / maxVal * 255.0).astype(int)
+ TileMaskO[...,1] = (TileMask[:,:].astype(float) / maxVal * 255.0).astype(int)
+ TileMaskO[...,2] = (TileMask[:,:].astype(float) / maxVal * 255.0).astype(int)
+ TileMaskO = numpy.array(Image.fromarray(TileMaskO).resize(arr.shape[0], arr.shape[1],3))
+ # TileMaskO = imresize(TileMaskO, (arr.shape[0], arr.shape[1],3))
+ TileMaskO[TileMaskO<10] = 0
+ TileMaskO[TileMaskO>=10] = 255
+ imsave(outfile_bw,TileMaskO) #(outfile_bw, quality=self._quality)
+
+ #print("%s good: %f" %(outfile, avgBkg))
+ #elif level>5:
+ # tile.save(outfile, quality=self._quality)
+ #print("%s empty: %f" %(outfile, avgBkg))
+ self._queue.task_done()
+ except Exception as e:
+ # print(level, address)
+ print("image %s failed at dz.get_tile for level %f" % (self._slidepath, level))
+ # e = sys.exc_info()[0]
+ print(e)
+ self._queue.task_done()
+
+ def _get_dz(self, associated=None):
+ if associated is not None:
+ image = ImageSlide(self._slide.associated_images[associated])
+ else:
+ image = self._slide
+ return DeepZoomGenerator(image, self._tile_size, self._overlap, limit_bounds=self._limit_bounds)
+
+
+class DeepZoomImageTiler(object):
+ """Handles generation of tiles and metadata for a single image."""
+
+ def __init__(self, dz, basename, format, associated, queue, slide, basenameJPG, xmlfile, mask_type, xmlLabel, ROIpc, ImgExtension, SaveMasks, Mag, normalize, Fieldxml):
+ self._dz = dz
+ self._basename = basename
+ self._basenameJPG = basenameJPG
+ self._format = format
+ self._associated = associated
+ self._queue = queue
+ self._processed = 0
+ self._slide = slide
+ self._xmlfile = xmlfile
+ self._mask_type = mask_type
+ self._xmlLabel = xmlLabel
+ self._ROIpc = ROIpc
+ self._ImgExtension = ImgExtension
+ self._SaveMasks = SaveMasks
+ self._Mag = Mag
+ self._normalize = normalize
+ self._Fieldxml = Fieldxml
+
+ def run(self):
+ self._write_tiles()
+ self._write_dzi()
+
+ def _write_tiles(self):
+ ########################################3
+ # nc_added
+ #level = self._dz.level_count-1
+ Magnification = 20
+ tol = 2
+ #get slide dimensions, zoom levels, and objective information
+ Factors = self._slide.level_downsamples
+ try:
+ Objective = float(self._slide.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER])
+ # print(self._basename + " - Obj information found")
+ except:
+ print(self._basename + " - No Obj information found")
+ print(self._ImgExtension)
+ if ("jpg" in self._ImgExtension) | ("dcm" in self._ImgExtension) | ("tif" in self._ImgExtension):
+ #Objective = self._ROIpc
+ Objective = 1.
+ Magnification = Objective
+ print("input is jpg - will be tiled as such with %f" % Objective)
+ else:
+ return
+ #calculate magnifications
+ Available = tuple(Objective / x for x in Factors)
+ #find highest magnification greater than or equal to 'Desired'
+ Mismatch = tuple(x-Magnification for x in Available)
+ AbsMismatch = tuple(abs(x) for x in Mismatch)
+ if len(AbsMismatch) < 1:
+ print(self._basename + " - Objective field empty!")
+ return
+ '''
+ if(min(AbsMismatch) <= tol):
+ Level = int(AbsMismatch.index(min(AbsMismatch)))
+ Factor = 1
+ else: #pick next highest level, downsample
+ Level = int(max([i for (i, val) in enumerate(Mismatch) if val > 0]))
+ Factor = Magnification / Available[Level]
+ # end added
+ '''
+ xml_valid = False
+ # a dir was provided for xml files
+
+ '''
+ ImgID = os.path.basename(self._basename)
+ Nbr_of_masks = 0
+ if self._xmlfile != '':
+ xmldir = os.path.join(self._xmlfile, ImgID + '.xml')
+ print("xml:")
+ print(xmldir)
+ if os.path.isfile(xmldir):
+ xml_labels, xml_valid = self.xml_read_labels(xmldir)
+ Nbr_of_masks = len(xml_labels)
+ else:
+ print("No xml file found for slide %s.svs (expected: %s). Directory or xml file does not exist" % (ImgID, xmldir) )
+ return
+ else:
+ Nbr_of_masks = 1
+ '''
+
+ if True:
+ #if self._xmlfile != '' && :
+ # print(self._xmlfile, self._ImgExtension)
+ ImgID = os.path.basename(self._basename)
+ xmldir = os.path.join(self._xmlfile, ImgID + '.xml')
+ # print("xml:")
+ # print(xmldir)
+ if (self._xmlfile != '') & (self._ImgExtension != 'jpg') & (self._ImgExtension != 'dcm'):
+ # print("read xml file...")
+ mask, xml_valid, Img_Fact = self.xml_read(xmldir, self._xmlLabel, self._Fieldxml)
+ if xml_valid == False:
+ print("Error: xml %s file cannot be read properly - please check format" % xmldir)
+ return
+ elif (self._xmlfile != '') & (self._ImgExtension == 'dcm'):
+ # print("check mask for dcm")
+ mask, xml_valid, Img_Fact = self.jpg_mask_read(xmldir)
+ # mask <-- read mask
+ # Img_Fact <-- 1
+ # xml_valid <-- True if mask file exists.
+ if xml_valid == False:
+ print("Error: xml %s file cannot be read properly - please check format" % xmldir)
+ return
+
+ # print("current directory: %s" % self._basename)
+
+ #return
+ #print(self._dz.level_count)
+
+ for level in range(self._dz.level_count-1,-1,-1):
+ ThisMag = Available[0]/pow(2,self._dz.level_count-(level+1))
+ if self._Mag > 0:
+ if ThisMag != self._Mag:
+ continue
+ ########################################
+ #tiledir = os.path.join("%s_files" % self._basename, str(level))
+ tiledir = os.path.join("%s_files" % self._basename, str(ThisMag))
+ if not os.path.exists(tiledir):
+ os.makedirs(tiledir)
+ cols, rows = self._dz.level_tiles[level]
+ if xml_valid:
+ # print("xml valid")
+ '''# If xml file is used, check for each tile what are their corresponding coordinate in the base image
+ IndX_orig, IndY_orig = self._dz.level_tiles[-1]
+ CurrentLevel_ReductionFactor = (Img_Fact * float(self._dz.level_dimensions[-1][0]) / float(self._dz.level_dimensions[level][0]))
+ startIndX_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(cols)]
+ print("***********")
+ endIndX_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(cols)]
+ endIndX_current_level_conv.append(self._dz.level_dimensions[level][0])
+ endIndX_current_level_conv.pop(0)
+
+ startIndY_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(rows)]
+ #endIndX_current_level_conv = [i * CurrentLevel_ReductionFactor - 1 for i in range(rows)]
+ endIndY_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(rows)]
+ endIndY_current_level_conv.append(self._dz.level_dimensions[level][1])
+ endIndY_current_level_conv.pop(0)
+ '''
+ #startIndY_current_level_conv = []
+ #endIndY_current_level_conv = []
+ #startIndX_current_level_conv = []
+ #endIndX_current_level_conv = []
+
+ #for row in range(rows):
+ # for col in range(cols):
+ # Dlocation, Dlevel, Dsize = self._dz.get_tile_coordinates(level,(col, row))
+ # Ddimension = self._dz.get_tile_dimensions(level,(col, row))
+ # startIndY_current_level_conv.append(int((Dlocation[1]) / Img_Fact))
+ # endIndY_current_level_conv.append(int((Dlocation[1] + Ddimension[1]) / Img_Fact))
+ # startIndX_current_level_conv.append(int((Dlocation[0]) / Img_Fact))
+ # endIndX_current_level_conv.append(int((Dlocation[0] + Ddimension[0]) / Img_Fact))
+ # print(Dlocation, Ddimension, int((Dlocation[1]) / Img_Fact), int((Dlocation[1] + Ddimension[1]) / Img_Fact), int((Dlocation[0]) / Img_Fact), int((Dlocation[0] + Ddimension[0]) / Img_Fact))
+ for row in range(rows):
+ for col in range(cols):
+ InsertBaseName = False
+ if InsertBaseName:
+ tilename = os.path.join(tiledir, '%s_%d_%d.%s' % (
+ self._basenameJPG, col, row, self._format))
+ tilename_bw = os.path.join(tiledir, '%s_%d_%d_mask.%s' % (
+ self._basenameJPG, col, row, self._format))
+ else:
+ tilename = os.path.join(tiledir, '%d_%d.%s' % (
+ col, row, self._format))
+ tilename_bw = os.path.join(tiledir, '%d_%d_mask.%s' % (
+ col, row, self._format))
+ if xml_valid:
+ # compute percentage of tile in mask
+ # print(row, col)
+ # print(startIndX_current_level_conv[col])
+ # print(endIndX_current_level_conv[col])
+ # print(startIndY_current_level_conv[row])
+ # print(endIndY_current_level_conv[row])
+ # print(mask.shape)
+ # print(mask[startIndX_current_level_conv[col]:endIndX_current_level_conv[col], startIndY_current_level_conv[row]:endIndY_current_level_conv[row]])
+ # TileMask = mask[startIndY_current_level_conv[row]:endIndY_current_level_conv[row], startIndX_current_level_conv[col]:endIndX_current_level_conv[col]]
+ # PercentMasked = mask[startIndY_current_level_conv[row]:endIndY_current_level_conv[row], startIndX_current_level_conv[col]:endIndX_current_level_conv[col]].mean()
+ # print(startIndY_current_level_conv[row], endIndY_current_level_conv[row], startIndX_current_level_conv[col], endIndX_current_level_conv[col])
+
+ Dlocation, Dlevel, Dsize = self._dz.get_tile_coordinates(level,(col, row))
+ Ddimension = tuple([pow(2,(self._dz.level_count - 1 - level)) * x for x in self._dz.get_tile_dimensions(level,(col, row))])
+ startIndY_current_level_conv = (int((Dlocation[1]) / Img_Fact))
+ endIndY_current_level_conv = (int((Dlocation[1] + Ddimension[1]) / Img_Fact))
+ startIndX_current_level_conv = (int((Dlocation[0]) / Img_Fact))
+ endIndX_current_level_conv = (int((Dlocation[0] + Ddimension[0]) / Img_Fact))
+ # print(Ddimension, Dlocation, Dlevel, Dsize, self._dz.level_count , level, col, row)
+
+ #startIndY_current_level_conv = (int((Dlocation[1]) / Img_Fact))
+ #endIndY_current_level_conv = (int((Dlocation[1] + Ddimension[1]) / Img_Fact))
+ #startIndX_current_level_conv = (int((Dlocation[0]) / Img_Fact))
+ #endIndX_current_level_conv = (int((Dlocation[0] + Ddimension[0]) / Img_Fact))
+ TileMask = mask[startIndY_current_level_conv:endIndY_current_level_conv, startIndX_current_level_conv:endIndX_current_level_conv]
+ PercentMasked = mask[startIndY_current_level_conv:endIndY_current_level_conv, startIndX_current_level_conv:endIndX_current_level_conv].mean()
+
+ # print(Ddimension, startIndY_current_level_conv, endIndY_current_level_conv, startIndX_current_level_conv, endIndX_current_level_conv)
+
+
+ if self._mask_type == 0:
+ # keep ROI outside of the mask
+ PercentMasked = 1.0 - PercentMasked
+ # print("Invert Mask percentage")
+
+ # if PercentMasked > 0:
+ # print("PercentMasked_p %.3f" % (PercentMasked))
+ # else:
+ # print("PercentMasked_0 %.3f" % (PercentMasked))
+
+
+ else:
+ PercentMasked = 1.0
+ TileMask = []
+
+ if not os.path.exists(tilename):
+ self._queue.put((self._associated, level, (col, row),
+ tilename, self._format, tilename_bw, PercentMasked, self._SaveMasks, TileMask, self._normalize))
+ self._tile_done()
+
+ def _tile_done(self):
+ self._processed += 1
+ count, total = self._processed, self._dz.tile_count
+ if count % 100 == 0 or count == total:
+ #print("Tiling %s: wrote %d/%d tiles" % (
+ # self._associated or 'slide', count, total),
+ # end='\r', file=sys.stderr)
+ if count == total:
+ print(file=sys.stderr)
+
+ def _write_dzi(self):
+ with open('%s.dzi' % self._basename, 'w') as fh:
+ fh.write(self.get_dzi())
+
+ def get_dzi(self):
+ return self._dz.get_dzi(self._format)
+
+
+ def jpg_mask_read(self, xmldir):
+ # Original size of the image
+ ImgMaxSizeX_orig = float(self._dz.level_dimensions[-1][0])
+ ImgMaxSizeY_orig = float(self._dz.level_dimensions[-1][1])
+ # Number of centers at the highest resolution
+ cols, rows = self._dz.level_tiles[-1]
+ # Img_Fact = int(ImgMaxSizeX_orig / 1.0 / cols)
+ Img_Fact = 1
+ try:
+ # xmldir: change extension from xml to *jpg
+ xmldir = xmldir[:-4] + "mask.jpg"
+ # xmlcontent = read xmldir image
+ xmlcontent = imread(xmldir)
+ xmlcontent = xmlcontent - np.min(xmlcontent)
+ mask = xmlcontent / np.max(xmlcontent)
+ # we want image between 0 and 1
+ xml_valid = True
+ except:
+ xml_valid = False
+ print("error with minidom.parse(xmldir)")
+ return [], xml_valid, 1.0
+
+ return mask, xml_valid, Img_Fact
+
+
+ def xml_read(self, xmldir, Attribute_Name, Fieldxml):
+
+ # Original size of the image
+ ImgMaxSizeX_orig = float(self._dz.level_dimensions[-1][0])
+ ImgMaxSizeY_orig = float(self._dz.level_dimensions[-1][1])
+ # Number of centers at the highest resolution
+ cols, rows = self._dz.level_tiles[-1]
+
+ NewFact = max(ImgMaxSizeX_orig, ImgMaxSizeY_orig) / min(max(ImgMaxSizeX_orig, ImgMaxSizeY_orig),15000.0)
+ # Img_Fact =
+ # read_region(location, level, size)
+ # dz.get_tile_coordinates(14,(0,2))
+ # ((0, 1792), 1, (320, 384))
+
+ Img_Fact = float(ImgMaxSizeX_orig) / 5.0 / float(cols)
+
+ # print("image info:")
+ # print(ImgMaxSizeX_orig, ImgMaxSizeY_orig, cols, rows)
+ try:
+ xmlcontent = minidom.parse(xmldir)
+ xml_valid = True
+ except:
+ xml_valid = False
+ print("error with minidom.parse(xmldir)")
+ return [], xml_valid, 1.0
+
+ xy = {}
+ xy_neg = {}
+ NbRg = 0
+ labelIDs = xmlcontent.getElementsByTagName('Annotation')
+ # print("%d labels" % len(labelIDs) )
+ for labelID in labelIDs:
+ if (Attribute_Name==[]) | (Attribute_Name==''):
+ isLabelOK = True
+ else:
+ try:
+ labeltag = labelID.getElementsByTagName('Attribute')[0]
+ if (Attribute_Name==labeltag.attributes[Fieldxml].value):
+ # if (Attribute_Name==labeltag.attributes['Value'].value):
+ # if (Attribute_Name==labeltag.attributes['Name'].value):
+ isLabelOK = True
+ else:
+ isLabelOK = False
+ except:
+ isLabelOK = False
+ if Attribute_Name == "non_selected_regions":
+ isLabelOK = True
+
+ #print("label ID, tag:")
+ #print(labelID, Attribute_Name, labeltag.attributes['Name'].value)
+ #if Attribute_Name==labeltag.attributes['Name'].value:
+ if isLabelOK:
+ regionlist = labelID.getElementsByTagName('Region')
+ for region in regionlist:
+ vertices = region.getElementsByTagName('Vertex')
+ NbRg += 1
+ regionID = region.attributes['Id'].value + str(NbRg)
+ NegativeROA = region.attributes['NegativeROA'].value
+ # print("%d vertices" % len(vertices))
+ if len(vertices) > 0:
+ #print( len(vertices) )
+ if NegativeROA=="0":
+ xy[regionID] = []
+ for vertex in vertices:
+ # get the x value of the vertex / convert them into index in the tiled matrix of the base image
+ # x = int(round(float(vertex.attributes['X'].value) / ImgMaxSizeX_orig * (cols*Img_Fact)))
+ # y = int(round(float(vertex.attributes['Y'].value) / ImgMaxSizeY_orig * (rows*Img_Fact)))
+ x = int(round(float(vertex.attributes['X'].value) / NewFact))
+ y = int(round(float(vertex.attributes['Y'].value) / NewFact))
+ xy[regionID].append((x,y))
+ #print(vertex.attributes['X'].value, vertex.attributes['Y'].value, x, y )
+
+ elif NegativeROA=="1":
+ xy_neg[regionID] = []
+ for vertex in vertices:
+ # get the x value of the vertex / convert them into index in the tiled matrix of the base image
+ # x = int(round(float(vertex.attributes['X'].value) / ImgMaxSizeX_orig * (cols*Img_Fact)))
+ # y = int(round(float(vertex.attributes['Y'].value) / ImgMaxSizeY_orig * (rows*Img_Fact)))
+ x = int(round(float(vertex.attributes['X'].value) / NewFact))
+ y = int(round(float(vertex.attributes['Y'].value) / NewFact))
+ xy_neg[regionID].append((x,y))
+
+
+ #xy_a = np.array(xy[regionID])
+
+ # print("%d xy" % len(xy))
+ #print(xy)
+ # print("%d xy_neg" % len(xy_neg))
+ #print(xy_neg)
+ # print("Img_Fact:")
+ # print(NewFact)
+ # img = Image.new('L', (int(cols*Img_Fact), int(rows*Img_Fact)), 0)
+ img = Image.new('L', (int(ImgMaxSizeX_orig/NewFact), int(ImgMaxSizeY_orig/NewFact)), 0)
+ for regionID in xy.keys():
+ xy_a = xy[regionID]
+ ImageDraw.Draw(img,'L').polygon(xy_a, outline=255, fill=255)
+ for regionID in xy_neg.keys():
+ xy_a = xy_neg[regionID]
+ ImageDraw.Draw(img,'L').polygon(xy_a, outline=255, fill=0)
+ #img = img.resize((cols,rows), Image.ANTIALIAS)
+ mask = np.array(img)
+ #print(mask.shape)
+ if Attribute_Name == "non_selected_regions":
+ # scipy.misc.toimage(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg"))
+ Image.fromarray(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg"))
+ else:
+ if self._mask_type==0:
+ # scipy.misc.toimage(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + "_inv.jpeg"))
+ Image.fromarray(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + "_inv.jpeg"))
+ else:
+ # scipy.misc.toimage(mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg"))
+ Image.fromarray(mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg"))
+ #print(mask)
+ return mask / 255.0, xml_valid, NewFact
+ # Img_Fact
+
+
+class DeepZoomStaticTiler(object):
+ """Handles generation of tiles and metadata for all images in a slide."""
+
+ def __init__(self, slidepath, basename, format, tile_size, overlap,
+ limit_bounds, quality, workers, with_viewer, Bkg, basenameJPG, xmlfile, mask_type, ROIpc, oLabel, ImgExtension, SaveMasks, Mag, normalize, Fieldxml):
+ if with_viewer:
+ # Check extra dependency before doing a bunch of work
+ import jinja2
+ #print("line226 - %s " % (slidepath) )
+ self._slide = open_slide(slidepath)
+ self._basename = basename
+ self._basenameJPG = basenameJPG
+ self._xmlfile = xmlfile
+ self._mask_type = mask_type
+ self._format = format
+ self._tile_size = tile_size
+ self._overlap = overlap
+ self._limit_bounds = limit_bounds
+ self._queue = JoinableQueue(2 * workers)
+ self._workers = workers
+ self._with_viewer = with_viewer
+ self._Bkg = Bkg
+ self._ROIpc = ROIpc
+ self._dzi_data = {}
+ self._xmlLabel = oLabel
+ self._ImgExtension = ImgExtension
+ self._SaveMasks = SaveMasks
+ self._Mag = Mag
+ self._normalize = normalize
+ self._Fieldxml = Fieldxml
+
+ for _i in range(workers):
+ TileWorker(self._queue, slidepath, tile_size, overlap,
+ limit_bounds, quality, self._Bkg, self._ROIpc).start()
+
+ def run(self):
+ self._run_image()
+ if self._with_viewer:
+ for name in self._slide.associated_images:
+ self._run_image(name)
+ self._write_html()
+ self._write_static()
+ self._shutdown()
+
+ def _run_image(self, associated=None):
+ """Run a single image from self._slide."""
+ if associated is None:
+ image = self._slide
+ if self._with_viewer:
+ basename = os.path.join(self._basename, VIEWER_SLIDE_NAME)
+ else:
+ basename = self._basename
+ else:
+ image = ImageSlide(self._slide.associated_images[associated])
+ basename = os.path.join(self._basename, self._slugify(associated))
+ # print("enter DeepZoomGenerator")
+ dz = DeepZoomGenerator(image, self._tile_size, self._overlap,limit_bounds=self._limit_bounds)
+ # print("enter DeepZoomImageTiler")
+ tiler = DeepZoomImageTiler(dz, basename, self._format, associated,self._queue, self._slide, self._basenameJPG, self._xmlfile, self._mask_type, self._xmlLabel, self._ROIpc, self._ImgExtension, self._SaveMasks, self._Mag, self._normalize, self._Fieldxml)
+ tiler.run()
+ self._dzi_data[self._url_for(associated)] = tiler.get_dzi()
+
+
+
+ def _url_for(self, associated):
+ if associated is None:
+ base = VIEWER_SLIDE_NAME
+ else:
+ base = self._slugify(associated)
+ return '%s.dzi' % base
+
+ def _write_html(self):
+ import jinja2
+ env = jinja2.Environment(loader=jinja2.PackageLoader(__name__),autoescape=True)
+ template = env.get_template('slide-multipane.html')
+ associated_urls = dict((n, self._url_for(n))
+ for n in self._slide.associated_images)
+ try:
+ mpp_x = self._slide.properties[openslide.PROPERTY_NAME_MPP_X]
+ mpp_y = self._slide.properties[openslide.PROPERTY_NAME_MPP_Y]
+ mpp = (float(mpp_x) + float(mpp_y)) / 2
+ except (KeyError, ValueError):
+ mpp = 0
+ # Embed the dzi metadata in the HTML to work around Chrome's
+ # refusal to allow XmlHttpRequest from file:///, even when
+ # the originating page is also a file:///
+ data = template.render(slide_url=self._url_for(None),slide_mpp=mpp,associated=associated_urls, properties=self._slide.properties, dzi_data=json.dumps(self._dzi_data))
+ with open(os.path.join(self._basename, 'index.html'), 'w') as fh:
+ fh.write(data)
+
+ def _write_static(self):
+ basesrc = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+ 'static')
+ basedst = os.path.join(self._basename, 'static')
+ self._copydir(basesrc, basedst)
+ self._copydir(os.path.join(basesrc, 'images'),
+ os.path.join(basedst, 'images'))
+
+ def _copydir(self, src, dest):
+ if not os.path.exists(dest):
+ os.makedirs(dest)
+ for name in os.listdir(src):
+ srcpath = os.path.join(src, name)
+ if os.path.isfile(srcpath):
+ shutil.copy(srcpath, os.path.join(dest, name))
+
+ @classmethod
+ def _slugify(cls, text):
+ text = normalize('NFKD', text.lower()).encode('ascii', 'ignore').decode()
+ return re.sub('[^a-z0-9]+', '_', text)
+
+ def _shutdown(self):
+ for _i in range(self._workers):
+ self._queue.put(None)
+ self._queue.join()
+
+
+
+def ImgWorker(queue):
+ # print("ImgWorker started")
+ while True:
+ cmd = queue.get()
+ if cmd is None:
+ queue.task_done()
+ break
+ # print("Execute: %s" % (cmd))
+ subprocess.Popen(cmd, shell=True).wait()
+ queue.task_done()
+
+def xml_read_labels(xmldir, Fieldxml):
+ try:
+ xmlcontent = minidom.parse(xmldir)
+ xml_valid = True
+ except:
+ xml_valid = False
+ print("error with minidom.parse(xmldir)")
+ return [], xml_valid
+ labeltag = xmlcontent.getElementsByTagName('Attribute')
+ xml_labels = []
+ for xmllabel in labeltag:
+ xml_labels.append(xmllabel.attributes[Fieldxml].value)
+ #xml_labels.append(xmllabel.attributes['Name'].value)
+ # xml_labels.append(xmllabel.attributes['Value'].value)
+ if xml_labels==[]:
+ xml_labels = ['']
+ # print(xml_labels)
+ return xml_labels, xml_valid
+
+
+if __name__ == '__main__':
+ parser = OptionParser(usage='Usage: %prog [options] ')
+
+ parser.add_option('-L', '--ignore-bounds', dest='limit_bounds',
+ default=True, action='store_false',
+ help='display entire scan area')
+ parser.add_option('-e', '--overlap', metavar='PIXELS', dest='overlap',
+ type='int', default=1,
+ help='overlap of adjacent tiles [1]')
+ parser.add_option('-f', '--format', metavar='{jpeg|png}', dest='format',
+ default='jpeg',
+ help='image format for tiles [jpeg]')
+ parser.add_option('-j', '--jobs', metavar='COUNT', dest='workers',
+ type='int', default=4,
+ help='number of worker processes to start [4]')
+ parser.add_option('-o', '--output', metavar='NAME', dest='basename',
+ help='base name of output file')
+ parser.add_option('-Q', '--quality', metavar='QUALITY', dest='quality',
+ type='int', default=90,
+ help='JPEG compression quality [90]')
+ parser.add_option('-r', '--viewer', dest='with_viewer',
+ action='store_true',
+ help='generate directory tree with HTML viewer')
+ parser.add_option('-s', '--size', metavar='PIXELS', dest='tile_size',
+ type='int', default=254,
+ help='tile size [254]')
+ parser.add_option('-B', '--Background', metavar='PIXELS', dest='Bkg',
+ type='float', default=50,
+ help='Max background threshold [50]; percentager of background allowed')
+ parser.add_option('-x', '--xmlfile', metavar='NAME', dest='xmlfile',
+ help='xml file if needed')
+ parser.add_option('-F', '--Fieldxml', metavar='{Name|Value}', dest='Fieldxml',
+ default='Value',
+ help='which field of the xml file is the label saved')
+ parser.add_option('-m', '--mask_type', metavar='COUNT', dest='mask_type',
+ type='int', default=1,
+ help='if xml file is used, keep tile within the ROI (1) or outside of it (0)')
+ parser.add_option('-R', '--ROIpc', metavar='PIXELS', dest='ROIpc',
+ type='float', default=50,
+ help='To be used with xml file - minimum percentage of tile covered by ROI (white)')
+ parser.add_option('-l', '--oLabelref', metavar='NAME', dest='oLabelref',
+ help='To be used with xml file - Only tile for label which contains the characters in oLabel')
+ parser.add_option('-S', '--SaveMasks', metavar='NAME', dest='SaveMasks',
+ default=False,
+ help='set to yes if you want to save ALL masks for ALL tiles (will be saved in same directory with suffix)')
+ parser.add_option('-t', '--tmp_dcm', metavar='NAME', dest='tmp_dcm',
+ help='base name of output folder to save intermediate dcm images converted to jpg (we assume the patient ID is the folder name in which the dcm images are originally saved)')
+ parser.add_option('-M', '--Mag', metavar='PIXELS', dest='Mag',
+ type='float', default=-1,
+ help='Magnification at which tiling should be done (-1 of all)')
+ parser.add_option('-N', '--normalize', metavar='NAME', dest='normalize',
+ help='if normalization is needed, N list the mean and std for each channel. For example \'57,22,-8,20,10,5\' with the first 3 numbers being the targeted means, and then the targeted stds')
+
+
+
+
+ (opts, args) = parser.parse_args()
+
+
+ try:
+ slidepath = args[0]
+ except IndexError:
+ parser.error('Missing slide argument')
+ if opts.basename is None:
+ opts.basename = os.path.splitext(os.path.basename(slidepath))[0]
+ if opts.xmlfile is None:
+ opts.xmlfile = ''
+
+ try:
+ if opts.normalize is not None:
+ opts.normalize = [float(x) for x in opts.normalize.split(',')]
+ if len(opts.normalize) != 6:
+ opts.normalize = ''
+ parser.error("ERROR: NO NORMALIZATION APPLIED: input vector does not have the right length - 6 values expected")
+ else:
+ opts.normalize = ''
+
+ except:
+ opts.normalize = ''
+ parser.error("ERROR: NO NORMALIZATION APPLIED: input vector does not have the right format")
+ #if ss != '':
+ # if os.path.isdir(opts.xmlfile):
+
+
+ # Initialization
+ # imgExample = "/ifs/home/coudrn01/NN/Lung/RawImages/*/*svs"
+ # tile_size = 512
+ # max_number_processes = 10
+ # NbrCPU = 4
+
+ # get images from the data/ file.
+
+ files = glob(slidepath)
+ #ImgExtension = os.path.splitext(slidepath)[1]
+ ImgExtension = slidepath.split('*')[-1]
+ #files
+ #len(files)
+ # print(args)
+ # print(args[0])
+ # print(slidepath)
+ # print(files)
+ # print("***********************")
+
+ '''
+ dz_queue = JoinableQueue()
+ procs = []
+ print("Nb of processes:")
+ print(opts.max_number_processes)
+ for i in range(opts.max_number_processes):
+ p = Process(target = ImgWorker, args = (dz_queue,))
+ #p.deamon = True
+ p.setDaemon = True
+ p.start()
+ procs.append(p)
+ '''
+ files = sorted(files)
+ print(len(files), ' to process')
+ import time
+ time.sleep(5)
+ for imgNb in tqdm(range(len(files))):
+ filename = files[imgNb]
+ #print(filename)
+ opts.basenameJPG = os.path.splitext(os.path.basename(filename))[0]
+ #print("processing: " + opts.basenameJPG + " with extension: " + ImgExtension)
+ #opts.basenameJPG = os.path.splitext(os.path.basename(slidepath))[0]
+ #if os.path.isdir("%s_files" % (basename)):
+ # print("EXISTS")
+ #else:
+ # print("Not Found")
+
+ if ("dcm" in ImgExtension) :
+ print("convert %s dcm to jpg" % filename)
+ if opts.tmp_dcm is None:
+ parser.error('Missing output folder for dcm>jpg intermediate files')
+ elif not os.path.isdir(opts.tmp_dcm):
+ parser.error('Missing output folder for dcm>jpg intermediate files')
+
+ if filename[-3:] == 'jpg':
+ continue
+ ImageFile=dicom.read_file(filename)
+ im1 = ImageFile.pixel_array
+ maxVal = float(im1.max())
+ minVal = float(im1.min())
+ height = im1.shape[0]
+ width = im1.shape[1]
+ image = np.zeros((height,width,3), 'uint8')
+ image[...,0] = ((im1[:,:].astype(float) - minVal) / (maxVal - minVal) * 255.0).astype(int)
+ image[...,1] = ((im1[:,:].astype(float) - minVal) / (maxVal - minVal) * 255.0).astype(int)
+ image[...,2] = ((im1[:,:].astype(float) - minVal) / (maxVal - minVal) * 255.0).astype(int)
+ # dcm_ID = os.path.basename(os.path.dirname(filename))
+ # opts.basenameJPG = dcm_ID + "_" + opts.basenameJPG
+ filename = os.path.join(opts.tmp_dcm, opts.basenameJPG + ".jpg")
+ # print(filename)
+ imsave(filename,image)
+
+ output = os.path.join(opts.basename, opts.basenameJPG)
+
+ try:
+ DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run()
+ except Exception as e:
+ print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0]))
+ print(e)
+
+ #elif ("jpg" in ImgExtension) :
+ # output = os.path.join(opts.basename, opts.basenameJPG)
+ # if os.path.exists(output + "_files"):
+ # print("Image %s already tiled" % opts.basenameJPG)
+ # continue
+
+ # DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run()
+
+ elif opts.xmlfile != '':
+ xmldir = os.path.join(opts.xmlfile, opts.basenameJPG + '.xml')
+ # print("xml:")
+ # print(xmldir)
+ if os.path.isfile(xmldir):
+ if (opts.mask_type==1) or (opts.oLabelref!=''):
+ # either mask inside ROI, or mask outside but a reference label exist
+ xml_labels, xml_valid = xml_read_labels(xmldir, opts.Fieldxml)
+ if (opts.mask_type==1):
+ # No inverse mask
+ Nbr_ROIs_ForNegLabel = 1
+ elif (opts.oLabelref!=''):
+ # Inverse mask and a label reference exist
+ Nbr_ROIs_ForNegLabel = 0
+
+ for oLabel in xml_labels:
+ # print("label is %s and ref is %s" % (oLabel, opts.oLabelref))
+ if (opts.oLabelref in oLabel) or (opts.oLabelref==''):
+ # is a label is identified
+ if (opts.mask_type==0):
+ # Inverse mask and label exist in the image
+ Nbr_ROIs_ForNegLabel += 1
+ # there is a label, and map is to be inverted
+ output = os.path.join(opts.basename, oLabel+'_inv', opts.basenameJPG)
+ if not os.path.exists(os.path.join(opts.basename, oLabel+'_inv')):
+ os.makedirs(os.path.join(opts.basename, oLabel+'_inv'))
+ else:
+ Nbr_ROIs_ForNegLabel += 1
+ output = os.path.join(opts.basename, oLabel, opts.basenameJPG)
+ if not os.path.exists(os.path.join(opts.basename, oLabel)):
+ os.makedirs(os.path.join(opts.basename, oLabel))
+ if 1:
+ #try:
+ DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, oLabel, ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run()
+ #except:
+ # print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0]))
+ if Nbr_ROIs_ForNegLabel==0:
+ print("label %s is not in that image; invert everything" % (opts.oLabelref))
+ # a label ref was given, and inverse mask is required but no ROI with this label in that map --> take everything
+ oLabel = opts.oLabelref
+ output = os.path.join(opts.basename, opts.oLabelref+'_inv', opts.basenameJPG)
+ if not os.path.exists(os.path.join(opts.basename, oLabel+'_inv')):
+ os.makedirs(os.path.join(opts.basename, oLabel+'_inv'))
+ if 1:
+ #try:
+ DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, oLabel, ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run()
+ #except:
+ # print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0]))
+
+ else:
+ # Background
+ oLabel = "non_selected_regions"
+ output = os.path.join(opts.basename, oLabel, opts.basenameJPG)
+ if not os.path.exists(os.path.join(opts.basename, oLabel)):
+ os.makedirs(os.path.join(opts.basename, oLabel))
+ try:
+ DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, oLabel, ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run()
+ except Exception as e:
+ print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0]))
+ print(e)
+
+ else:
+ if (ImgExtension == ".jpg") | (ImgExtension == ".dcm") :
+ print("Input image to be tiled is jpg or dcm and not svs - will be treated as such")
+ output = os.path.join(opts.basename, opts.basenameJPG)
+ try:
+ DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run()
+ except Exception as e:
+ print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0]))
+ print(e)
+
+
+ else:
+ print("No xml file found for slide %s.svs (expected: %s). Directory or xml file does not exist" % (opts.basenameJPG, xmldir) )
+ continue
+ else:
+ output = os.path.join(opts.basename, opts.basenameJPG)
+ if os.path.exists(output + "_files"):
+ print("Image %s already tiled" % opts.basenameJPG)
+ continue
+ try:
+ #if True:
+ DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run()
+ except Exception as e:
+ print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0]))
+ print(e)
+ '''
+ dz_queue.join()
+ for i in range(opts.max_number_processes):
+ dz_queue.put( None )
+ '''
+
+ print("End")
diff --git a/src/vis_graphcam.py b/src/vis_graphcam.py
new file mode 100644
index 0000000000000000000000000000000000000000..dec1c01aff96ddb7c1f03761b4b31cb1a02cec4c
--- /dev/null
+++ b/src/vis_graphcam.py
@@ -0,0 +1,210 @@
+from PIL import Image
+from matplotlib.pyplot import imshow, show
+import matplotlib.pyplot as plt
+from torchvision import models, transforms
+from torch.autograd import Variable
+from torch.nn import functional as F
+import torch
+import torch.nn as nn
+from torch import topk
+import numpy as np
+import os
+import skimage.transform
+import cv2
+import math
+import openslide
+import argparse
+import pickle
+
+def show_cam_on_image(img, mask):
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
+ heatmap = np.float32(heatmap) / 255
+ cam = heatmap + np.float32(img)
+ cam = cam / np.max(cam)
+ return cam
+
+def cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s):
+ mask = np.full_like(gray, 0.).astype(np.float32)
+ for ind1, patch in enumerate(patches):
+ x, y = patch.split('.')[0].split('_')
+ x, y = int(x), int(y)
+ #if y <5 or x>w-5 or y>h-5:
+ # continue
+ mask[int(y*h_s):int((y+1)*h_s), int(x*w_s):int((x+1)*w_s)].fill(cam_matrix[ind1][0])
+
+ return mask
+
+def main(args):
+ label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb'))
+
+ label_name_from_id = dict()
+ for label_name, label_id in label_map.items():
+ label_name_from_id[label_id] = label_name
+
+ n_class = len(label_map)#args.n_class
+ file_name, label = open(args.path_file, 'r').readlines()[-1].split('\t')
+ label = label.rstrip().strip()
+ #site, file_name = file_name.split('/')
+ file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name))
+ print(file_name)
+ print(label)
+
+ p = torch.load('graphcam/prob.pt').cpu().detach().numpy()[0]
+ file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name))
+ #ori = openslide.OpenSlide(os.path.join(args.path_WSI, '{}.svs').format(file_name))
+ ORIGINAL_FILEPATH = os.path.join(args.path_WSI,'TCGA',label, '{}.svs'.format(file_name))
+ print('L', ORIGINAL_FILEPATH)
+ ori = openslide.OpenSlide(ORIGINAL_FILEPATH)
+ patch_info = open(os.path.join(args.path_graph, file_name, 'c_idx.txt'), 'r')
+
+ width, height = ori.dimensions
+
+ REDUCTION_FACTOR = 10
+ w, h = int(width/512), int(height/512)
+ w_r, h_r = int(width/20), int(height/20)
+ resized_img = ori.get_thumbnail((width,height))#ori.get_thumbnail((w_r,h_r))
+ resized_img = resized_img.resize((w_r,h_r))
+ ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
+ print('ratios ', ratio_w, ratio_h)
+ w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
+ print(w_s, h_s)
+
+ patch_info = patch_info.readlines()
+ patches = []
+ xmax, ymax = 0, 0
+ for patch in patch_info:
+ x, y = patch.strip('\n').split('\t')
+ if xmax < int(x): xmax = int(x)
+ if ymax < int(y): ymax = int(y)
+ patches.append('{}_{}.jpeg'.format(x,y))
+
+ output_img = np.asarray(resized_img)[:,:,::-1].copy()
+ #-----------------------------------------------------------------------------------------------------#
+ # GraphCAM
+ print('visulize GraphCAM')
+ assign_matrix = torch.load('graphcam/s_matrix_ori.pt')
+ m = nn.Softmax(dim=1)
+ assign_matrix = m(assign_matrix)
+
+ # Thresholding for better visualization
+ p = np.clip(p, 0.4, 1)
+
+
+
+ output_img_copy =np.copy(output_img)
+ gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+ image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min())
+ cam_matrices = []
+ masks = []
+ visualizations = []
+ print(len(patches))
+ os.makedirs('graphcam_vis', exist_ok=True)
+ for class_i in range(n_class):
+
+ # Load graphcam for each class
+ cam_matrix = torch.load(f'graphcam/cam_{class_i}.pt')
+ print(cam_matrix.shape)
+ cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0))
+ cam_matrix = cam_matrix.cpu()
+ print(assign_matrix.shape)
+ print(cam_matrix.shape)
+ # Normalize the graphcam
+ cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min())
+ cam_matrix = cam_matrix.detach().numpy()
+ cam_matrix = p[class_i] * cam_matrix
+ cam_matrix = np.clip(cam_matrix, 0, 1)
+ print(cam_matrix.shape)
+ #print()
+
+
+ mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s)
+ print('mask shape ', mask.shape)
+ print('imgtf attr ', image_transformer_attribution.shape)
+ vis = show_cam_on_image(image_transformer_attribution, mask)
+ vis = np.uint8(255 * vis)
+
+ cam_matrices.append(cam_matrix)
+ masks.append(mask)
+ visualizations.append(vis)
+ print()
+ cv2.imwrite('graphcam_vis/{}_all_types_cam_{}.png'.format(file_name, label_name_from_id[class_i] ), vis)
+ h, w, _ = output_img.shape
+ if h > w:
+ vis_merge = cv2.hconcat([output_img] + visualizations)
+ else:
+ vis_merge = cv2.vconcat([output_img] + visualizations)
+
+
+ cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name), vis_merge)
+ cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img)
+
+ '''
+ # Load graphcam for differnet class
+ cam_matrix_0 = torch.load('graphcam/cam_0.pt')
+ cam_matrix_0 = torch.mm(assign_matrix, cam_matrix_0.transpose(1,0))
+ cam_matrix_0 = cam_matrix_0.cpu()
+ cam_matrix_1 = torch.load('graphcam/cam_1.pt')
+ cam_matrix_1 = torch.mm(assign_matrix, cam_matrix_1.transpose(1,0))
+ cam_matrix_1 = cam_matrix_1.cpu()
+ cam_matrix_2 = torch.load('graphcam/cam_2.pt')
+ cam_matrix_2 = torch.mm(assign_matrix, cam_matrix_2.transpose(1,0))
+ cam_matrix_2 = cam_matrix_2.cpu()
+
+ # Normalize the graphcam
+ cam_matrix_0 = (cam_matrix_0 - cam_matrix_0.min()) / (cam_matrix_0.max() - cam_matrix_0.min())
+ cam_matrix_0 = cam_matrix_0.detach().numpy()
+ cam_matrix_0 = p[0] * cam_matrix_0
+ cam_matrix_0 = np.clip(cam_matrix_0, 0, 1)
+ cam_matrix_1 = (cam_matrix_1 - cam_matrix_1.min()) / (cam_matrix_1.max() - cam_matrix_1.min())
+ cam_matrix_1 = cam_matrix_1.detach().numpy()
+ cam_matrix_1 = p[1] * cam_matrix_1
+ cam_matrix_1 = np.clip(cam_matrix_1, 0, 1)
+ cam_matrix_2 = (cam_matrix_2 - cam_matrix_2.min()) / (cam_matrix_2.max() - cam_matrix_2.min())
+ cam_matrix_2 = cam_matrix_2.detach().numpy()
+ cam_matrix_2 = p[2] * cam_matrix_2
+ cam_matrix_2 = np.clip(cam_matrix_2, 0, 1)
+
+ output_img_copy =np.copy(output_img)
+
+ gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+ image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min())
+
+ mask0 = cam_to_mask(gray, patches, cam_matrix_0, w, h, w_s, h_s)
+ vis0 = show_cam_on_image(image_transformer_attribution, mask0)
+ vis0 = np.uint8(255 * vis0)
+ mask1 = cam_to_mask(gray, patches, cam_matrix_1, w, h, w_s, h_s)
+ vis1 = show_cam_on_image(image_transformer_attribution, mask1)
+ vis1 = np.uint8(255 * vis1)
+ mask2 = cam_to_mask(gray, patches, cam_matrix_2, w, h, w_s, h_s)
+ vis2 = show_cam_on_image(image_transformer_attribution, mask2)
+ vis2 = np.uint8(255 * vis2)
+
+ ##########################################
+ h, w, _ = output_img.shape
+ if h > w:
+ vis_merge = cv2.hconcat([output_img, vis0, vis1, vis2])
+ else:
+ vis_merge = cv2.vconcat([output_img, vis0, vis1, vis2])
+
+ #cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_all.png'.format(file_name, site), vis_merge)
+
+ #cv2.imwrite('graphcam_vis/{}_{}_all_types_ori.png'.format(file_name, site), output_img)
+ #cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_luad.png'.format(file_name, site), vis1)
+ #cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_lscc.png'.format(file_name, site), vis2)
+ cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name, ), vis_merge)
+
+ cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img)
+ cv2.imwrite('graphcam_vis/{}_all_types_cam_luad.png'.format(file_name ), vis1)
+ cv2.imwrite('graphcam_vis/{}_all_types_cam_lscc.png'.format(file_name ), vis2)
+
+ '''
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='GraphCAM')
+ parser.add_argument('--path_file', type=str, default='test.txt', help='txt file contains test sample')
+ parser.add_argument('--path_patches', type=str, default='', help='')
+ parser.add_argument('--path_WSI', type=str, default='', help='')
+ parser.add_argument('--path_graph', type=str, default='', help='')
+ parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/utils/__pycache__/dataset.cpython-38.pyc b/utils/__pycache__/dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b210956f1abfd54032c697a9462b73ddd18ea8f9
Binary files /dev/null and b/utils/__pycache__/dataset.cpython-38.pyc differ
diff --git a/utils/__pycache__/lr_scheduler.cpython-38.pyc b/utils/__pycache__/lr_scheduler.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82a76349d767f6eddb063baf8e248265b4c8fc75
Binary files /dev/null and b/utils/__pycache__/lr_scheduler.cpython-38.pyc differ
diff --git a/utils/__pycache__/metrics.cpython-38.pyc b/utils/__pycache__/metrics.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c91051c4f87985c42a6bf7b4e15eca2d30f50a44
Binary files /dev/null and b/utils/__pycache__/metrics.cpython-38.pyc differ
diff --git a/utils/dataset.py b/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..86766344717e223362c30caba1f70e5d8f99cef0
--- /dev/null
+++ b/utils/dataset.py
@@ -0,0 +1,147 @@
+import os
+import torch
+import torch.utils.data as data
+import numpy as np
+from PIL import Image, ImageFile
+import random
+from torchvision.transforms import ToTensor
+from torchvision import transforms
+import cv2
+import pickle
+import torch.nn.functional as F
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+def collate_features(batch):
+ img = torch.cat([item[0] for item in batch], dim = 0)
+ coords = np.vstack([item[1] for item in batch])
+ return [img, coords]
+
+def eval_transforms(pretrained=False):
+ if pretrained:
+ mean = (0.485, 0.456, 0.406)
+ std = (0.229, 0.224, 0.225)
+
+ else:
+ mean = (0.5,0.5,0.5)
+ std = (0.5,0.5,0.5)
+
+ trnsfrms_val = transforms.Compose(
+ [
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize(mean = mean, std = std)
+ ]
+ )
+
+ return trnsfrms_val
+
+class GraphDataset(data.Dataset):
+ """input and label image dataset"""
+
+ def __init__(self, root, ids, metadata_path, target_patch_size=-1):
+ super(GraphDataset, self).__init__()
+ """
+ Args:
+
+ fileDir(string): directory with all the input images.
+ transform(callable, optional): Optional transform to be applied on a sample
+ """
+ self.root = root
+ self.ids = ids
+ #self.target_patch_size = target_patch_size
+ self.classdict = pickle.load(open(os.path.join(metadata_path, 'label_map.pkl'), 'rb' )) # {'normal': 0, 'luad': 1, 'lscc': 2} #
+ #self.classdict = {'normal': 0, 'tumor': 1} #
+ #self.classdict = {'Normal': 0, 'TCGA-LUAD': 1, 'TCGA-LUSC': 2}
+ self._up_kwargs = {'mode': 'bilinear'}
+
+ def __getitem__(self, index):
+ sample = {}
+ info = self.ids[index].replace('\n', '')
+ #file_name, label = info.split('\t')[0].rsplit('.', 1)[0], info.split('\t')[1]
+ file_name, label = info.split('\t')[0], info.split('\t')[1]
+
+
+ sample['label'] = self.classdict[label]
+ sample['id'] = file_name
+
+
+ file_path = os.path.join(self.root, 'simclr_files')
+ #feature_path = os.path.join(self.root, file_name, 'features.pt')
+ feature_path = os.path.join(file_path, file_name, 'features.pt')
+
+ if os.path.exists(feature_path):
+ features = torch.load(feature_path, map_location=lambda storage, loc: storage)
+ else:
+ print(feature_path + ' not exists')
+ features = torch.zeros(1, 512)
+
+ #adj_s_path = os.path.join(self.root, file_name, 'adj_s.pt')
+ adj_s_path = os.path.join(file_path, file_name, 'adj_s.pt')
+ if os.path.exists(adj_s_path):
+ adj_s = torch.load(adj_s_path, map_location=lambda storage, loc: storage)
+ else:
+ print(adj_s_path + ' not exists')
+ adj_s = torch.ones(features.shape[0], features.shape[0])
+
+ #features = features.unsqueeze(0)
+ sample['image'] = features
+ sample['adj_s'] = adj_s #adj_s.to(torch.double)
+ # return {'image': image.astype(np.float32), 'label': label.astype(np.int64)}
+
+ return sample
+
+
+ def __len__(self):
+ return len(self.ids)
+
+
+''' def __getitem__(self, index):
+ sample = {}
+ info = self.ids[index].replace('\n', '')
+ file_name, label = info.split('\t')[0].rsplit('.', 1)[0], info.split('\t')[1]
+ site, file_name = file_name.split('/')
+
+ # if site =='CCRCC':
+ # file_path = self.root + 'CPTAC_CCRCC_features/simclr_files'
+ if site =='LUAD' or site =='LSCC':
+ site = 'LUNG'
+ file_path = self.root + 'CPTAC_{}_features/simclr_files'.format(site) #_pre# with # rushin
+
+ # For NLST only
+ if site =='NLST':
+ file_path = self.root + 'NLST_Lung_features/simclr_files'
+
+ # For TCGA only
+ if site =='TCGA':
+ file_name = info.split('\t')[0]
+ _, file_name = file_name.split('/')
+ file_path = self.root + 'TCGA_LUNG_features/simclr_files' #_resnet_with
+
+ sample['label'] = self.classdict[label]
+ sample['id'] = file_name
+
+ #feature_path = os.path.join(self.root, file_name, 'features.pt')
+ feature_path = os.path.join(file_path, file_name, 'features.pt')
+
+ if os.path.exists(feature_path):
+ features = torch.load(feature_path, map_location=lambda storage, loc: storage)
+ else:
+ print(feature_path + ' not exists')
+ features = torch.zeros(1, 512)
+
+ #adj_s_path = os.path.join(self.root, file_name, 'adj_s.pt')
+ adj_s_path = os.path.join(file_path, file_name, 'adj_s.pt')
+ if os.path.exists(adj_s_path):
+ adj_s = torch.load(adj_s_path, map_location=lambda storage, loc: storage)
+ else:
+ print(adj_s_path + ' not exists')
+ adj_s = torch.ones(features.shape[0], features.shape[0])
+
+ #features = features.unsqueeze(0)
+ sample['image'] = features
+ sample['adj_s'] = adj_s #adj_s.to(torch.double)
+ # return {'image': image.astype(np.float32), 'label': label.astype(np.int64)}
+
+ return sample
+'''
\ No newline at end of file
diff --git a/utils/lr_scheduler.py b/utils/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5913949bdddf15602e55f82985c0a6b7b6656e23
--- /dev/null
+++ b/utils/lr_scheduler.py
@@ -0,0 +1,71 @@
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+## Created by: Hang Zhang
+## ECE Department, Rutgers University
+## Email: zhang.hang@rutgers.edu
+## Copyright (c) 2017
+##
+## This source code is licensed under the MIT-style license found in the
+## LICENSE file in the root directory of this source tree
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+import math
+
+class LR_Scheduler(object):
+ """Learning Rate Scheduler
+
+ Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}``
+
+ Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))``
+
+ Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9``
+
+ Args:
+ args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`),
+ :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs,
+ :attr:`args.lr_step`
+
+ iters_per_epoch: number of iterations per epoch
+ """
+ def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
+ lr_step=0, warmup_epochs=0):
+ self.mode = mode
+ print('Using {} LR Scheduler!'.format(self.mode))
+ self.lr = base_lr
+ if mode == 'step':
+ assert lr_step
+ self.lr_step = lr_step
+ self.iters_per_epoch = iters_per_epoch
+ self.N = num_epochs * iters_per_epoch
+ self.epoch = -1
+ self.warmup_iters = warmup_epochs * iters_per_epoch
+
+ def __call__(self, optimizer, i, epoch, best_pred):
+ T = epoch * self.iters_per_epoch + i
+ if self.mode == 'cos':
+ lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi))
+ elif self.mode == 'poly':
+ lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9)
+ elif self.mode == 'step':
+ lr = self.lr * (0.1 ** (epoch // self.lr_step))
+ else:
+ raise NotImplemented
+ # warm up lr schedule
+ if self.warmup_iters > 0 and T < self.warmup_iters:
+ lr = lr * 1.0 * T / self.warmup_iters
+ if epoch > self.epoch:
+ print('\n=>Epoches %i, learning rate = %.7f, \
+ previous best = %.4f' % (epoch+1, lr, best_pred))
+ self.epoch = epoch
+ assert lr >= 0
+ self._adjust_learning_rate(optimizer, lr)
+
+ def _adjust_learning_rate(self, optimizer, lr):
+ if len(optimizer.param_groups) == 1:
+ optimizer.param_groups[0]['lr'] = lr
+ else:
+ # enlarge the lr at the head
+ for i in range(len(optimizer.param_groups)):
+ if optimizer.param_groups[i]['lr'] > 0: optimizer.param_groups[i]['lr'] = lr
+ # optimizer.param_groups[0]['lr'] = lr
+ # for i in range(1, len(optimizer.param_groups)):
+ # optimizer.param_groups[i]['lr'] = lr * 10
diff --git a/utils/metrics.py b/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b7799650cc535052f0f9753cd8317f2e740830e
--- /dev/null
+++ b/utils/metrics.py
@@ -0,0 +1,47 @@
+# Adapted from score written by wkentaro
+# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
+
+import numpy as np
+
+class ConfusionMatrix(object):
+
+ def __init__(self, n_classes):
+ self.n_classes = n_classes
+ # axis = 0: prediction
+ # axis = 1: target
+ self.confusion_matrix = np.zeros((n_classes, n_classes))
+
+ def _fast_hist(self, label_true, label_pred, n_class):
+ hist = np.zeros((n_class, n_class))
+ hist[label_pred, label_true] += 1
+
+ return hist
+
+ def update(self, label_trues, label_preds):
+ for lt, lp in zip(label_trues, label_preds):
+ tmp = self._fast_hist(lt.item(), lp.item(), self.n_classes) #lt.item(), lp.item()
+ self.confusion_matrix += tmp
+
+ def get_scores(self):
+ """Returns accuracy score evaluation result.
+ - overall accuracy
+ - mean accuracy
+ - mean IU
+ - fwavacc
+ """
+ hist = self.confusion_matrix
+ # accuracy is recall/sensitivity for each class, predicted TP / all real positives
+ # axis in sum: perform summation along
+
+ if sum(hist.sum(axis=1)) != 0:
+ acc = sum(np.diag(hist)) / sum(hist.sum(axis=1))
+ else:
+ acc = 0.0
+
+ return acc
+
+ def plotcm(self):
+ print(self.confusion_matrix)
+
+ def reset(self):
+ self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
\ No newline at end of file
diff --git a/weights/feature_extractor/config.yaml b/weights/feature_extractor/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c8f4309e6cbefa7270b1beb7c639d9551b325a8
--- /dev/null
+++ b/weights/feature_extractor/config.yaml
@@ -0,0 +1,23 @@
+batch_size: 256
+epochs: 20
+eval_every_n_epochs: 1
+fine_tune_from: ''
+log_every_n_steps: 25
+weight_decay: 10e-6
+fp16_precision: False
+n_gpu: 2
+gpu_ids: (0,1)
+
+model:
+ out_dim: 512
+ base_model: "resnet18"
+
+dataset:
+ s: 1
+ input_shape: (224,224,3)
+ num_workers: 10
+ valid_size: 0.1
+
+loss:
+ temperature: 0.5
+ use_cosine_similarity: True
diff --git a/weights/feature_extractor/model.pth b/weights/feature_extractor/model.pth
new file mode 100644
index 0000000000000000000000000000000000000000..86030da63ff1604b29d80d2bbed137d96526439d
--- /dev/null
+++ b/weights/feature_extractor/model.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c156857743ee3bd7b353fe8d34d2250e4c153e930c457b929ec461cafcd15fe4
+size 46779101
diff --git a/weights/graph_transformer/GraphCAM.pth b/weights/graph_transformer/GraphCAM.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7593168e87fc9e8b7adc62c297c80e2f190b768f
--- /dev/null
+++ b/weights/graph_transformer/GraphCAM.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ab59c0a07d8a566f22ccece1d0ba4f05271be0c3927c5362bb4b7e220c432cb
+size 577432