chandan-sreedhara commited on
Commit
6ba4e9d
·
verified ·
1 Parent(s): de3125b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -32
README.md CHANGED
@@ -17,52 +17,39 @@ The SimpleStories models are a tiny model family created for interpretability re
17
 
18
  ## Usage
19
 
20
- ```bash
21
- pip install simple_stories_train
22
- ```
23
-
24
  ```python
25
- from transformers import AutoTokenizer
26
  import torch
 
27
 
28
- from simple_stories_train.models.llama import Llama
29
- from simple_stories_train.models.model_configs import MODEL_CONFIGS
30
-
31
- # Select the model size you want to use
32
- model_size = "11M" # Options: "35M", "30M", "11M", "5M", "1.25M"
33
 
34
- # Load model configuration
35
- model_config = MODEL_CONFIGS[model_size]
36
 
37
- # Load appropriate model
38
- model_path = f"SimpleStories/SimpleStories-{model_size}"
39
- model = Llama.from_pretrained(model_path, model_config)
40
- device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
41
- model.to(device)
42
  model.eval()
43
 
44
- # Load tokenizer
45
- tokenizer = AutoTokenizer.from_pretrained(model_path)
46
 
47
- # Define your prompt
48
  prompt = "The curious cat looked at the"
49
 
50
- inputs = tokenizer(prompt, return_tensors="pt")
51
- input_ids = inputs.input_ids.to(device)
52
 
53
- # Generate text
 
 
 
 
54
  with torch.no_grad():
55
  output_ids = model.generate(
56
- idx=input_ids,
57
- max_new_tokens=50,
58
- temperature=0.0,
59
- top_k=40,
60
- eos_token_id=tokenizer.eos_token_id
61
- )
62
-
63
- # Decode output
64
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
- print(f"Generated text:\n{output_text}")
66
 
67
  ```
68
 
@@ -99,3 +86,7 @@ The SimpleStories dataset is a collection of short stories generated by state-of
99
  - ASCII-only guarantee for the English dataset
100
 
101
  Read the dataset paper on [arXiv](https://arxiv.org/abs/2504.09184).
 
 
 
 
 
17
 
18
  ## Usage
19
 
 
 
 
 
20
  ```python
 
21
  import torch
22
+ from transformers import AutoTokenizer, LlamaForCausalLM
23
 
 
 
 
 
 
24
 
25
+ MODEL_SIZE = "11M"
26
+ model_path = "SimpleStories/SimpleStories-{}".format(MODEL_SIZE)
27
 
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+ model = LlamaForCausalLM.from_pretrained(model_path)
30
+ model.to("cuda")
 
 
31
  model.eval()
32
 
 
 
33
 
 
34
  prompt = "The curious cat looked at the"
35
 
 
 
36
 
37
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
38
+ input_ids = inputs.input_ids.to("cuda")
39
+
40
+ eos_token_id = 1
41
+
42
  with torch.no_grad():
43
  output_ids = model.generate(
44
+ input_ids=input_ids,
45
+ max_new_tokens=400,
46
+ temperature=0.7,
47
+ do_sample=True,
48
+ eos_token_id=eos_token_id
49
+ )
50
+
 
51
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
52
+ print(f"\nGenerated text:\n{output_text}")
53
 
54
  ```
55
 
 
86
  - ASCII-only guarantee for the English dataset
87
 
88
  Read the dataset paper on [arXiv](https://arxiv.org/abs/2504.09184).
89
+
90
+ ## Training
91
+
92
+ The training and evaluation scripts can be accessed at https://github.com/danbraunai/simple_stories_train