File size: 8,975 Bytes
d49f7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations  # so we can refer to class Type inside class
import numpy as np
import numpy.typing as npt
from animated_drawings.model.vectors import Vectors
from animated_drawings.model.quaternions import Quaternions
import logging
from typing import Union, Optional, List, Tuple


class Transform():
    """Base class from which all other scene objects descend"""

    def __init__(self,
                 parent: Optional[Transform] = None,
                 name: Optional[str] = None,
                 children: List[Transform] = [],
                 offset: Union[npt.NDArray[np.float32], Vectors, None] = None,
                 **kwargs
                 ) -> None:

        super().__init__(**kwargs)

        self._parent: Optional[Transform] = parent

        self._children: List[Transform] = []
        for child in children:
            self.add_child(child)

        self.name: Optional[str] = name

        self._translate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
        self._rotate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
        self._scale_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)

        if offset is not None:
            self.offset(offset)

        self._local_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
        self._world_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
        self.dirty_bit: bool = True  # are world/local transforms stale?

    def update_transforms(self, parent_dirty_bit: bool = False, recurse_on_children: bool = True, update_ancestors: bool = False) -> None:
        """
        Updates transforms if stale.
        If own dirty bit is set, recompute local matrix
        If own or parent's dirty bit is set, recompute world matrix
        If own or parent's dirty bit is set, recurses on children, unless param recurse_on_children is false.
        If update_ancestors is true, first find first ancestor, then call update_transforms upon it.
        Set dirty bit back to false.
        """
        if update_ancestors:
            ancestor, ancestor_parent = self, self.get_parent()
            while ancestor_parent is not None:
                ancestor, ancestor_parent = ancestor_parent, ancestor_parent.get_parent()
            ancestor.update_transforms()

        if self.dirty_bit:
            self.compute_local_transform()
        if self.dirty_bit | parent_dirty_bit:
            self.compute_world_transform()

        if recurse_on_children:
            for c in self.get_children():
                c.update_transforms(self.dirty_bit | parent_dirty_bit)

        self.dirty_bit = False

    def compute_local_transform(self) -> None:
        self._local_transform = self._translate_m @ self._rotate_m @ self._scale_m

    def compute_world_transform(self) -> None:
        self._world_transform = self._local_transform
        if self._parent:
            self._world_transform = self._parent._world_transform @ self._world_transform

    def get_world_transform(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]:
        """
        Get the transform's world matrix.
        If update is true, check to ensure the world_transform is current
        """
        if update_ancestors:
            self.update_transforms(update_ancestors=True)
        return np.copy(self._world_transform)

    def set_scale(self, scale: float) -> None:
        self._scale_m[:-1, :-1] = scale * np.identity(3, dtype=np.float32)
        self.dirty_bit = True

    def set_position(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None:
        """ Set the absolute values of the translational elements of transform """
        if isinstance(pos, Vectors):
            pos = pos.vs

        if pos.shape == (1, 3):
            pos = np.squeeze(pos)
        elif pos.shape == (3,):
            pass
        else:
            msg = f'bad vector dim passed to set_position. Found: {pos.shape}'
            logging.critical(msg)
            assert False, msg

        self._translate_m[:-1, -1] = pos
        self.dirty_bit = True

    def get_local_position(self) -> npt.NDArray[np.float32]:
        """ Ensure local transform is up-to-date and return local xyz coordinates """
        if self.dirty_bit:
            self.compute_local_transform()
        return np.copy(self._local_transform[:-1, -1])

    def get_world_position(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]:
        """
        Ensure all parent transforms are update and return world xyz coordinates
        If update_ancestor_transforms is true, update ancestor transforms to ensure
        up-to-date world_transform before returning
        """
        if update_ancestors:
            self.update_transforms(update_ancestors=True)

        return np.copy(self._world_transform[:-1, -1])

    def offset(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None:
        """ Translational offset by the specified amount """

        if isinstance(pos, Vectors):
            pos = pos.vs[0]
        assert isinstance(pos, np.ndarray)

        self.set_position(self._translate_m[:-1, -1] + pos)

    def look_at(self, fwd_: Union[npt.NDArray[np.float32], Vectors, None]) -> None:
        """Given a forward vector, rotate the transform to face that position"""
        if fwd_ is None:
            fwd_ = Vectors(self.get_world_position())
        elif isinstance(fwd_, np.ndarray):
            fwd_ = Vectors(fwd_)
        fwd: Vectors = fwd_.copy()  # norming will change the vector

        if fwd.vs.shape != (1, 3):
            msg = f'look_at fwd_ vector must have shape [1,3]. Found: {fwd.vs.shape}'
            logging.critical(msg)
            assert False, msg

        tmp: Vectors = Vectors([0.0, 1.0, 0.0])

        # if fwd and tmp are same vector, modify tmp to avoid collapse
        if np.isclose(fwd.vs, tmp.vs).all() or np.isclose(fwd.vs, -tmp.vs).all():
            tmp.vs[0] += 0.001

        right: Vectors = tmp.cross(fwd)
        up: Vectors = fwd.cross(right)

        fwd.norm()
        right.norm()
        up.norm()

        rotate_m = np.identity(4, dtype=np.float32)
        rotate_m[:-1, 0] = np.squeeze(right.vs)
        rotate_m[:-1, 1] = np.squeeze(up.vs)
        rotate_m[:-1, 2] = np.squeeze(fwd.vs)

        self._rotate_m = rotate_m
        self.dirty_bit = True

    def get_right_up_fwd_vectors(self) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
        inverted: npt.NDArray[np.float32] = np.linalg.inv(self.get_world_transform())
        right: npt.NDArray[np.float32] = inverted[:-1, 0]
        up: npt.NDArray[np.float32] = inverted[:-1, 1]
        fwd: npt.NDArray[np.float32] = inverted[:-1, 2]

        return right, up, fwd

    def set_rotation(self, q: Quaternions) -> None:
        if q.qs.shape != (1, 4):
            msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}'
            logging.critical(msg)
            assert False, msg
        self._rotate_m = q.to_rotation_matrix()
        self.dirty_bit = True

    def rotation_offset(self, q: Quaternions) -> None:
        if q.qs.shape != (1, 4):
            msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}'
            logging.critical(msg)
            assert False, msg
        self._rotate_m = (q * Quaternions.from_rotation_matrix(self._rotate_m)).to_rotation_matrix()
        self.dirty_bit = True

    def add_child(self, child: Transform) -> None:
        self._children.append(child)
        child.set_parent(self)

    def get_children(self) -> List[Transform]:
        return self._children

    def set_parent(self, parent: Transform) -> None:
        self._parent = parent
        self.dirty_bit = True

    def get_parent(self) -> Optional[Transform]:
        return self._parent

    def get_transform_by_name(self, name: str) -> Optional[Transform]:
        """ Search self and children for transform with matching name. Return it if found, None otherwise. """

        # are we match?
        if self.name == name:
            return self

        # recurse to check if a child is match
        for c in self.get_children():
            transform_or_none = c.get_transform_by_name(name)
            if transform_or_none:  # if we found it
                return transform_or_none

        # no match
        return None

    def draw(self, recurse: bool = True, **kwargs) -> None:
        """ Draw this transform and recurse on children """
        self._draw(**kwargs)

        if recurse:
            for child in self.get_children():
                child.draw(**kwargs)

    def _draw(self, **kwargs) -> None:
        """Transforms default to not being drawn. Subclasses must implement how they appear"""