dim commited on
Commit
e22762e
·
1 Parent(s): 8672de0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +99 -0
README.md CHANGED
@@ -1,6 +1,105 @@
1
  ---
2
  library_name: peft
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  ## Training procedure
5
 
6
 
 
1
  ---
2
  library_name: peft
3
  ---
4
+
5
+ ```python
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
7
+ import torch
8
+ from peft import PeftModel, PeftConfig
9
+
10
+
11
+ class GoralConversation:
12
+ def __init__(
13
+ self,
14
+ message_template=" <s> {role}\n{content} </s>\n",
15
+ system_prompt="Ты — Горал, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им.",
16
+ start_token_id=1,
17
+ bot_token_id=9225,
18
+ ):
19
+ self.message_template = message_template
20
+ self.start_token_id = start_token_id
21
+ self.bot_token_id = bot_token_id
22
+ self.messages = [{"role": "system", "content": system_prompt}]
23
+
24
+ def get_start_token_id(self):
25
+ return self.start_token_id
26
+
27
+ def get_bot_token_id(self):
28
+ return self.bot_token_id
29
+
30
+ def add_user_message(self, message):
31
+ self.messages.append({"role": "user", "content": message})
32
+
33
+ def add_bot_message(self, message):
34
+ self.messages.append({"role": "bot", "content": message})
35
+
36
+ def get_prompt(self, tokenizer):
37
+ final_text = ""
38
+ for message in self.messages:
39
+ message_text = self.message_template.format(**message)
40
+ final_text += message_text
41
+ final_text += tokenizer.decode(
42
+ [
43
+ self.start_token_id,
44
+ ]
45
+ )
46
+ final_text += " "
47
+ final_text += tokenizer.decode([self.bot_token_id])
48
+ return final_text.strip()
49
+
50
+
51
+ def generate(model, tokenizer, prompt, generation_config):
52
+ data = tokenizer(
53
+ prompt,
54
+ return_tensors="pt",
55
+ truncation=True,
56
+ max_length=2048,
57
+ )
58
+ data = {k: v.to(model.device) for k, v in data.items()}
59
+ output_ids = model.generate(**data, generation_config=generation_config)[0]
60
+ output_ids = output_ids[len(data["input_ids"][0]) :]
61
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
62
+ return output.strip()
63
+
64
+
65
+ weights_path = "dim/xglm-4.5b_dolly_oasst1_chip2"
66
+ access_token = ""
67
+
68
+ config = PeftConfig.from_pretrained(weights_path)
69
+ model = AutoModelForCausalLM.from_pretrained(
70
+ config.base_model_name_or_path,
71
+ load_in_8bit=True,
72
+ torch_dtype=torch.float16,
73
+ device_map={"": 0},
74
+ token=access_token,
75
+ )
76
+ model = PeftModel.from_pretrained(
77
+ model,
78
+ weights_path,
79
+ torch_dtype=torch.float16,
80
+ )
81
+ model.eval()
82
+
83
+ tokenizer = AutoTokenizer.from_pretrained(weights_path)
84
+ generation_config = GenerationConfig.from_pretrained(weights_path)
85
+ generation_config.do_sample = False
86
+
87
+
88
+ inp = "Напишите интересный пост в блоге о недавней поездке на Гавайи, рассказывая о культурном опыте и достопримечательностях, которые обязательно нужно увидеть."
89
+ conversation = GoralConversation(
90
+ start_token_id=0,
91
+ bot_token_id=7425,
92
+ )
93
+ conversation.add_user_message(inp)
94
+ prompt = conversation.get_prompt(tokenizer)
95
+
96
+ output = generate(model, tokenizer, prompt, generation_config)
97
+ print(inp)
98
+ print(output)
99
+ # Я был там! Это было незабываемое путешествие, которое я никогда не забуду. Мы посетили все основные достопримечательности острова, включая пляжи, вулканы, пещеры, национальные парки и многое другое. Впечатления от посещения были потрясающими, а культура - уникальной. Поездка была отличным способом исследовать остров и узнать больше об истории его жителей. Надеюсь, что вы также захотите посетить это место!
100
+ ```
101
+
102
+
103
  ## Training procedure
104
 
105