File size: 11,526 Bytes
71d94dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright (c) Facebook, Inc. and its affiliates.

import os
import sys
import tempfile
from contextlib import ExitStack, contextmanager
from copy import deepcopy
from unittest import mock
import torch
from torch import nn

# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
import detectron2  # noqa F401
from detectron2.structures import Boxes, Instances
from detectron2.utils.env import _import_file

_counter = 0


def _clear_jit_cache():
    from torch.jit._recursive import concrete_type_store
    from torch.jit._state import _jit_caching_layer

    concrete_type_store.type_store.clear()  # for modules
    _jit_caching_layer.clear()  # for free functions


def _add_instances_conversion_methods(newInstances):
    """
    Add from_instances methods to the scripted Instances class.
    """
    cls_name = newInstances.__name__

    @torch.jit.unused
    def from_instances(instances: Instances):
        """
        Create scripted Instances from original Instances
        """
        fields = instances.get_fields()
        image_size = instances.image_size
        ret = newInstances(image_size)
        for name, val in fields.items():
            assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
            setattr(ret, name, deepcopy(val))
        return ret

    newInstances.from_instances = from_instances


@contextmanager
def patch_instances(fields):
    """
    A contextmanager, under which the Instances class in detectron2 is replaced
    by a statically-typed scriptable class, defined by `fields`.
    See more in `scripting_with_instances`.
    """

    with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
        mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
    ) as f:
        try:
            # Objects that use Instances should not reuse previously-compiled
            # results in cache, because `Instances` could be a new class each time.
            _clear_jit_cache()

            cls_name, s = _gen_instance_module(fields)
            f.write(s)
            f.flush()
            f.close()

            module = _import(f.name)
            new_instances = getattr(module, cls_name)
            _ = torch.jit.script(new_instances)
            # let torchscript think Instances was scripted already
            Instances.__torch_script_class__ = True
            # let torchscript find new_instances when looking for the jit type of Instances
            Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)

            _add_instances_conversion_methods(new_instances)
            yield new_instances
        finally:
            try:
                del Instances.__torch_script_class__
                del Instances._jit_override_qualname
            except AttributeError:
                pass
            sys.modules.pop(module.__name__)


def _gen_instance_class(fields):
    """
    Args:
        fields (dict[name: type])
    """

    class _FieldType:
        def __init__(self, name, type_):
            assert isinstance(name, str), f"Field name must be str, got {name}"
            self.name = name
            self.type_ = type_
            self.annotation = f"{type_.__module__}.{type_.__name__}"

    fields = [_FieldType(k, v) for k, v in fields.items()]

    def indent(level, s):
        return " " * 4 * level + s

    lines = []

    global _counter
    _counter += 1

    cls_name = "ScriptedInstances{}".format(_counter)

    field_names = tuple(x.name for x in fields)
    extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
    lines.append(
        f"""
class {cls_name}:
    def __init__(self, image_size: Tuple[int, int], {extra_args}):
        self.image_size = image_size
        self._field_names = {field_names}
"""
    )

    for f in fields:
        lines.append(
            indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
        )

    for f in fields:
        lines.append(
            f"""
    @property
    def {f.name}(self) -> {f.annotation}:
        # has to use a local for type refinement
        # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
        t = self._{f.name}
        assert t is not None, "{f.name} is None and cannot be accessed!"
        return t

    @{f.name}.setter
    def {f.name}(self, value: {f.annotation}) -> None:
        self._{f.name} = value
"""
        )

    # support method `__len__`
    lines.append(
        """
    def __len__(self) -> int:
"""
    )
    for f in fields:
        lines.append(
            f"""
        t = self._{f.name}
        if t is not None:
            return len(t)
"""
        )
    lines.append(
        """
        raise NotImplementedError("Empty Instances does not support __len__!")
"""
    )

    # support method `has`
    lines.append(
        """
    def has(self, name: str) -> bool:
"""
    )
    for f in fields:
        lines.append(
            f"""
        if name == "{f.name}":
            return self._{f.name} is not None
"""
        )
    lines.append(
        """
        return False
"""
    )

    # support method `to`
    none_args = ", None" * len(fields)
    lines.append(
        f"""
    def to(self, device: torch.device) -> "{cls_name}":
        ret = {cls_name}(self.image_size{none_args})
"""
    )
    for f in fields:
        if hasattr(f.type_, "to"):
            lines.append(
                f"""
        t = self._{f.name}
        if t is not None:
            ret._{f.name} = t.to(device)
"""
            )
        else:
            # For now, ignore fields that cannot be moved to devices.
            # Maybe can support other tensor-like classes (e.g. __torch_function__)
            pass
    lines.append(
        """
        return ret
"""
    )

    # support method `getitem`
    none_args = ", None" * len(fields)
    lines.append(
        f"""
    def __getitem__(self, item) -> "{cls_name}":
        ret = {cls_name}(self.image_size{none_args})
"""
    )
    for f in fields:
        lines.append(
            f"""
        t = self._{f.name}
        if t is not None:
            ret._{f.name} = t[item]
"""
        )
    lines.append(
        """
        return ret
"""
    )

    # support method `cat`
    # this version does not contain checks that all instances have same size and fields
    none_args = ", None" * len(fields)
    lines.append(
        f"""
    def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
        ret = {cls_name}(self.image_size{none_args})
"""
    )
    for f in fields:
        lines.append(
            f"""
        t = self._{f.name}
        if t is not None:
            values: List[{f.annotation}] = [x.{f.name} for x in instances]
            if torch.jit.isinstance(t, torch.Tensor):
                ret._{f.name} = torch.cat(values, dim=0)
            else:
                ret._{f.name} = t.cat(values)
"""
        )
    lines.append(
        """
        return ret"""
    )

    # support method `get_fields()`
    lines.append(
        """
    def get_fields(self) -> Dict[str, Tensor]:
        ret = {}
    """
    )
    for f in fields:
        if f.type_ == Boxes:
            stmt = "t.tensor"
        elif f.type_ == torch.Tensor:
            stmt = "t"
        else:
            stmt = f'assert False, "unsupported type {str(f.type_)}"'
        lines.append(
            f"""
        t = self._{f.name}
        if t is not None:
            ret["{f.name}"] = {stmt}
        """
        )
    lines.append(
        """
        return ret"""
    )
    return cls_name, os.linesep.join(lines)


