lisa-on-cuda / tests /test_app_helpers.py
alessandro trinca tornidor
tests: fix broken test case because of missing pathlib import
71ce930
import logging
import unittest
class TestAppBuilders(unittest.TestCase):
def test_default_creation(self):
from lisa_on_cuda.utils import utils
placeholders = utils.create_placeholder_variables()
self.assertIsInstance(placeholders, dict)
assert placeholders["no_seg_out"].shape == (512, 512, 3)
assert placeholders["error_happened"].shape == (512, 512, 3)
def test_parse_args(self):
from lisa_on_cuda.utils import app_helpers
from lisa_on_cuda.utils import utils
test_args_parse = app_helpers.parse_args([])
assert vars(test_args_parse) == {
'version': 'xinlai/LISA-13B-llama2-v1-explanatory',
'vis_save_path': str(utils.VIS_OUTPUT),
'precision': 'fp16',
'image_size': 1024,
'model_max_length': 512,
'lora_r': 8,
'vision_tower': 'openai/clip-vit-large-patch14',
'local_rank': 0,
'load_in_8bit': False,
'load_in_4bit': True,
'use_mm_start_end': True,
'conv_type': 'llava_v1'
}
def test_inference(self):
import cv2
import numpy as np
from lisa_on_cuda.utils import app_helpers, constants, utils
max_diff = 0.02
logging.info("starting...")
logging.warning("Remember: before running again 'get_inference_model_by_args(test_args_parse)' free some memory")
test_args_parse = app_helpers.parse_args([])
inference_fn = app_helpers.get_inference_model_by_args(test_args_parse)
idx_example = 0
input_prompt, input_image_path = constants.examples[idx_example]
logging.info("running inference function with input prompt '{}'.".format(input_prompt))
_, output_mask, output_str = inference_fn(
input_prompt,
utils.ROOT / input_image_path
)
logging.info(f"output_str: {output_str}.")
expected_mask = cv2.imread(
str(utils.ROOT / "tests" / "imgs" / f"example{idx_example}_mask_0.png"),
cv2.IMREAD_GRAYSCALE
)
tot = output_mask.size
count = np.sum(output_mask != expected_mask)
perc = 100 * count / tot
logging.info(f"diff 1 vs 1b: {perc:.2f}!")
try:
assert np.array_equal(output_mask, expected_mask)
except AssertionError:
try:
logging.error("failed equality assertion!")
logging.info(f"assert now that perc diff between ndarrays is minor than {max_diff}.")
assert perc < max_diff
except AssertionError as ae:
logging.error("failed all assertions, writing debug files...")
import datetime
now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
output_folder = utils.ROOT / "tests" / "imgs"
prefix = f"broken_test_example{idx_example + 1}_{now_str}"
cv2.imwrite(
str(output_folder / f"{prefix}.png"),
output_mask
)
with open(output_folder / f"{prefix}__input_prompt.txt",
"w") as dst:
dst.write(input_prompt)
with open(output_folder / f"{prefix}__output_str.txt",
"w") as dst:
dst.write(output_str)
logging.info(f"Written files with prefix '{prefix}' in {output_folder} folder.")
raise ae
logging.info("end")
if __name__ == '__main__':
unittest.main()