Spaces:
Paused
Paused
import logging | |
import unittest | |
class TestAppBuilders(unittest.TestCase): | |
def test_default_creation(self): | |
from lisa_on_cuda.utils import utils | |
placeholders = utils.create_placeholder_variables() | |
self.assertIsInstance(placeholders, dict) | |
assert placeholders["no_seg_out"].shape == (512, 512, 3) | |
assert placeholders["error_happened"].shape == (512, 512, 3) | |
def test_parse_args(self): | |
from lisa_on_cuda.utils import app_helpers | |
from lisa_on_cuda.utils import utils | |
test_args_parse = app_helpers.parse_args([]) | |
assert vars(test_args_parse) == { | |
'version': 'xinlai/LISA-13B-llama2-v1-explanatory', | |
'vis_save_path': str(utils.VIS_OUTPUT), | |
'precision': 'fp16', | |
'image_size': 1024, | |
'model_max_length': 512, | |
'lora_r': 8, | |
'vision_tower': 'openai/clip-vit-large-patch14', | |
'local_rank': 0, | |
'load_in_8bit': False, | |
'load_in_4bit': True, | |
'use_mm_start_end': True, | |
'conv_type': 'llava_v1' | |
} | |
def test_inference(self): | |
import cv2 | |
import numpy as np | |
from lisa_on_cuda.utils import app_helpers, constants, utils | |
max_diff = 0.02 | |
logging.info("starting...") | |
logging.warning("Remember: before running again 'get_inference_model_by_args(test_args_parse)' free some memory") | |
test_args_parse = app_helpers.parse_args([]) | |
inference_fn = app_helpers.get_inference_model_by_args(test_args_parse) | |
idx_example = 0 | |
input_prompt, input_image_path = constants.examples[idx_example] | |
logging.info("running inference function with input prompt '{}'.".format(input_prompt)) | |
_, output_mask, output_str = inference_fn( | |
input_prompt, | |
utils.ROOT / input_image_path | |
) | |
logging.info(f"output_str: {output_str}.") | |
expected_mask = cv2.imread( | |
str(utils.ROOT / "tests" / "imgs" / f"example{idx_example}_mask_0.png"), | |
cv2.IMREAD_GRAYSCALE | |
) | |
tot = output_mask.size | |
count = np.sum(output_mask != expected_mask) | |
perc = 100 * count / tot | |
logging.info(f"diff 1 vs 1b: {perc:.2f}!") | |
try: | |
assert np.array_equal(output_mask, expected_mask) | |
except AssertionError: | |
try: | |
logging.error("failed equality assertion!") | |
logging.info(f"assert now that perc diff between ndarrays is minor than {max_diff}.") | |
assert perc < max_diff | |
except AssertionError as ae: | |
logging.error("failed all assertions, writing debug files...") | |
import datetime | |
now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S") | |
output_folder = utils.ROOT / "tests" / "imgs" | |
prefix = f"broken_test_example{idx_example + 1}_{now_str}" | |
cv2.imwrite( | |
str(output_folder / f"{prefix}.png"), | |
output_mask | |
) | |
with open(output_folder / f"{prefix}__input_prompt.txt", | |
"w") as dst: | |
dst.write(input_prompt) | |
with open(output_folder / f"{prefix}__output_str.txt", | |
"w") as dst: | |
dst.write(output_str) | |
logging.info(f"Written files with prefix '{prefix}' in {output_folder} folder.") | |
raise ae | |
logging.info("end") | |
if __name__ == '__main__': | |
unittest.main() | |