def _gen_instance_module(fields):
    # TODO: find a more automatic way to enable import of other classes
    s = """
from copy import deepcopy
import torch
from torch import Tensor
import typing
from typing import *

import detectron2
from detectron2.structures import Boxes, Instances

"""

    cls_name, cls_def = _gen_instance_class(fields)
    s += cls_def
    return cls_name, s


def _import(path):
    return _import_file(
        "{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
    )


@contextmanager
def patch_builtin_len(modules=()):
    """
    Patch the builtin len() function of a few detectron2 modules
    to use __len__ instead, because __len__ does not convert values to
    integers and therefore is friendly to tracing.

    Args:
        modules (list[stsr]): names of extra modules to patch len(), in
            addition to those in detectron2.
    """

    def _new_len(obj):
        return obj.__len__()

    with ExitStack() as stack:
        MODULES = [
            "detectron2.modeling.roi_heads.fast_rcnn",
            "detectron2.modeling.roi_heads.mask_head",
            "detectron2.modeling.roi_heads.keypoint_head",
        ] + list(modules)
        ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
        for m in ctxs:
            m.side_effect = _new_len
        yield


def patch_nonscriptable_classes():
    """
    Apply patches on a few nonscriptable detectron2 classes.
    Should not have side-effects on eager usage.
    """
    # __prepare_scriptable__ can also be added to models for easier maintenance.
    # But it complicates the clean model code.

    from detectron2.modeling.backbone import ResNet, FPN

    # Due to https://github.com/pytorch/pytorch/issues/36061,
    # we change backbone to use ModuleList for scripting.
    # (note: this changes param names in state_dict)

    def prepare_resnet(self):
        ret = deepcopy(self)
        ret.stages = nn.ModuleList(ret.stages)
        for k in self.stage_names:
            delattr(ret, k)
        return ret

    ResNet.__prepare_scriptable__ = prepare_resnet

    def prepare_fpn(self):
        ret = deepcopy(self)
        ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
        ret.output_convs = nn.ModuleList(ret.output_convs)
        for name, _ in self.named_children():
            if name.startswith("fpn_"):
                delattr(ret, name)
        return ret

    FPN.__prepare_scriptable__ = prepare_fpn

    # Annotate some attributes to be constants for the purpose of scripting,
    # even though they are not constants in eager mode.
    from detectron2.modeling.roi_heads import StandardROIHeads

    if hasattr(StandardROIHeads, "__annotations__"):
        # copy first to avoid editing annotations of base class
        StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
        StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
        StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]


# These patches are not supposed to have side-effects.
patch_nonscriptable_classes()


@contextmanager
def freeze_training_mode(model):
    """
    A context manager that annotates the "training" attribute of every submodule
    to constant, so that the training codepath in these modules can be
    meta-compiled away. Upon exiting, the annotations are reverted.
    """
    classes = {type(x) for x in model.modules()}
    # __constants__ is the old way to annotate constants and not compatible
    # with __annotations__ .
    classes = {x for x in classes if not hasattr(x, "__constants__")}
    for cls in classes:
        cls.__annotations__["training"] = torch.jit.Final[bool]
    yield
    for cls in classes:
        cls.__annotations__["training"] = bool