File size: 5,028 Bytes
8683813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import torch
from annotator.oneformer.detectron2.utils.file_io import PathManager
from .torchscript_patch import freeze_training_mode, patch_instances
__all__ = ["scripting_with_instances", "dump_torchscript_IR"]
def scripting_with_instances(model, fields):
"""
Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
for scripting to support it out of the box. This function is made to support scripting
a model that uses :class:`Instances`. It does the following:
1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
but with all attributes been "static".
The attributes need to be statically declared in the ``fields`` argument.
2. Register ``new_Instances``, and force scripting compiler to
use it when trying to compile ``Instances``.
After this function, the process will be reverted. User should be able to script another model
using different fields.
Example:
Assume that ``Instances`` in the model consist of two attributes named
``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
:class:`Tensor` respectively during inference. You can call this function like:
::
fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
torchscipt_model = scripting_with_instances(model, fields)
Note:
It only support models in evaluation mode.
Args:
model (nn.Module): The input model to be exported by scripting.
fields (Dict[str, type]): Attribute names and corresponding type that
``Instances`` will use in the model. Note that all attributes used in ``Instances``
need to be added, regardless of whether they are inputs/outputs of the model.
Data type not defined in detectron2 is not supported for now.
Returns:
torch.jit.ScriptModule: the model in torchscript format
"""
assert (
not model.training
), "Currently we only support exporting models in evaluation mode to torchscript"
with freeze_training_mode(model), patch_instances(fields):
scripted_model = torch.jit.script(model)
return scripted_model
# alias for old name
export_torchscript_with_instances = scripting_with_instances
def dump_torchscript_IR(model, dir):
"""
Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
inlined graph). Useful for debugging.
Args:
model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
dir (str): output directory to dump files.
"""
dir = os.path.expanduser(dir)
PathManager.mkdirs(dir)
def _get_script_mod(mod):
if isinstance(mod, torch.jit.TracedModule):
return mod._actual_script_module
return mod
# Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:
def get_code(mod):
# Try a few ways to get code using private attributes.
try:
# This contains more information than just `mod.code`
return _get_script_mod(mod)._c.code
except AttributeError:
pass
try:
return mod.code
except AttributeError:
return None
def dump_code(prefix, mod):
code = get_code(mod)
name = prefix or "root model"
if code is None:
f.write(f"Could not found code for {name} (type={mod.original_name})\n")
f.write("\n")
else:
f.write(f"\nCode for {name}, type={mod.original_name}:\n")
f.write(code)
f.write("\n")
f.write("-" * 80)
for name, m in mod.named_children():
dump_code(prefix + "." + name, m)
if isinstance(model, torch.jit.ScriptFunction):
f.write(get_code(model))
else:
dump_code("", model)
def _get_graph(model):
try:
# Recursively dump IR of all modules
return _get_script_mod(model)._c.dump_to_str(True, False, False)
except AttributeError:
return model.graph.str()
with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
f.write(_get_graph(model))
# Dump IR of the entire graph (all submodules inlined)
with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
f.write(str(model.inlined_graph))
if not isinstance(model, torch.jit.ScriptFunction):
# Dump the model structure in pytorch style
with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
f.write(str(model))
|