dariakryvosheieva commited on
Commit
f733cc1
·
verified ·
1 Parent(s): d28253c

Delete custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +0 -178
custom_st.py DELETED
@@ -1,178 +0,0 @@
1
- from typing import List, Dict, Tuple, Union, Any, Optional
2
-
3
- import os
4
- import json
5
- import torch
6
-
7
- from torch import nn
8
- from transformers import AutoConfig, AutoModel, AutoTokenizer
9
- from transformers.utils import is_flash_attn_2_available
10
-
11
-
12
- class Transformer(nn.Module):
13
- def __init__(
14
- self,
15
- model_name_or_path: str,
16
- max_seq_length: int = None,
17
- model_args: Dict[str, Any] = None,
18
- tokenizer_args: Dict[str, Any] = None,
19
- config_args: Dict[str, Any] = None,
20
- cache_dir: str = None,
21
- do_lower_case: bool = False,
22
- tokenizer_name_or_path: str = None,
23
- **kwargs,
24
- ) -> None:
25
- super().__init__()
26
- self.config_keys = ["max_seq_length", "do_lower_case"]
27
- self.do_lower_case = do_lower_case
28
- if model_args is None:
29
- model_args = {}
30
- if tokenizer_args is None:
31
- tokenizer_args = {}
32
- if config_args is None:
33
- config_args = {}
34
-
35
- self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
36
-
37
- self.task_names = self.config.task_names
38
-
39
- self.default_task = model_args.pop('default_task', None)
40
-
41
- model_args["attn_implementation"] = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
42
-
43
- self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
44
-
45
- if max_seq_length is not None and "model_max_length" not in tokenizer_args:
46
- tokenizer_args["model_max_length"] = max_seq_length
47
- self.tokenizer = AutoTokenizer.from_pretrained(
48
- tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
49
- cache_dir=cache_dir,
50
- **tokenizer_args,
51
- )
52
-
53
- # No max_seq_length set. Try to infer from model
54
- if max_seq_length is None:
55
- if (
56
- hasattr(self.auto_model, "config")
57
- and hasattr(self.auto_model.config, "max_position_embeddings")
58
- and hasattr(self.tokenizer, "model_max_length")
59
- ):
60
- max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
61
-
62
- self.max_seq_length = max_seq_length
63
-
64
- if tokenizer_name_or_path is not None:
65
- self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
66
-
67
-
68
- @property
69
- def default_task(self):
70
- return self._default_task
71
-
72
-
73
- @default_task.setter
74
- def default_task(self, task: Union[None, str]):
75
- self._validate_task(task)
76
- self._default_task = task
77
-
78
-
79
- def _validate_task(self, task: str):
80
- if task and task not in self.task_names:
81
- raise ValueError(
82
- f"Unsupported task '{task}'. "
83
- f"Supported tasks are: {', '.join(self.config.task_names)}."
84
- )
85
-
86
-
87
- def forward(
88
- self,
89
- features: Dict[str, torch.Tensor],
90
- task: Optional[str] = None
91
- ) -> Dict[str, torch.Tensor]:
92
- """
93
- Forward pass through the model.
94
- """
95
- features.pop('prompt_length', None)
96
- output_states = self.auto_model.forward(
97
- **features,
98
- output_attentions=False,
99
- return_dict=True
100
- )
101
- output_tokens = output_states[0]
102
- features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
103
- return features
104
-
105
-
106
- def get_word_embedding_dimension(self) -> int:
107
- return self.auto_model.config.hidden_size
108
-
109
-
110
- def tokenize(
111
- self,
112
- texts: Union[List[str], List[dict], List[Tuple[str, str]]],
113
- padding: Union[str, bool] = True
114
- ) -> Dict[str, torch.Tensor]:
115
- """Tokenizes a text and maps tokens to token-ids"""
116
- output = {}
117
- if isinstance(texts[0], str):
118
- to_tokenize = [texts]
119
- elif isinstance(texts[0], dict):
120
- to_tokenize = []
121
- output["text_keys"] = []
122
- for lookup in texts:
123
- text_key, text = next(iter(lookup.items()))
124
- to_tokenize.append(text)
125
- output["text_keys"].append(text_key)
126
- to_tokenize = [to_tokenize]
127
- else:
128
- batch1, batch2 = [], []
129
- for text_tuple in texts:
130
- batch1.append(text_tuple[0])
131
- batch2.append(text_tuple[1])
132
- to_tokenize = [batch1, batch2]
133
-
134
- # strip
135
- to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
136
-
137
- # Lowercase
138
- if self.do_lower_case:
139
- to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
140
-
141
- output.update(
142
- self.tokenizer(
143
- *to_tokenize,
144
- padding=padding,
145
- truncation=True,
146
- return_tensors="pt",
147
- max_length=self.max_seq_length,
148
- )
149
- )
150
- return output
151
-
152
-
153
- def get_config_dict(self) -> Dict[str, Any]:
154
- return {key: self.__dict__[key] for key in self.config_keys}
155
-
156
-
157
- def save(self, output_path: str, safe_serialization: bool = True) -> None:
158
- self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
159
- self.tokenizer.save_pretrained(output_path)
160
-
161
- with open(os.path.join(output_path, "sentence_transformer_config.json"), "w") as fOut:
162
- json.dump(self.get_config_dict(), fOut, indent=2)
163
-
164
-
165
- @classmethod
166
- def load(cls, input_path: str) -> "Transformer":
167
- config_name = "sentence_transformer_config.json"
168
- stransformer_config_path = os.path.join(input_path, config_name)
169
- with open(stransformer_config_path) as fIn:
170
- config = json.load(fIn)
171
- # Don't allow configs to set trust_remote_code
172
- if "model_args" in config and "trust_remote_code" in config["model_args"]:
173
- config["model_args"].pop("trust_remote_code")
174
- if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
175
- config["tokenizer_args"].pop("trust_remote_code")
176
- if "config_args" in config and "trust_remote_code" in config["config_args"]:
177
- config["config_args"].pop("trust_remote_code")
178
- return cls(model_name_or_path=input_path, **config)