Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import torch.distributed as dist | |
from .comm import (all_to_all, gather_forward_split_backward, | |
split_forward_gather_backward) | |
from .setup_distributed import (get_inner_sequence_parallel_group, | |
get_inner_sequence_parallel_world_size, | |
get_sequence_parallel_group, | |
get_sequence_parallel_world_size, | |
init_inner_sequence_parallel, | |
is_inner_sequence_parallel_initialized) | |
def pre_process_for_sequence_parallel_attn(query_states, | |
key_states, | |
value_states, | |
scatter_dim=2, | |
gather_dim=1): | |
b, s_div_sp, h, d = query_states.shape | |
sp = get_sequence_parallel_world_size() | |
if not is_inner_sequence_parallel_initialized(): | |
insp = sp // math.gcd(h, sp) | |
init_inner_sequence_parallel(insp) | |
else: | |
insp = get_inner_sequence_parallel_world_size() | |
def pre_process_for_inner_sp(q, k, v): | |
if scatter_dim != 2 and gather_dim != 1: | |
raise NotImplementedError( | |
'Currently only `scatter_dim == 2` and `gather_dim == 1` ' | |
f'is supported. But got scatter_dim = {scatter_dim} and ' | |
f'gather_dim = {gather_dim}.') | |
# (b, s_div_sp, h, d) -> | |
# (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) -> | |
# (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) -> | |
# (b, s_div_sp, insp*h, d/insp) | |
q = q.view(b, s_div_sp, sp // insp, h * insp // sp, insp, | |
d // insp).transpose(3, 4).flatten(2, 4) | |
k = k.view(b, s_div_sp, sp // insp, h * insp // sp, insp, | |
d // insp).transpose(3, 4).flatten(2, 4) | |
v = v.view(b, s_div_sp, sp // insp, h * insp // sp, insp, | |
d // insp).transpose(3, 4).flatten(2, 4) | |
return q, k, v | |
def post_process_for_inner_sp(q, k, v): | |
# (b, s, insp*h/sp, d/insp) -> (b, s, insp*h/sp, d) | |
q = gather_forward_split_backward(q, -1, | |
get_inner_sequence_parallel_group()) | |
k = gather_forward_split_backward(k, -1, | |
get_inner_sequence_parallel_group()) | |
v = gather_forward_split_backward(v, -1, | |
get_inner_sequence_parallel_group()) | |
return q, k, v | |
assert (h * insp) % sp == 0, \ | |
('The number of attention heads should be divisible by ' | |
'(sequence_parallel_world_size // sequence_parallel_inner_world_size)' | |
f'. But got n_head = {h}, sequence_parallel_world_size = ' | |
f'{sp} and sequence_parallel_inner_world_size = {insp}.') | |
if insp > 1: | |
query_states, key_states, value_states = pre_process_for_inner_sp( | |
query_states, key_states, value_states) | |
# (b, s_div_sp, insp*h, d/insp) -> (b, s, insp*h/sp, d/insp) | |
sequence_parallel_group = get_sequence_parallel_group() | |
query_states = all_to_all( | |
query_states, | |
sequence_parallel_group, | |
scatter_dim=scatter_dim, | |
gather_dim=gather_dim) | |
key_states = all_to_all( | |
key_states, | |
sequence_parallel_group, | |
scatter_dim=scatter_dim, | |
gather_dim=gather_dim) | |
value_states = all_to_all( | |
value_states, | |
sequence_parallel_group, | |
scatter_dim=scatter_dim, | |
gather_dim=gather_dim) | |
if insp > 1: | |
query_states, key_states, value_states = post_process_for_inner_sp( | |
query_states, key_states, value_states) | |
return query_states, key_states, value_states | |
def post_process_for_sequence_parallel_attn(attn_output, | |
scatter_dim=1, | |
gather_dim=2): | |
sp = get_sequence_parallel_world_size() | |
insp = get_inner_sequence_parallel_world_size() | |
b, s, h_mul_insp_div_sp, d = attn_output.shape | |
h = h_mul_insp_div_sp * sp // insp | |
s_div_sp = s // sp | |
if insp > 1: | |
# (b, s, insp*h/sp, d) -> (b, s, insp*h/sp, d/insp) | |
attn_output = split_forward_gather_backward( | |
attn_output, -1, get_inner_sequence_parallel_group()) | |
# (b, s, insp*h/sp, d/insp) -> (b, s_div_sp, insp*h, d/insp) | |
sequence_parallel_group = get_sequence_parallel_group() | |
output = all_to_all( | |
attn_output, | |
sequence_parallel_group, | |
scatter_dim=scatter_dim, | |
gather_dim=gather_dim) | |
if insp > 1: | |
# (b, s_div_sp, insp*h, d/insp) -> | |
# (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) -> | |
# (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) -> | |
# (b, s_div_sp, h, d) | |
output = output.view(b, s_div_sp, sp // insp, insp, h * insp // sp, | |
d // insp).transpose(3, 4).reshape( | |
b, s_div_sp, h, d) | |
return output | |
def sequence_parallel_wrapper(local_attn): | |
def sequence_parallel_attn(query_states, key_states, value_states, *args, | |
**kwargs): | |
training = kwargs.pop('training', True) | |
enable_sequence_parallel = ( | |
dist.is_initialized() and get_sequence_parallel_world_size() > 1 | |
and training) | |
if enable_sequence_parallel: | |
query_states, key_states, value_states = \ | |
pre_process_for_sequence_parallel_attn( | |
query_states, key_states, value_states) | |
out = local_attn(query_states, key_states, value_states, *args, | |
**kwargs) | |
if enable_sequence_parallel: | |
out = post_process_for_sequence_parallel_attn(out).contiguous() | |
return out | |
return sequence_parallel_attn | |