File size: 3,764 Bytes
5b24075
 
 
 
3d2072b
5b24075
 
 
63ad23b
5b24075
63ad23b
5b24075
63ad23b
 
5b24075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249c808
5b24075
 
 
 
 
 
 
 
 
 
3d2072b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b24075
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
import leafmap.foliumap as leafmap
from transformers import PretrainedConfig
from folium import Icon
import os

from messis.messis import Messis
from inference import perform_inference
from inference import generate_presigned_url

from dotenv import load_dotenv

load_dotenv()
st.set_page_config(layout="wide")

# Load the model
@st.cache_resource
def load_model():
    config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4')
    model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4')
    return model, config
model, config = load_model()

def perform_inference_step():
    st.title("Step 2: Perform Crop Classification")

    if "selected_location" not in st.session_state:
        st.error("No location selected. Please select a location first.")
        st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
        return

    lat, lon = st.session_state["selected_location"]

    # Sidebar
    st.sidebar.header("Settings")

    # Timestep Slider
    timestep = st.sidebar.slider("Select Timestep", 1, 9, 5)

    # Band Dropdown
    band_options = {
        "RGB": [1, 2, 3],  # Adjust indices based on the actual bands in your GeoTIFF
        "NIR": [4],
        "SWIR1": [5],
        "SWIR2": [6]
    }
    vmin_vmax = { 
        "RGB": (89, 1878),
        "NIR": (165, 5468),
        "SWIR1": (120, 3361),
        "SWIR2": (94, 2700)
    }
    selected_band = st.sidebar.selectbox("Select Satellite Band to Display", options=list(band_options.keys()), index=0)
    
    # Calculate the band indices based on the selected timestep
    selected_bands = [band + (timestep - 1) * 6 for band in band_options[selected_band]]
    
    instructions = """
    Click the button "Perform Crop Classification".

    _Note:_ 
    - Messis will classify the crop types for the fields in your selected location.
    - Hover over the fields to see the predicted and true crop type.
    - The satellite images might take a few seconds to load.
    """
    st.sidebar.header("Instructions")
    st.sidebar.markdown(instructions)

    # Initialize the map
    m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False)

    # Perform inference
    if st.button("Perform Crop Classification", type="primary"):
        predictions = perform_inference(lon, lat, model, config, debug=True)

        m.add_data(predictions,
            layer_name = "Predictions",
            column="Correct",
            add_legend=False,
            style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25},
        )
        st.success("Inference completed!")

    # Add Satellite Imagery
    if os.environ.get("USE_LOCAL_DATA") == "True":
        m.add_raster(
            "./data/stacked_features_cog.tif",
            layer_name="Satellite Image",
            bands=selected_bands,
            fit_bounds=True,
            vmin=vmin_vmax[selected_band][0],
            vmax=vmin_vmax[selected_band][1],
        )
    else:
        presigned_url = generate_presigned_url('messis-demo', 'stacked_features_cog.tif')
        m.add_cog_layer(
            url=presigned_url,
            name="Sentinel-2 Satellite Imagery",
            bands=selected_bands,
            rescale=f"{vmin_vmax[selected_band][0]},{vmin_vmax[selected_band][1]}"
        )

    # Show the POI on the map
    poi_icon = Icon(color="green", prefix="fa", icon="crosshairs")
    m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon)

    # Display the map in the Streamlit app
    m.to_streamlit()

if __name__ == "__main__":
    perform_inference_step()