File size: 3,615 Bytes
acbbf71
 
 
 
 
 
 
60fa201
acbbf71
 
 
 
 
 
 
60fa201
ca22ec3
acbbf71
 
 
 
ca22ec3
acbbf71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60fa201
acbbf71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ce930
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import logging
import unittest


class TestAppBuilders(unittest.TestCase):

    def test_default_creation(self):
        from lisa_on_cuda.utils import utils

        placeholders = utils.create_placeholder_variables()
        self.assertIsInstance(placeholders, dict)
        assert placeholders["no_seg_out"].shape == (512, 512, 3)
        assert placeholders["error_happened"].shape == (512, 512, 3)

    def test_parse_args(self):
        from lisa_on_cuda.utils import app_helpers
        from lisa_on_cuda.utils import utils

        test_args_parse = app_helpers.parse_args([])
        assert vars(test_args_parse) == {
            'version': 'xinlai/LISA-13B-llama2-v1-explanatory',
            'vis_save_path': str(utils.VIS_OUTPUT),
            'precision': 'fp16',
            'image_size': 1024,
            'model_max_length': 512,
            'lora_r': 8,
            'vision_tower': 'openai/clip-vit-large-patch14',
            'local_rank': 0,
            'load_in_8bit': False,
            'load_in_4bit': True,
            'use_mm_start_end': True,
            'conv_type': 'llava_v1'
        }

    def test_inference(self):
        import cv2
        import numpy as np
        from lisa_on_cuda.utils import app_helpers, constants, utils

        max_diff = 0.02

        logging.info("starting...")
        logging.warning("Remember: before running again 'get_inference_model_by_args(test_args_parse)' free some memory")
        test_args_parse = app_helpers.parse_args([])
        inference_fn = app_helpers.get_inference_model_by_args(test_args_parse)
        idx_example = 0
        input_prompt, input_image_path = constants.examples[idx_example]
        logging.info("running inference function with input prompt '{}'.".format(input_prompt))
        _, output_mask, output_str = inference_fn(
            input_prompt,
            utils.ROOT / input_image_path
        )
        logging.info(f"output_str: {output_str}.")
        expected_mask = cv2.imread(
            str(utils.ROOT / "tests" / "imgs" / f"example{idx_example}_mask_0.png"),
            cv2.IMREAD_GRAYSCALE
        )

        tot = output_mask.size
        count = np.sum(output_mask != expected_mask)
        perc = 100 * count / tot

        logging.info(f"diff 1 vs 1b: {perc:.2f}!")
        try:
            assert np.array_equal(output_mask, expected_mask)
        except AssertionError:
            try:
                logging.error("failed equality assertion!")
                logging.info(f"assert now that perc diff between ndarrays is minor than {max_diff}.")
                assert perc < max_diff
            except AssertionError as ae:
                logging.error("failed all assertions, writing debug files...")
                import datetime
                now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
                output_folder = utils.ROOT / "tests" / "imgs"
                prefix = f"broken_test_example{idx_example + 1}_{now_str}"
                cv2.imwrite(
                    str(output_folder / f"{prefix}.png"),
                    output_mask
                )
                with open(output_folder / f"{prefix}__input_prompt.txt",
                          "w") as dst:
                    dst.write(input_prompt)
                with open(output_folder / f"{prefix}__output_str.txt",
                          "w") as dst:
                    dst.write(output_str)
                logging.info(f"Written files with prefix '{prefix}' in {output_folder} folder.")
                raise ae
        logging.info("end")


if __name__ == '__main__':
    unittest.main()