File size: 4,012 Bytes
a0bcaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb0f5a9
 
a0bcaae
 
 
bb0f5a9
 
 
 
 
 
 
a0bcaae
 
 
 
 
bb0f5a9
a0bcaae
 
 
 
 
 
 
 
 
 
 
bb0f5a9
a0bcaae
bb0f5a9
 
a0bcaae
 
 
 
 
 
 
 
 
 
 
 
bb0f5a9
 
 
 
a0bcaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb0f5a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import re
import numpy as np
import imgui
import PIL.Image
from gui_utils import imgui_utils
from . import renderer
import torch
import torchvision

# ----------------------------------------------------------------------------


class CaptureWidget:
    def __init__(self, viz):
        self.viz = viz
        self.path = os.path.abspath(os.path.join(
            os.path.dirname(__file__), '..', '_screenshots'))
        self.dump_image = False
        self.dump_gui = False
        self.defer_frames = 0
        self.disabled_time = 0

    def dump_png(self, image):
        viz = self.viz
        try:
            _height, _width, channels = image.shape
            print(viz.result)
            assert image.dtype == np.uint8
            os.makedirs(self.path, exist_ok=True)
            file_id = 0
            for entry in os.scandir(self.path):
                if entry.is_file():
                    match = re.fullmatch(r'(\d+).*', entry.name)
                    if match:
                        file_id = max(file_id, int(match.group(1)) + 1)
            if channels == 1:
                pil_image = PIL.Image.fromarray(image[:, :, 0], 'L')
            else:
                pil_image = PIL.Image.fromarray(image[:, :, :3], 'RGB')
            pil_image.save(os.path.join(self.path, f'{file_id:05d}.png'))
            np.save(os.path.join(
                self.path, f'{file_id:05d}.npy'), viz.result.w)
        except:
            viz.result.error = renderer.CapturedException()

    @imgui_utils.scoped_by_object_id
    def __call__(self, show=True):
        viz = self.viz
        if show:
            with imgui_utils.grayed_out(self.disabled_time != 0):
                imgui.text('Capture')
                imgui.same_line(viz.label_w)

                _changed, self.path = imgui_utils.input_text('##path', self.path, 1024,
                                                             flags=(
                                                                 imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE),
                                                             width=(-1),
                                                             help_text='PATH')
                if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '':
                    imgui.set_tooltip(self.path)
                imgui.text(' ')
                imgui.same_line(viz.label_w)
                if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)):
                    self.dump_image = True
                    self.defer_frames = 2
                    self.disabled_time = 0.5
                imgui.same_line()
                if imgui_utils.button('Save GUI', width=viz.button_w, enabled=(self.disabled_time == 0)):
                    self.dump_gui = True
                    self.defer_frames = 2
                    self.disabled_time = 0.5

        self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
        if self.defer_frames > 0:
            self.defer_frames -= 1
        elif self.dump_image:
            if 'image' in viz.result:
                self.dump_png(viz.result.image)
            self.dump_image = False
        elif self.dump_gui:
            viz.capture_next_frame()
            self.dump_gui = False
        captured_frame = viz.pop_captured_frame()
        if captured_frame is not None:
            self.dump_png(captured_frame)

# ----------------------------------------------------------------------------