naxalpha commited on
Commit
9c22e75
·
1 Parent(s): 0af0146

add readme

Browse files
.ipynb_checkpoints/README-checkpoint.md CHANGED
@@ -1,3 +1,38 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # [Gated State Space](https://arxiv.org/abs/2206.13947)
6
+
7
+ This repo contains pretrain model for the gated state space paper. The model has been trained on [C4 dataset](https://huggingface.co/datasets/c4). I have used [Lucidrains' implementation](https://github.com/lucidrains/gated-state-spaces-pytorch) ([commit](https://github.com/lucidrains/gated-state-spaces-pytorch/tree/32cd036e775112cc469e94fa1165fe111393708b)) for the model. I think the main benefit of this model is the ability to scale beyond the training context length. As authors noted in the paper, they trained the model on 4k sequence length but it generalized beyond that length. I have written a blog post on how I started the training [here](https://naxalpha.substack.com/p/devlog-experiment-a2a468-gated-state).
8
+
9
+ [Here](https://wandb.ai/naxalpha/gated-state-space/reports/Gated-State-Space-Training-v1--VmlldzozMTYzMzY3?accessToken=zy10rrpofi9k7l52aqwiej8bk0ub302rdswfkxmf8y94dt2j6z4kxbca6ar3sc52) are the training logs (report) on W&B. Since the training is still going on, this repo currently contains the checkpoint @ ~200k step. The report contains both the loss and the output results. While the generated text somewhat makes sense, given the loss is still ~4.2, it is not useful out-of-the-box for most tasks. So need to fine-tune it before using it for anything else.
10
+
11
+ ## How to use this.
12
+
13
+ Since it is not based on [transformers](https://github.com/huggingface/transformers/) library, it is a bit tricky to use the model out of the box. Here are the general steps:
14
+
15
+ 1. `pip install gated-state-spaces-pytorch`
16
+ 2. Download the model weights from [here](https://huggingface.co/naxalpha/gated-state-space/raw/main/model.pt).
17
+ 3. Download the config from [here](https://huggingface.co/naxalpha/gated-state-space/raw/main/config.json).
18
+ 4. Following code to patch the original model:
19
+ ```python
20
+ model = AutoregressiveWrapper(
21
+ GatedStateSpacesLM(
22
+ **config
23
+ ),
24
+ )
25
+ model.net.to_logits = nn.Sequential(
26
+ nn.LayerNorm(f_emb),
27
+ model.net.to_logits,
28
+ )
29
+ ```
30
+ 5. Load the state dict: `model.load_state_dict(torch.load('model.pt'))`
31
+ 6. If you want to fine-tune the model, you can freeze the embeddings:
32
+ ```python
33
+ model.net.token_emb.weight.requires_grad_(False)
34
+ model.net.token_emb.weight.copy_(emb)
35
+
36
+ model.net.to_logits[1].weight.requires_grad_(False)
37
+ model.net.to_logits[1].weight.copy_(emb)
38
+ ```
README.md CHANGED
@@ -1,3 +1,38 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # [Gated State Space](https://arxiv.org/abs/2206.13947)
6
+
7
+ This repo contains pretrain model for the gated state space paper. The model has been trained on [C4 dataset](https://huggingface.co/datasets/c4). I have used [Lucidrains' implementation](https://github.com/lucidrains/gated-state-spaces-pytorch) ([commit](https://github.com/lucidrains/gated-state-spaces-pytorch/tree/32cd036e775112cc469e94fa1165fe111393708b)) for the model. I think the main benefit of this model is the ability to scale beyond the training context length. As authors noted in the paper, they trained the model on 4k sequence length but it generalized beyond that length. I have written a blog post on how I started the training [here](https://naxalpha.substack.com/p/devlog-experiment-a2a468-gated-state).
8
+
9
+ [Here](https://wandb.ai/naxalpha/gated-state-space/reports/Gated-State-Space-Training-v1--VmlldzozMTYzMzY3?accessToken=zy10rrpofi9k7l52aqwiej8bk0ub302rdswfkxmf8y94dt2j6z4kxbca6ar3sc52) are the training logs (report) on W&B. Since the training is still going on, this repo currently contains the checkpoint @ ~200k step. The report contains both the loss and the output results. While the generated text somewhat makes sense, given the loss is still ~4.2, it is not useful out-of-the-box for most tasks. So need to fine-tune it before using it for anything else.
10
+
11
+ ## How to use this.
12
+
13
+ Since it is not based on [transformers](https://github.com/huggingface/transformers/) library, it is a bit tricky to use the model out of the box. Here are the general steps:
14
+
15
+ 1. `pip install gated-state-spaces-pytorch`
16
+ 2. Download the model weights from [here](https://huggingface.co/naxalpha/gated-state-space/raw/main/model.pt).
17
+ 3. Download the config from [here](https://huggingface.co/naxalpha/gated-state-space/raw/main/config.json).
18
+ 4. Following code to patch the original model:
19
+ ```python
20
+ model = AutoregressiveWrapper(
21
+ GatedStateSpacesLM(
22
+ **config
23
+ ),
24
+ )
25
+ model.net.to_logits = nn.Sequential(
26
+ nn.LayerNorm(f_emb),
27
+ model.net.to_logits,
28
+ )
29
+ ```
30
+ 5. Load the state dict: `model.load_state_dict(torch.load('model.pt'))`
31
+ 6. If you want to fine-tune the model, you can freeze the embeddings:
32
+ ```python
33
+ model.net.token_emb.weight.requires_grad_(False)
34
+ model.net.token_emb.weight.copy_(emb)
35
+
36
+ model.net.to_logits[1].weight.requires_grad_(False)
37
+ model.net.to_logits[1].weight.copy_(emb)
38
+ ```