random weights for testing megalodon implementation
arch
about 200m params
======================================================================
Layer (type:depth-idx) Output Shape Param # Trainable
======================================================================
MegalodonForCausalLM 227,707,904 True
MegalodonModel 194,151,424 True
Embedding 33,554,432 True
ModuleList 160,594,944 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
MegalodonBlock 13,382,912 True
TimestepNorm 2,048 True
MegalodonLMHead 33,556,480 True
TimestepNorm 2,048 True
Linear 33,554,432 True
======================================================================
Total params: 227,707,904
Trainable params: 227,707,904
Non-trainable params: --
======================================================================
testing
dummy inputs
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("pszemraj/megalodon-random-small", trust_remote_code=True)
# Dummy input for testing
input_ids = torch.randint(0, model.config.vocab_size, (2, 50)) # (batch_size, seq_length)
attention_mask = torch.ones(2, 50) # (batch_size, seq_length)
outputs = model(
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
)
loss = outputs.loss
logits = outputs.logits
print(f"Loss: {loss.item()}")
print(
f"Logits shape: {logits.shape}"
) # Expected: (batch_size, seq_length, vocab_size)
with tokenizer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"pszemraj/megalodon-random-small", trust_remote_code=True, device_map="auto"
)
tk = AutoTokenizer.from_pretrained("pszemraj/megalodon-random-small")
prompt = "write a deep poem about Stroopwafels (the Dutch ones)"
inputs = tk(prompt, return_tensors="pt").to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=64)
print(
tk.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
)
- Downloads last month
- 90
Inference API (serverless) does not yet support model repos that contain custom code.