File size: 4,665 Bytes
7d48fbe 8783164 7d48fbe c311b69 8783164 7d48fbe 0914710 1e100ac 8783164 1e100ac 8783164 0914710 c311b69 0914710 8783164 c311b69 0914710 c311b69 0914710 7d48fbe 8783164 7d48fbe 5350122 8783164 0914710 7d48fbe c311b69 7d48fbe 6756dd2 7d48fbe 6756dd2 7d48fbe 0914710 fd5c95e c311b69 fd5c95e 0914710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from datetime import datetime
from spaces import GPU as SPACES_GPU
from samgis_core.utilities.type_hints import LlistFloat, DictStrInt
from samgis_lisa_on_zero.io.geo_helpers import get_vectorized_raster_as_geojson
from samgis_lisa_on_zero.io.raster_helpers import write_raster_png, write_raster_tiff
from samgis_lisa_on_zero.io.tms2geotiff import download_extent
from samgis_lisa_on_zero.utilities.constants import DEFAULT_URL_TILES, LISA_INFERENCE_FN
msg_write_tmp_on_disk = "found option to write images and geojson output..."
def load_model_and_inference_fn(inference_function_name_key: str):
from samgis_lisa_on_zero import app_logger
from lisa_on_cuda.utils import app_helpers
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
if models_dict[inference_function_name_key]["inference"] is None:
app_logger.info(f"missing inference function {inference_function_name_key}, instantiating it now!")
parsed_args = app_helpers.parse_args([])
inference_fn = app_helpers.get_inference_model_by_args(
parsed_args,
internal_logger0=app_logger,
inference_decorator=SPACES_GPU
)
models_dict[inference_function_name_key]["inference"] = inference_fn
def lisa_predict(
bbox: LlistFloat,
prompt: str,
zoom: float,
inference_function_name_key: str = LISA_INFERENCE_FN,
source: str = DEFAULT_URL_TILES,
source_name: str = None
) -> DictStrInt:
"""
Return predictions as a geojson from a geo-referenced image using the given input prompt.
1. if necessary instantiate a segment anything machine learning instance model
2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
3. get a prediction image from the segment anything instance model using the input prompt
4. get a geo-referenced geojson from the prediction image
Args:
bbox: coordinates bounding box
prompt: machine learning input prompt
zoom: Level of detail
inference_function_name_key: machine learning model name
source: xyz
source_name: name of tile provider
Returns:
Affine transform
"""
from os import getenv
from samgis_lisa_on_zero import app_logger
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
app_logger.info("start lisa inference...")
load_model_and_inference_fn(inference_function_name_key)
app_logger.debug(f"using a {inference_function_name_key} instance model...")
inference_fn = models_dict[inference_function_name_key]["inference"]
pt0, pt1 = bbox
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
app_logger.info(
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "")
prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_"
if bool(folder_write_tmp_on_disk):
now = datetime.now().strftime('%Y%m%d_%H%M%S')
app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.")
if img.shape and len(img.shape) == 2:
write_raster_tiff(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk)
if img.shape and len(img.shape) == 3 and img.shape[2] == 3:
write_raster_png(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_img", folder_write_tmp_on_disk)
else:
app_logger.info("keep all temp data in memory...")
app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
app_logger.info(f"lisa_zero, prompt tpye:{type(prompt)}.")
app_logger.info(f"lisa_zero, prompt:{prompt}.")
prompt_str = str(prompt)
app_logger.info(f"lisa_zero, img tpye:{type(img)}.")
embedding_key = f"{source_name}_z{zoom}_{prefix}"
_, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
app_logger.info(f"lisa_zero, output_string tpye:{type(output_string)}.")
app_logger.info(f"lisa_zero, output_string:{output_string}.")
app_logger.info(f"lisa_zero, mask_output tpye:{type(mask)}.")
app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
return {
"output_string": output_string,
**get_vectorized_raster_as_geojson(mask, transform)
}
|