aynetdia commited on
Commit
d39c21b
·
1 Parent(s): 6d82c6c

fix last token pooling

Browse files
Files changed (1) hide show
  1. semscore.py +2 -0
semscore.py CHANGED
@@ -98,6 +98,8 @@ class SemScore(evaluate.Metric):
98
 
99
  @staticmethod
100
  def _last_token_pooling(last_hidden_states, attention_mask):
 
 
101
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
102
  if left_padding:
103
  return last_hidden_states[:, -1]
 
98
 
99
  @staticmethod
100
  def _last_token_pooling(last_hidden_states, attention_mask):
101
+ if not isinstance(last_hidden_states, torch.Tensor):
102
+ last_hidden_states = last_hidden_states.last_hidden_state
103
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
104
  if left_padding:
105
  return last_hidden_states[:, -1]