miittnnss commited on
Commit
dd74c6d
1 Parent(s): d2d0116

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +5 -1
pipeline.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
 
2
  class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin):
3
  def __init__(self, input_size=45, hidden_size=512, output_size=45, num_layers=2, dropout=0.5):
@@ -23,7 +27,7 @@ class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin):
23
  class PreTrainedPipeline():
24
  def __init__(self, path=""):
25
  self.model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
26
- self.chars = ""
27
  self.char_to_index = {char: index for index, char in enumerate(self.chars)}
28
  self.index_to_char = {index: char for char, index in self.char_to_index.items()}
29
  self.output_size = len(chars)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+ from huggingface_hub import PyTorchModelHubMixin
5
 
6
  class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin):
7
  def __init__(self, input_size=45, hidden_size=512, output_size=45, num_layers=2, dropout=0.5):
 
27
  class PreTrainedPipeline():
28
  def __init__(self, path=""):
29
  self.model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
30
+ self.chars = "!',.;ACDFGHIMORSTWabcdefghijklmnopqrstuvwxy"
31
  self.char_to_index = {char: index for index, char in enumerate(self.chars)}
32
  self.index_to_char = {index: char for char, index in self.char_to_index.items()}
33
  self.output_size = len(chars)