Jackmin108 commited on
Commit
65e9690
·
1 Parent(s): 4ee2970

fix: device

Browse files

Signed-off-by: Meow <[email protected]>

Files changed (4) hide show
  1. embedding.py +1 -1
  2. mha.py +2 -2
  3. mlp.py +2 -2
  4. modeling_xlm_roberta.py +1 -1
embedding.py CHANGED
@@ -51,7 +51,7 @@ class XLMRobertaEmbeddings(nn.Module):
51
  unique_tasks = torch.unique(adapter_mask).tolist()
52
  embedding_dtype = next(self.word_embeddings.parameters()).dtype
53
  embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
54
- dtype=embedding_dtype).to(input_ids.device)
55
  for task_id in unique_tasks:
56
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
57
  task_input_ids = input_ids[task_indices]
 
51
  unique_tasks = torch.unique(adapter_mask).tolist()
52
  embedding_dtype = next(self.word_embeddings.parameters()).dtype
53
  embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
54
+ dtype=embedding_dtype, device=input_ids.device)
55
  for task_id in unique_tasks:
56
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
57
  task_input_ids = input_ids[task_indices]
mha.py CHANGED
@@ -650,7 +650,7 @@ class MHA(nn.Module):
650
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
651
  qkv_dtype = next(self.Wqkv.parameters()).dtype
652
  qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
653
- dtype=qkv_dtype).to(x.device)
654
  for task_id in unique_tasks:
655
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
656
  task_tensor = x[task_indices]
@@ -755,7 +755,7 @@ class MHA(nn.Module):
755
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
756
  out_dtype = next(self.out_proj.parameters()).dtype
757
  out = torch.empty(inp.shape[0], self.out_proj.out_features,
758
- dtype=out_dtype).to(inp.device)
759
  for task_id in unique_tasks:
760
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
761
  task_tensor = inp[task_indices]
 
650
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
651
  qkv_dtype = next(self.Wqkv.parameters()).dtype
652
  qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
653
+ dtype=qkv_dtype, device=x.device)
654
  for task_id in unique_tasks:
655
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
656
  task_tensor = x[task_indices]
 
755
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
756
  out_dtype = next(self.out_proj.parameters()).dtype
757
  out = torch.empty(inp.shape[0], self.out_proj.out_features,
758
+ dtype=out_dtype, device=inp.device)
759
  for task_id in unique_tasks:
760
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
761
  task_tensor = inp[task_indices]
mlp.py CHANGED
@@ -52,7 +52,7 @@ class Mlp(nn.Module):
52
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
53
  fc1_dtype = next(self.fc1.parameters()).dtype
54
  y = torch.empty(x.shape[0], self.fc1.out_features,
55
- dtype=fc1_dtype).to(x.device)
56
  for task_id in unique_tasks:
57
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
  task_tensor = x[task_indices]
@@ -67,7 +67,7 @@ class Mlp(nn.Module):
67
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
68
  fc2_dtype = next(self.fc2.parameters()).dtype
69
  out = torch.empty(y.shape[0], self.fc2.out_features,
70
- dtype=fc2_dtype).to(y.device)
71
  for task_id in unique_tasks:
72
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
  task_tensor = y[task_indices]
 
52
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
53
  fc1_dtype = next(self.fc1.parameters()).dtype
54
  y = torch.empty(x.shape[0], self.fc1.out_features,
55
+ dtype=fc1_dtype, device=x.device)
56
  for task_id in unique_tasks:
57
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
  task_tensor = x[task_indices]
 
67
  unique_tasks = torch.unique(cu_adapter_mask).tolist()
68
  fc2_dtype = next(self.fc2.parameters()).dtype
69
  out = torch.empty(y.shape[0], self.fc2.out_features,
70
+ dtype=fc2_dtype, device=y.device)
71
  for task_id in unique_tasks:
72
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
  task_tensor = y[task_indices]
modeling_xlm_roberta.py CHANGED
@@ -317,7 +317,7 @@ class XLMRobertaPooler(nn.Module):
317
  unique_tasks = torch.unique(adapter_mask).tolist()
318
  pool_dtype = next(self.dense.parameters()).dtype
319
  pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
320
- dtype=pool_dtype).to(first_token_tensor.device)
321
  for task_id in unique_tasks:
322
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
323
  task_first_token_tensor = first_token_tensor[task_indices]
 
317
  unique_tasks = torch.unique(adapter_mask).tolist()
318
  pool_dtype = next(self.dense.parameters()).dtype
319
  pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
320
+ dtype=pool_dtype, device=first_token_tensor.device)
321
  for task_id in unique_tasks:
322
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
323
  task_first_token_tensor = first_token_tensor[task_indices]