|
import json |
|
import os |
|
import pathlib |
|
import uuid |
|
|
|
import gradio as gr |
|
import uvicorn |
|
from fastapi import FastAPI, HTTPException, Request, status |
|
from fastapi.exceptions import RequestValidationError |
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists |
|
from pydantic import ValidationError |
|
from spaces import GPU as SPACES_GPU |
|
|
|
from samgis_core.utilities.fastapi_logger import setup_logging |
|
from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR |
|
from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody |
|
|
|
|
|
loglevel = os.getenv('LOGLEVEL', 'INFO').upper() |
|
app_logger = setup_logging(debug=loglevel) |
|
|
|
CUSTOM_INDEX_URL = os.getenv("CUSTOM_INDEX_URL", "/static") |
|
CUSTOM_SAMGIS_URL = os.getenv("CUSTOM_SAMGIS_URL", "/samgis") |
|
CUSTOM_LISA_URL = os.getenv("CUSTOM_LISA_URL", "/lisa") |
|
CUSTOM_GRADIO_URL = os.getenv("CUSTOM_GRADIO_URL", "/") |
|
FASTAPI_TITLE = "samgis-lisa-on-zero" |
|
app = FastAPI(title=FASTAPI_TITLE, version="1.0") |
|
|
|
|
|
@app.middleware("http") |
|
async def request_middleware(request, call_next): |
|
request_id = str(uuid.uuid4()) |
|
with app_logger.contextualize(request_id=request_id): |
|
app_logger.info("Request started") |
|
|
|
try: |
|
response = await call_next(request) |
|
|
|
except Exception as ex: |
|
app_logger.error(f"Request failed: {ex}") |
|
response = JSONResponse(content={"success": False}, status_code=500) |
|
|
|
finally: |
|
response.headers["X-Request-ID"] = request_id |
|
app_logger.info("Request ended") |
|
|
|
return response |
|
|
|
|
|
@app.post("/post_test_dictlist") |
|
def post_test_dictlist2(request_input: ApiRequestBody) -> JSONResponse: |
|
from samgis_lisa_on_zero.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt |
|
|
|
request_body = get_parsed_bbox_points_with_dictlist_prompt(request_input) |
|
app_logger.info(f"request_body:{request_body}.") |
|
return JSONResponse( |
|
status_code=200, |
|
content=request_body |
|
) |
|
|
|
|
|
@app.get("/health") |
|
async def health() -> JSONResponse: |
|
import importlib.metadata |
|
from importlib.metadata import PackageNotFoundError |
|
|
|
core_version = lisa_on_cuda_version = samgis_lisa_on_cuda_version = "" |
|
try: |
|
core_version = importlib.metadata.version('samgis_core') |
|
lisa_on_cuda_version = importlib.metadata.version('lisa-on-cuda') |
|
samgis_lisa_on_cuda_version = importlib.metadata.version('samgis-lisa-on-zero') |
|
except PackageNotFoundError as pe: |
|
app_logger.error(f"pe:{pe}.") |
|
|
|
msg = "still alive, " |
|
msg += f"""version:{samgis_lisa_on_cuda_version}, core version:{core_version},""" |
|
msg += f"""lisa-on-cuda version:{lisa_on_cuda_version},""" |
|
|
|
app_logger.info(msg) |
|
return JSONResponse(status_code=200, content={"msg": "still alive..."}) |
|
|
|
|
|
@app.post("/post_test_string") |
|
def post_test_string(request_input: StringPromptApiRequestBody) -> JSONResponse: |
|
from lisa_on_cuda.utils import app_helpers |
|
from samgis_lisa_on_zero.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt |
|
|
|
request_body = get_parsed_bbox_points_with_string_prompt(request_input) |
|
app_logger.info(f"request_body:{request_body}.") |
|
custom_args = app_helpers.parse_args([]) |
|
request_body["content"] = {**request_body, "precision": str(custom_args.precision)} |
|
return JSONResponse( |
|
status_code=200, |
|
content=request_body |
|
) |
|
|
|
|
|
@app.post("/infer_lisa") |
|
def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse: |
|
from samgis_lisa_on_zero.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt, get_source_name |
|
from samgis_lisa_on_zero.prediction_api import lisa |
|
from samgis_lisa_on_zero.utilities.constants import LISA_INFERENCE_FN |
|
|
|
app_logger.info("starting lisa inference request...") |
|
|
|
try: |
|
import time |
|
|
|
time_start_run = time.time() |
|
body_request = get_parsed_bbox_points_with_string_prompt(request_input) |
|
app_logger.info(f"lisa body_request:{body_request}.") |
|
app_logger.info(f"lisa module:{lisa}.") |
|
try: |
|
source_name = get_source_name(request_input.source_type) |
|
app_logger.info(f"source_name = {source_name}.") |
|
output = lisa.lisa_predict( |
|
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], |
|
source=body_request["source"], source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN |
|
) |
|
duration_run = time.time() - time_start_run |
|
app_logger.info(f"duration_run:{duration_run}.") |
|
body = { |
|
"duration_run": duration_run, |
|
"output": output |
|
} |
|
return JSONResponse(status_code=200, content={"body": json.dumps(body)}) |
|
except Exception as inference_exception: |
|
import subprocess |
|
project_root_folder_content = subprocess.run( |
|
f"ls -l {PROJECT_ROOT_FOLDER}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE |
|
) |
|
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.") |
|
workdir_folder_content = subprocess.run( |
|
f"ls -l {WORKDIR}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE |
|
) |
|
app_logger.error(f"workdir folder 'ls -l' command output: {workdir_folder_content.stdout}.") |
|
app_logger.error(f"inference error:{inference_exception}.") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference") |
|
except ValidationError as va1: |
|
app_logger.error(f"validation error: {str(va1)}.") |
|
raise ValidationError("Unprocessable Entity") |
|
|
|
|
|
@app.post("/infer_samgis") |
|
def infer_samgis(request_input: ApiRequestBody) -> JSONResponse: |
|
from samgis_lisa_on_zero.prediction_api import predictors |
|
from samgis_lisa_on_zero.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt, get_source_name |
|
|
|
app_logger.info("starting plain samgis inference request...") |
|
|
|
try: |
|
import time |
|
|
|
time_start_run = time.time() |
|
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input) |
|
app_logger.info(f"body_request:{body_request}.") |
|
try: |
|
source_name = get_source_name(request_input.source_type) |
|
app_logger.info(f"source_name = {source_name}.") |
|
output = predictors.samexporter_predict( |
|
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], |
|
source=body_request["source"], source_name=source_name |
|
) |
|
duration_run = time.time() - time_start_run |
|
app_logger.info(f"duration_run:{duration_run}.") |
|
body = { |
|
"duration_run": duration_run, |
|
"output": output |
|
} |
|
return JSONResponse(status_code=200, content={"body": json.dumps(body)}) |
|
except Exception as inference_exception: |
|
import subprocess |
|
project_root_folder_content = subprocess.run( |
|
f"ls -l {PROJECT_ROOT_FOLDER}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE |
|
) |
|
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.") |
|
workdir_folder_content = subprocess.run( |
|
f"ls -l {WORKDIR}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE |
|
) |
|
app_logger.error(f"workdir folder 'ls -l' command output: {workdir_folder_content.stdout}.") |
|
app_logger.error(f"inference error:{inference_exception}.") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference") |
|
except ValidationError as va1: |
|
app_logger.error(f"validation error: {str(va1)}.") |
|
raise ValidationError("Unprocessable Entity") |
|
|
|
|
|
@app.exception_handler(RequestValidationError) |
|
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: |
|
app_logger.error(f"exception errors: {exc.errors()}.") |
|
app_logger.error(f"exception body: {exc.body}.") |
|
headers = request.headers.items() |
|
app_logger.error(f'request header: {dict(headers)}.') |
|
params = request.query_params.items() |
|
app_logger.error(f'request query params: {dict(params)}.') |
|
return JSONResponse( |
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, |
|
content={"msg": "Error - Unprocessable Entity"} |
|
) |
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: |
|
app_logger.error(f"exception: {str(exc)}.") |
|
headers = request.headers.items() |
|
app_logger.error(f'request header: {dict(headers)}.') |
|
params = request.query_params.items() |
|
app_logger.error(f'request query params: {dict(params)}.') |
|
return JSONResponse( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
content={"msg": "Error - Internal Server Error"} |
|
) |
|
|
|
|
|
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "") |
|
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.") |
|
if bool(write_tmp_on_disk): |
|
try: |
|
path_write_tmp_on_disk = pathlib.Path(write_tmp_on_disk) |
|
try: |
|
pathlib.Path.unlink(path_write_tmp_on_disk, missing_ok=True) |
|
except (IsADirectoryError, PermissionError, OSError) as err: |
|
app_logger.error(f"{err} while removing old write_tmp_on_disk:{write_tmp_on_disk}.") |
|
app_logger.error(f"is file?{path_write_tmp_on_disk.is_file()}.") |
|
app_logger.error(f"is symlink?{path_write_tmp_on_disk.is_symlink()}.") |
|
app_logger.error(f"is folder?{path_write_tmp_on_disk.is_dir()}.") |
|
os.makedirs(write_tmp_on_disk, exist_ok=True) |
|
app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output") |
|
except RuntimeError as rerr: |
|
app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...") |
|
raise rerr |
|
templates = Jinja2Templates(directory=WORKDIR / "static") |
|
|
|
|
|
@app.get("/vis_output", response_class=HTMLResponse) |
|
def list_files(request: Request): |
|
|
|
files = os.listdir(write_tmp_on_disk) |
|
files_paths = sorted([f"{request.url._url}/{f}" for f in files]) |
|
print(files_paths) |
|
return templates.TemplateResponse( |
|
"list_files.html", {"request": request, "files": files_paths} |
|
) |
|
|
|
|
|
static_dist_folder = WORKDIR / "static" / "dist" |
|
frontend_builder.build_frontend( |
|
project_root_folder=frontend_builder.env_project_root_folder, |
|
input_css_path=frontend_builder.env_input_css_path, |
|
output_dist_folder=static_dist_folder |
|
) |
|
create_folders_and_variables_if_not_exists.folders_creation() |
|
|
|
app_logger.info("build_frontend ok!") |
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static") |
|
|
|
|
|
app.mount(CUSTOM_SAMGIS_URL, StaticFiles(directory=static_dist_folder, html=True), name="samgis") |
|
|
|
|
|
@app.get(CUSTOM_SAMGIS_URL) |
|
async def samgis() -> FileResponse: |
|
return FileResponse(path=static_dist_folder / "samgis.html", media_type="text/html") |
|
|
|
|
|
|
|
app.mount(CUSTOM_LISA_URL, StaticFiles(directory=static_dist_folder, html=True), name="lisa") |
|
|
|
|
|
@app.get(CUSTOM_LISA_URL) |
|
async def lisa() -> FileResponse: |
|
return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html") |
|
|
|
|
|
|
|
app.mount(CUSTOM_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index") |
|
|
|
|
|
@app.get(CUSTOM_INDEX_URL) |
|
async def index() -> FileResponse: |
|
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html") |
|
|
|
|
|
args = app_helpers.parse_args([]) |
|
app_helpers.app_logger.info(f"prepared default arguments:{args}.") |
|
inference_fn = app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU) |
|
|
|
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...") |
|
io = app_helpers.get_gradio_interface(inference_fn) |
|
app_helpers.app_logger.info(f"created gradio interface, mounting gradio app on url {CUSTOM_GRADIO_URL} within FastAPI...") |
|
app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_URL) |
|
app_helpers.app_logger.info("mounted gradio app within fastapi") |
|
|
|
|
|
if __name__ == '__main__': |
|
try: |
|
uvicorn.run(host="0.0.0.0", port=7860, app=app) |
|
except Exception as ex: |
|
import logging |
|
logging.error(f"fastapi/gradio application {FASTAPI_TITLE}, exception:{ex}.") |
|
print(f"fastapi/gradio application {FASTAPI_TITLE}, exception:{ex}.") |
|
raise ex |
|
|