File size: 5,966 Bytes
476ac07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch.distributed as dist

from .comm import (all_to_all, gather_forward_split_backward,
                   split_forward_gather_backward)
from .setup_distributed import (get_inner_sequence_parallel_group,
                                get_inner_sequence_parallel_world_size,
                                get_sequence_parallel_group,
                                get_sequence_parallel_world_size,
                                init_inner_sequence_parallel,
                                is_inner_sequence_parallel_initialized)


def pre_process_for_sequence_parallel_attn(query_states,
                                           key_states,
                                           value_states,
                                           scatter_dim=2,
                                           gather_dim=1):
    b, s_div_sp, h, d = query_states.shape
    sp = get_sequence_parallel_world_size()

    if not is_inner_sequence_parallel_initialized():
        insp = sp // math.gcd(h, sp)
        init_inner_sequence_parallel(insp)
    else:
        insp = get_inner_sequence_parallel_world_size()

    def pre_process_for_inner_sp(q, k, v):
        if scatter_dim != 2 and gather_dim != 1:
            raise NotImplementedError(
                'Currently only `scatter_dim == 2` and `gather_dim == 1` '
                f'is supported. But got scatter_dim = {scatter_dim} and '
                f'gather_dim = {gather_dim}.')

        # (b, s_div_sp, h, d) ->
        # (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
        # (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
        # (b, s_div_sp, insp*h, d/insp)
        q = q.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
                   d // insp).transpose(3, 4).flatten(2, 4)
        k = k.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
                   d // insp).transpose(3, 4).flatten(2, 4)
        v = v.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
                   d // insp).transpose(3, 4).flatten(2, 4)

        return q, k, v

    def post_process_for_inner_sp(q, k, v):
        # (b, s, insp*h/sp, d/insp) -> (b, s, insp*h/sp, d)
        q = gather_forward_split_backward(q, -1,
                                          get_inner_sequence_parallel_group())
        k = gather_forward_split_backward(k, -1,
                                          get_inner_sequence_parallel_group())
        v = gather_forward_split_backward(v, -1,
                                          get_inner_sequence_parallel_group())

        return q, k, v

    assert (h * insp) % sp == 0, \
        ('The number of attention heads should be divisible by '
         '(sequence_parallel_world_size // sequence_parallel_inner_world_size)'
         f'. But got n_head = {h}, sequence_parallel_world_size = '
         f'{sp} and sequence_parallel_inner_world_size = {insp}.')

    if insp > 1:
        query_states, key_states, value_states = pre_process_for_inner_sp(
            query_states, key_states, value_states)

    # (b, s_div_sp, insp*h, d/insp) -> (b, s, insp*h/sp, d/insp)
    sequence_parallel_group = get_sequence_parallel_group()
    query_states = all_to_all(
        query_states,
        sequence_parallel_group,
        scatter_dim=scatter_dim,
        gather_dim=gather_dim)
    key_states = all_to_all(
        key_states,
        sequence_parallel_group,
        scatter_dim=scatter_dim,
        gather_dim=gather_dim)
    value_states = all_to_all(
        value_states,
        sequence_parallel_group,
        scatter_dim=scatter_dim,
        gather_dim=gather_dim)

    if insp > 1:
        query_states, key_states, value_states = post_process_for_inner_sp(
            query_states, key_states, value_states)

    return query_states, key_states, value_states


def post_process_for_sequence_parallel_attn(attn_output,
                                            scatter_dim=1,
                                            gather_dim=2):
    sp = get_sequence_parallel_world_size()
    insp = get_inner_sequence_parallel_world_size()
    b, s, h_mul_insp_div_sp, d = attn_output.shape
    h = h_mul_insp_div_sp * sp // insp
    s_div_sp = s // sp

    if insp > 1:
        # (b, s, insp*h/sp, d) -> (b, s, insp*h/sp, d/insp)
        attn_output = split_forward_gather_backward(
            attn_output, -1, get_inner_sequence_parallel_group())

    # (b, s, insp*h/sp, d/insp) -> (b, s_div_sp, insp*h, d/insp)
    sequence_parallel_group = get_sequence_parallel_group()
    output = all_to_all(
        attn_output,
        sequence_parallel_group,
        scatter_dim=scatter_dim,
        gather_dim=gather_dim)

    if insp > 1:
        # (b, s_div_sp, insp*h, d/insp) ->
        # (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
        # (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
        # (b, s_div_sp, h, d)
        output = output.view(b, s_div_sp, sp // insp, insp, h * insp // sp,
                             d // insp).transpose(3, 4).reshape(
                                 b, s_div_sp, h, d)

    return output


def sequence_parallel_wrapper(local_attn):

    def sequence_parallel_attn(query_states, key_states, value_states, *args,
                               **kwargs):
        training = kwargs.pop('training', True)
        enable_sequence_parallel = (
            dist.is_initialized() and get_sequence_parallel_world_size() > 1
            and training)
        if enable_sequence_parallel:
            query_states, key_states, value_states = \
                pre_process_for_sequence_parallel_attn(
                    query_states, key_states, value_states)

        out = local_attn(query_states, key_states, value_states, *args,
                         **kwargs)

        if enable_sequence_parallel:
            out = post_process_for_sequence_parallel_attn(out).contiguous()

        return out

    return sequence_parallel_attn