zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
raw
history blame
5.97 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch.distributed as dist
from .comm import (all_to_all, gather_forward_split_backward,
split_forward_gather_backward)
from .setup_distributed import (get_inner_sequence_parallel_group,
get_inner_sequence_parallel_world_size,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
init_inner_sequence_parallel,
is_inner_sequence_parallel_initialized)
def pre_process_for_sequence_parallel_attn(query_states,
key_states,
value_states,
scatter_dim=2,
gather_dim=1):
b, s_div_sp, h, d = query_states.shape
sp = get_sequence_parallel_world_size()
if not is_inner_sequence_parallel_initialized():
insp = sp // math.gcd(h, sp)
init_inner_sequence_parallel(insp)
else:
insp = get_inner_sequence_parallel_world_size()
def pre_process_for_inner_sp(q, k, v):
if scatter_dim != 2 and gather_dim != 1:
raise NotImplementedError(
'Currently only `scatter_dim == 2` and `gather_dim == 1` '
f'is supported. But got scatter_dim = {scatter_dim} and '
f'gather_dim = {gather_dim}.')
# (b, s_div_sp, h, d) ->
# (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
# (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
# (b, s_div_sp, insp*h, d/insp)
q = q.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
d // insp).transpose(3, 4).flatten(2, 4)
k = k.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
d // insp).transpose(3, 4).flatten(2, 4)
v = v.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
d // insp).transpose(3, 4).flatten(2, 4)
return q, k, v
def post_process_for_inner_sp(q, k, v):
# (b, s, insp*h/sp, d/insp) -> (b, s, insp*h/sp, d)
q = gather_forward_split_backward(q, -1,
get_inner_sequence_parallel_group())
k = gather_forward_split_backward(k, -1,
get_inner_sequence_parallel_group())
v = gather_forward_split_backward(v, -1,
get_inner_sequence_parallel_group())
return q, k, v
assert (h * insp) % sp == 0, \
('The number of attention heads should be divisible by '
'(sequence_parallel_world_size // sequence_parallel_inner_world_size)'
f'. But got n_head = {h}, sequence_parallel_world_size = '
f'{sp} and sequence_parallel_inner_world_size = {insp}.')
if insp > 1:
query_states, key_states, value_states = pre_process_for_inner_sp(
query_states, key_states, value_states)
# (b, s_div_sp, insp*h, d/insp) -> (b, s, insp*h/sp, d/insp)
sequence_parallel_group = get_sequence_parallel_group()
query_states = all_to_all(
query_states,
sequence_parallel_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim)
key_states = all_to_all(
key_states,
sequence_parallel_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim)
value_states = all_to_all(
value_states,
sequence_parallel_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim)
if insp > 1:
query_states, key_states, value_states = post_process_for_inner_sp(
query_states, key_states, value_states)
return query_states, key_states, value_states
def post_process_for_sequence_parallel_attn(attn_output,
scatter_dim=1,
gather_dim=2):
sp = get_sequence_parallel_world_size()
insp = get_inner_sequence_parallel_world_size()
b, s, h_mul_insp_div_sp, d = attn_output.shape
h = h_mul_insp_div_sp * sp // insp
s_div_sp = s // sp
if insp > 1:
# (b, s, insp*h/sp, d) -> (b, s, insp*h/sp, d/insp)
attn_output = split_forward_gather_backward(
attn_output, -1, get_inner_sequence_parallel_group())
# (b, s, insp*h/sp, d/insp) -> (b, s_div_sp, insp*h, d/insp)
sequence_parallel_group = get_sequence_parallel_group()
output = all_to_all(
attn_output,
sequence_parallel_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim)
if insp > 1:
# (b, s_div_sp, insp*h, d/insp) ->
# (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
# (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
# (b, s_div_sp, h, d)
output = output.view(b, s_div_sp, sp // insp, insp, h * insp // sp,
d // insp).transpose(3, 4).reshape(
b, s_div_sp, h, d)
return output
def sequence_parallel_wrapper(local_attn):
def sequence_parallel_attn(query_states, key_states, value_states, *args,
**kwargs):
training = kwargs.pop('training', True)
enable_sequence_parallel = (
dist.is_initialized() and get_sequence_parallel_world_size() > 1
and training)
if enable_sequence_parallel:
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
out = local_attn(query_states, key_states, value_states, *args,
**kwargs)
if enable_sequence_parallel:
out = post_process_for_sequence_parallel_attn(out).contiguous()
return out
return sequence_parallel_attn