Spaces:
Running
Running
| import os | |
| from typing import List, Union | |
| import torch | |
| from torch import Tensor, nn | |
| class ClipTextEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| modelpath: str='openai/clip-vit-large-patch14', # clip-vit-base-patch32 | |
| finetune: bool = False, | |
| **kwargs | |
| ) -> None: | |
| super().__init__() | |
| from transformers import logging | |
| from transformers import AutoModel, AutoTokenizer | |
| logging.set_verbosity_error() | |
| # Tokenizer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| self.tokenizer = AutoTokenizer.from_pretrained(modelpath) | |
| self.text_model = AutoModel.from_pretrained(modelpath) | |
| # Don't train the model | |
| if not finetune: | |
| self.text_model.training = False | |
| for p in self.text_model.parameters(): | |
| p.requires_grad = False | |
| # Then configure the model | |
| self.max_length = self.tokenizer.model_max_length | |
| self.text_encoded_dim = self.text_model.config.text_config.hidden_size | |
| def forward(self, texts: List[str]): | |
| # get prompt text embeddings | |
| text_inputs = self.tokenizer( | |
| texts, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.text_model.device) | |
| txt_att_mask = text_inputs.attention_mask.to(self.text_model.device) | |
| # split into max length Clip can handle | |
| if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | |
| text_input_ids = text_input_ids[:, :self.tokenizer. | |
| model_max_length] | |
| # use pooled ouuput if latent dim is two-dimensional | |
| # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim) | |
| # text encoder forward, clip must use get_text_features | |
| # (batch_Size, seq_length , text_encoded_dim) | |
| text_embeddings = self.text_model.text_model(text_input_ids, | |
| # attention_mask=txt_att_mask | |
| ).last_hidden_state | |
| return text_embeddings, txt_att_mask.bool() | |