File size: 5,066 Bytes
7d48fbe
 
0677d1d
c311b69
0677d1d
 
bb6e1c0
00f8875
 
 
8783164
7d48fbe
 
0914710
 
0677d1d
1e100ac
8783164
 
 
 
a5e4002
 
 
8783164
 
 
1e100ac
8783164
 
 
 
 
0677d1d
0914710
c311b69
0914710
 
8783164
c311b69
 
 
0914710
 
 
 
 
 
 
 
 
 
 
 
 
 
c311b69
0914710
 
 
 
7d48fbe
8783164
7d48fbe
a5e4002
 
 
5350122
a5e4002
 
 
8783164
a5e4002
0914710
a5e4002
0914710
 
 
 
 
 
7d48fbe
c311b69
7d48fbe
 
 
 
6756dd2
7d48fbe
6756dd2
7d48fbe
 
0914710
fd5c95e
a5e4002
fd5c95e
 
a5e4002
c311b69
fd5c95e
a5e4002
 
fd5c95e
0914710
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from datetime import datetime

from lisa_on_cuda.utils import session_logger
from samgis_core.utilities.type_hints import LlistFloat, DictStrInt
from spaces import GPU as SPACES_GPU

from samgis_lisa_on_zero import app_logger
from samgis_lisa_on_zero.io_package.geo_helpers import get_vectorized_raster_as_geojson
from samgis_lisa_on_zero.io_package.raster_helpers import write_raster_png, write_raster_tiff
from samgis_lisa_on_zero.io_package.tms2geotiff import download_extent
from samgis_lisa_on_zero.utilities.constants import DEFAULT_URL_TILES, LISA_INFERENCE_FN

msg_write_tmp_on_disk = "found option to write images and geojson output..."


@session_logger.set_uuid_logging
def load_model_and_inference_fn(inference_function_name_key: str):
    from lisa_on_cuda.utils import app_helpers
    from samgis_lisa_on_zero.prediction_api.global_models import models_dict

    if models_dict[inference_function_name_key]["inference"] is None:
        msg = f"missing inference function {inference_function_name_key}, "
        msg += f"instantiating it now using inference_decorator {SPACES_GPU}!"
        app_logger.info(msg)
        parsed_args = app_helpers.parse_args([])
        inference_fn = app_helpers.get_inference_model_by_args(
            parsed_args,
            internal_logger0=app_logger,
            inference_decorator=SPACES_GPU
        )
        models_dict[inference_function_name_key]["inference"] = inference_fn


@session_logger.set_uuid_logging
def lisa_predict(
        bbox: LlistFloat,
        prompt: str,
        zoom: float,
        inference_function_name_key: str = LISA_INFERENCE_FN,
        source: str = DEFAULT_URL_TILES,
        source_name: str = None
) -> DictStrInt:
    """
    Return predictions as a geojson from a geo-referenced image using the given input prompt.

    1. if necessary instantiate a segment anything machine learning instance model
    2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
    3. get a prediction image from the segment anything instance model using the input prompt
    4. get a geo-referenced geojson from the prediction image

    Args:
        bbox: coordinates bounding box
        prompt: machine learning input prompt
        zoom: Level of detail
        inference_function_name_key: machine learning model name
        source: xyz
        source_name: name of tile provider

    Returns:
        Affine transform
    """
    from os import getenv
    from samgis_lisa_on_zero.prediction_api.global_models import models_dict

    if source_name is None:
        source_name = str(source)

    app_logger.info("start lisa inference...")
    app_logger.debug(f"type(source):{type(source)}, source:{source},")
    app_logger.debug(f"type(source_name):{type(source_name)}, source_name:{source_name}.")

    load_model_and_inference_fn(inference_function_name_key)
    app_logger.debug(f"using a '{inference_function_name_key}' instance model...")
    inference_fn = models_dict[inference_function_name_key]["inference"]
    app_logger.info(f"loaded inference function '{inference_fn.__name__}'.")

    pt0, pt1 = bbox
    app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
    img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
    app_logger.info(
        f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
    folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "")
    prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_"
    if bool(folder_write_tmp_on_disk):
        now = datetime.now().strftime('%Y%m%d_%H%M%S')
        app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.")
        if img.shape and len(img.shape) == 2:
            write_raster_tiff(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk)
        if img.shape and len(img.shape) == 3 and img.shape[2] == 3:
            write_raster_png(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_img", folder_write_tmp_on_disk)
    else:
        app_logger.info("keep all temp data in memory...")

    app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
    app_logger.info(f"lisa_zero, prompt type:{type(prompt)}.")
    app_logger.info(f"lisa_zero, prompt:{prompt}.")
    prompt_str = str(prompt)
    app_logger.info(f"lisa_zero, img type:{type(img)}.")
    embedding_key = f"{source_name}_z{zoom}_{prefix}"
    _, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
    app_logger.info(f"lisa_zero, output_string type:{type(output_string)}.")
    app_logger.info(f"lisa_zero, mask_output type:{type(mask)}.")
    app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
    return {
        "output_string": output_string,
        **get_vectorized_raster_as_geojson(mask, transform)
    }