Update model.py
Browse files
model.py
CHANGED
@@ -1,38 +1,30 @@
|
|
1 |
import tensorflow as tf
|
2 |
-
from tensorflow.keras import
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
def
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
#
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
test_labels = tf.keras.utils.to_categorical(test_labels, 10)
|
28 |
-
|
29 |
-
# Create an instance of your model
|
30 |
-
model = create_model()
|
31 |
-
|
32 |
-
# Compile the model
|
33 |
-
model.compile(optimizer='adam',
|
34 |
-
loss='categorical_crossentropy',
|
35 |
-
metrics=['accuracy'])
|
36 |
-
|
37 |
-
# Train the model
|
38 |
-
model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_data=(test_images, test_labels))
|
|
|
1 |
import tensorflow as tf
|
2 |
+
from tensorflow.keras.layers import Input, Embedding, LayerNormalization, MultiHeadAttention, Dense, Add, Dropout, Layer
|
3 |
+
from tensorflow.keras.models import Model
|
4 |
+
from tensorflow.keras.optimizers import Adam
|
5 |
+
from tensorflow.keras.losses import SparseCategoricalCrossentropy
|
6 |
+
import numpy as np
|
7 |
|
8 |
+
class VoidChatModel(tf.keras.Model):
|
9 |
+
def __init__(self, vocab_size, seq_len, num_layers=6, num_heads=8, emb_dim=512, mlp_dim=2048, dropout_rate=0.1):
|
10 |
+
super(VoidChatModel, self).__init__()
|
11 |
+
self.vocab_size = vocab_size
|
12 |
+
self.seq_len = seq_len
|
13 |
+
self.num_layers = num_layers
|
14 |
+
self.num_heads = num_heads
|
15 |
+
self.emb_dim = emb_dim
|
16 |
+
self.mlp_dim = mlp_dim
|
17 |
+
self.dropout_rate = dropout_rate
|
18 |
+
|
19 |
+
# Embedding layer
|
20 |
+
self.embedding = Embedding(input_dim=vocab_size, output_dim=emb_dim)
|
21 |
+
|
22 |
+
# Transformer layers
|
23 |
+
self.transformer_blocks = [TransformerBlock(num_heads, emb_dim, mlp_dim, dropout_rate) for _ in range(num_layers)]
|
24 |
+
|
25 |
+
# Output layer
|
26 |
+
self.output_layer = Dense(vocab_size, activation='softmax')
|
27 |
+
|
28 |
+
def call(self, input_ids, training=False):
|
29 |
+
# Embedding layer
|
30 |
+
x = self.embedding(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|