from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig from transformers.models.gemma3.modeling_gemma3 import ( Gemma3TextModel, Gemma3PreTrainedModel, GEMMA3_INPUTS_DOCSTRING, ) from transformers.modeling_outputs import ( SequenceClassifierOutputWithPast, BaseModelOutputWithPast, ) from transformers.utils.doc import add_start_docstrings_to_model_forward from transformers.utils.generic import can_return_tuple from transformers.utils import logging from transformers.cache_utils import Cache from typing import Optional from torch import nn import torch logger = logging.get_logger(__name__) class Gemma3ForSequenceClassification(Gemma3PreTrainedModel): config_class = Gemma3TextConfig def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = Gemma3TextModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError( "Cannot handle batch sizes > 1 if no padding token is defined." ) if self.config.pad_token_id is None: last_non_pad_token = -1 elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to( logits.device, torch.int32 ) token_indices = torch.arange( input_ids.shape[-1], device=logits.device, dtype=torch.int32 ) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) pooled_logits = logits[ torch.arange(batch_size, device=logits.device), last_non_pad_token ] loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config, ) return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )