Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
import imgui | |
import dnnlib | |
from gui_utils import imgui_utils | |
# ---------------------------------------------------------------------------- | |
class DragWidget: | |
def __init__(self, viz): | |
self.viz = viz | |
self.point = [-1, -1] | |
self.points = [] | |
self.targets = [] | |
self.is_point = True | |
self.last_click = False | |
self.is_drag = False | |
self.iteration = 0 | |
self.mode = 'point' | |
self.r_mask = 50 | |
self.show_mask = False | |
self.mask = torch.ones(256, 256) | |
self.lambda_mask = 20 | |
self.feature_idx = 5 | |
self.r1 = 3 | |
self.r2 = 12 | |
self.path = os.path.abspath(os.path.join( | |
os.path.dirname(__file__), '..', '_screenshots')) | |
self.defer_frames = 0 | |
self.disabled_time = 0 | |
def action(self, click, down, x, y): | |
if self.mode == 'point': | |
self.add_point(click, x, y) | |
elif down: | |
self.draw_mask(x, y) | |
def add_point(self, click, x, y): | |
if click: | |
self.point = [y, x] | |
elif self.last_click: | |
if self.is_drag: | |
self.stop_drag() | |
if self.is_point: | |
self.points.append(self.point) | |
self.is_point = False | |
else: | |
self.targets.append(self.point) | |
self.is_point = True | |
self.last_click = click | |
def init_mask(self, w, h): | |
self.width, self.height = w, h | |
self.mask = torch.ones(h, w) | |
def draw_mask(self, x, y): | |
X = torch.linspace(0, self.width, self.width) | |
Y = torch.linspace(0, self.height, self.height) | |
yy, xx = torch.meshgrid(Y, X) | |
circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2 | |
if self.mode == 'flexible': | |
self.mask[circle] = 0 | |
elif self.mode == 'fixed': | |
self.mask[circle] = 1 | |
def stop_drag(self): | |
self.is_drag = False | |
self.iteration = 0 | |
def set_points(self, points): | |
self.points = points | |
def reset_point(self): | |
self.points = [] | |
self.targets = [] | |
self.is_point = True | |
def load_points(self, suffix): | |
points = [] | |
point_path = self.path + f'_{suffix}.txt' | |
try: | |
with open(point_path, "r") as f: | |
for line in f.readlines(): | |
y, x = line.split() | |
points.append([int(y), int(x)]) | |
except: | |
print(f'Wrong point file path: {point_path}') | |
return points | |
def __call__(self, show=True): | |
viz = self.viz | |
reset = False | |
if show: | |
with imgui_utils.grayed_out(self.disabled_time != 0): | |
imgui.text('Drag') | |
imgui.same_line(viz.label_w) | |
if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result): | |
self.mode = 'point' | |
imgui.same_line() | |
reset = False | |
if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result): | |
self.reset_point() | |
reset = True | |
imgui.text(' ') | |
imgui.same_line(viz.label_w) | |
if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result): | |
self.is_drag = True | |
if len(self.points) > len(self.targets): | |
self.points = self.points[:len(self.targets)] | |
imgui.same_line() | |
if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result): | |
self.stop_drag() | |
imgui.text(' ') | |
imgui.same_line(viz.label_w) | |
imgui.text(f'Steps: {self.iteration}') | |
imgui.text('Mask') | |
imgui.same_line(viz.label_w) | |
if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result): | |
self.mode = 'flexible' | |
self.show_mask = True | |
imgui.same_line() | |
if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result): | |
self.mode = 'fixed' | |
self.show_mask = True | |
imgui.text(' ') | |
imgui.same_line(viz.label_w) | |
if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result): | |
self.mask = torch.ones(self.height, self.width) | |
imgui.same_line() | |
_clicked, self.show_mask = imgui.checkbox( | |
'Show mask', self.show_mask) | |
imgui.text(' ') | |
imgui.same_line(viz.label_w) | |
with imgui_utils.item_width(viz.font_size * 6): | |
changed, self.r_mask = imgui.input_int( | |
'Radius', self.r_mask) | |
imgui.text(' ') | |
imgui.same_line(viz.label_w) | |
with imgui_utils.item_width(viz.font_size * 6): | |
changed, self.lambda_mask = imgui.input_int( | |
'Lambda', self.lambda_mask) | |
self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) | |
if self.defer_frames > 0: | |
self.defer_frames -= 1 | |
viz.args.is_drag = self.is_drag | |
if self.is_drag: | |
self.iteration += 1 | |
viz.args.iteration = self.iteration | |
viz.args.points = [point for point in self.points] | |
viz.args.targets = [point for point in self.targets] | |
viz.args.mask = self.mask | |
viz.args.lambda_mask = self.lambda_mask | |
viz.args.feature_idx = self.feature_idx | |
viz.args.r1 = self.r1 | |
viz.args.r2 = self.r2 | |
viz.args.reset = reset | |
# ---------------------------------------------------------------------------- | |