File size: 4,665 Bytes
7d48fbe
8783164
7d48fbe
c311b69
8783164
 
 
 
7d48fbe
 
0914710
 
1e100ac
8783164
 
 
 
 
 
 
 
 
1e100ac
8783164
 
 
 
 
0914710
c311b69
0914710
 
8783164
c311b69
 
 
0914710
 
 
 
 
 
 
 
 
 
 
 
 
 
c311b69
0914710
 
 
 
7d48fbe
8783164
 
7d48fbe
5350122
8783164
0914710
 
 
 
 
 
 
 
7d48fbe
c311b69
7d48fbe
 
 
 
6756dd2
7d48fbe
6756dd2
7d48fbe
 
0914710
fd5c95e
 
 
 
 
c311b69
fd5c95e
 
 
 
 
0914710
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from datetime import datetime
from spaces import GPU as SPACES_GPU

from samgis_core.utilities.type_hints import LlistFloat, DictStrInt
from samgis_lisa_on_zero.io.geo_helpers import get_vectorized_raster_as_geojson
from samgis_lisa_on_zero.io.raster_helpers import write_raster_png, write_raster_tiff
from samgis_lisa_on_zero.io.tms2geotiff import download_extent
from samgis_lisa_on_zero.utilities.constants import DEFAULT_URL_TILES, LISA_INFERENCE_FN

msg_write_tmp_on_disk = "found option to write images and geojson output..."


def load_model_and_inference_fn(inference_function_name_key: str):
    from samgis_lisa_on_zero import app_logger
    from lisa_on_cuda.utils import app_helpers
    from samgis_lisa_on_zero.prediction_api.global_models import models_dict

    if models_dict[inference_function_name_key]["inference"] is None:
        app_logger.info(f"missing inference function {inference_function_name_key}, instantiating it now!")
        parsed_args = app_helpers.parse_args([])
        inference_fn = app_helpers.get_inference_model_by_args(
            parsed_args,
            internal_logger0=app_logger,
            inference_decorator=SPACES_GPU
        )
        models_dict[inference_function_name_key]["inference"] = inference_fn


def lisa_predict(
        bbox: LlistFloat,
        prompt: str,
        zoom: float,
        inference_function_name_key: str = LISA_INFERENCE_FN,
        source: str = DEFAULT_URL_TILES,
        source_name: str = None
) -> DictStrInt:
    """
    Return predictions as a geojson from a geo-referenced image using the given input prompt.

    1. if necessary instantiate a segment anything machine learning instance model
    2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
    3. get a prediction image from the segment anything instance model using the input prompt
    4. get a geo-referenced geojson from the prediction image

    Args:
        bbox: coordinates bounding box
        prompt: machine learning input prompt
        zoom: Level of detail
        inference_function_name_key: machine learning model name
        source: xyz
        source_name: name of tile provider

    Returns:
        Affine transform
    """
    from os import getenv
    from samgis_lisa_on_zero import app_logger
    from samgis_lisa_on_zero.prediction_api.global_models import models_dict

    app_logger.info("start lisa inference...")
    load_model_and_inference_fn(inference_function_name_key)
    app_logger.debug(f"using a {inference_function_name_key} instance model...")
    inference_fn = models_dict[inference_function_name_key]["inference"]

    pt0, pt1 = bbox
    app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
    img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
    app_logger.info(
        f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
    folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "")
    prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_"
    if bool(folder_write_tmp_on_disk):
        now = datetime.now().strftime('%Y%m%d_%H%M%S')
        app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.")
        if img.shape and len(img.shape) == 2:
            write_raster_tiff(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk)
        if img.shape and len(img.shape) == 3 and img.shape[2] == 3:
            write_raster_png(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_img", folder_write_tmp_on_disk)
    else:
        app_logger.info("keep all temp data in memory...")

    app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
    app_logger.info(f"lisa_zero, prompt tpye:{type(prompt)}.")
    app_logger.info(f"lisa_zero, prompt:{prompt}.")
    prompt_str = str(prompt)
    app_logger.info(f"lisa_zero, img tpye:{type(img)}.")
    embedding_key = f"{source_name}_z{zoom}_{prefix}"
    _, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
    app_logger.info(f"lisa_zero, output_string tpye:{type(output_string)}.")
    app_logger.info(f"lisa_zero, output_string:{output_string}.")
    app_logger.info(f"lisa_zero, mask_output tpye:{type(mask)}.")
    app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
    return {
        "output_string": output_string,
        **get_vectorized_raster_as_geojson(mask, transform)
    }