jpohhhh commited on
Commit
1e0a1be
·
1 Parent(s): 52b2458

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -0
handler.py CHANGED
@@ -8,6 +8,20 @@ import time
8
  import os
9
  import torch
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  #Mean Pooling - Take attention mask into account for correct averaging
12
  def mean_pooling(model_output, attention_mask):
13
  token_embeddings = model_output[0] #First element of model_output contains all token embeddings
 
8
  import os
9
  import torch
10
 
11
+ def max_pooling(model_output):
12
+ # Get dimensions
13
+ _, Z, Y = model_output.shape
14
+ # Initialize an empty list with length Y (384 in your case)
15
+ output_array = [0] * Y
16
+ # Loop over secondary arrays (Z)
17
+ for i in range(Z):
18
+ # Loop over values in innermost arrays (Y)
19
+ for j in range(Y):
20
+ # If value is greater than current max, update max
21
+ if model_output[0][i][j] > output_array[j]:
22
+ output_array[j] = model_output[0][i][j]
23
+ return output_array
24
+
25
  #Mean Pooling - Take attention mask into account for correct averaging
26
  def mean_pooling(model_output, attention_mask):
27
  token_embeddings = model_output[0] #First element of model_output contains all token embeddings