EvanTHU commited on
Commit
1908ce0
·
verified ·
1 Parent(s): f30b316

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +3 -5
models/unet.py CHANGED
@@ -840,16 +840,14 @@ class MotionCLR(nn.Module):
840
  def encode_text(self, raw_text, device):
841
  print("00000000")
842
  print(device)
843
- print(self.clip_model.device)
844
  print("00000000")
845
  with torch.no_grad():
846
  texts = clip.tokenize(raw_text, truncate=True).to(
847
  device
848
  ) # [bs, context_length] # if n_tokens > 77 -> will truncate
849
- x = self.clip_model.token_embedding(texts).type(
850
- self.clip_model.dtype
851
- ) # [batch_size, n_ctx, d_model]
852
- x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
853
  x = x.permute(1, 0, 2) # NLD -> LND
854
  x = self.clip_model.transformer(x)
855
  x = self.clip_model.ln_final(x).type(
 
840
  def encode_text(self, raw_text, device):
841
  print("00000000")
842
  print(device)
843
+ print(next(self.clip_model.parameters()).device)
844
  print("00000000")
845
  with torch.no_grad():
846
  texts = clip.tokenize(raw_text, truncate=True).to(
847
  device
848
  ) # [bs, context_length] # if n_tokens > 77 -> will truncate
849
+ x = self.clip_model.token_embedding(texts).type(self.clip_model.dtype).to(device) # [batch_size, n_ctx, d_model]
850
+ x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype).to(device)
 
 
851
  x = x.permute(1, 0, 2) # NLD -> LND
852
  x = self.clip_model.transformer(x)
853
  x = self.clip_model.ln_final(x).type(