File size: 2,766 Bytes
e54915d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from transformers import RoFormerConfig, RoFormerForMaskedLM
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch

class Roformer(nn.Module):
    def __init__(self, config, tokenizer):
        super(Roformer, self).__init__()
        
        self.tokenizer = tokenizer
        self.vocab_size = self.tokenizer.vocab_size
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        
        roformer_config = RoFormerConfig(
            vocab_size=self.tokenizer.vocab_size,
            embedding_size=config.roformer.hidden_size,
            hidden_size=config.roformer.hidden_size,
            num_hidden_layers=config.roformer.n_layers,
            num_attention_heads=config.roformer.n_heads,
            intermediate_size=config.roformer.hidden_size * 4,
            max_position_embeddings=config.roformer.max_position_embeddings,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            pad_token_id=0,
            rotary_value=False
        )
        
        self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
                
    def freeze_model(self):
        for param in self.model.parameters():
            param.requires_grad = False
    
    def unfreeze_all_layers(self):
        for param in self.model.parameters():
            param.requires_grad = True
    
    def unfreeze_n_layers(self, n):
        num_layers = 8
        
        for i, layer in enumerate(self.model.roformer.encoder.layer):
            # finetune final n layers
            if i >= num_layers - n:
                # unfreeze query weights
                for module in layer.attention.self.query.modules():
                    for param in module.parameters():
                         param.requires_grad = True
                # unfreeze key weights
                for module in layer.attention.self.key.modules():
                    for param in module.parameters():
                        param.requires_grad = True
            
    def forward(self, input_ids, attn_mask):
        
        input_ids = input_ids.to(self.device)
        attn_mask = attn_mask.to(self.device)

        # get logits embeddings
        logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
        # return logits
        #print(logits.logits)
        return logits.logits
    
    def save_model(self, save_dir):
        self.model.save_pretrained(save_dir)
        self.tokenizer.save_pretrained(save_dir)
    
    @classmethod
    def load_model(cls, save_dir, config, tokenizer):
        roformer = cls(config, tokenizer)
        roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
        return roformer