ibrim commited on
Commit
0597dc6
·
verified ·
1 Parent(s): 6cf1d95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -62
app.py CHANGED
@@ -22,74 +22,75 @@ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported
22
  compile = False # use PyTorch 2.0 to compile the model to be faster
23
  #exec(open('configurator.py').read()) # overrides from command line or config file
24
  # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
25
 
26
- torch.manual_seed(seed)
27
- torch.cuda.manual_seed(seed)
28
- torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
29
- torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
30
- device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
31
- ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
32
- ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
 
 
 
 
 
 
 
 
 
33
 
34
- # model
35
- if init_from == 'resume':
36
- # init from a model saved in a specific directory
37
- ckpt_path = os.path.join(out_dir, 'ckpt.pt')
38
- checkpoint = torch.load(ckpt_path, map_location=device)
39
- gptconf = GPTConfig(**checkpoint['model_args'])
40
- model = GPT(gptconf)
41
- state_dict = checkpoint['model']
42
- unwanted_prefix = '_orig_mod.'
43
- for k,v in list(state_dict.items()):
44
- if k.startswith(unwanted_prefix):
45
- state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
46
- model.load_state_dict(state_dict)
47
- elif init_from.startswith('gpt2'):
48
- # init from a given GPT-2 model
49
- model = GPT.from_pretrained(init_from, dict(dropout=0.0))
50
 
51
- model.eval()
52
- model.to(device)
53
- if compile:
54
- model = torch.compile(model) # requires PyTorch 2.0 (optional)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # look for the meta pickle in case it is available in the dataset folder
57
- load_meta = False
58
- if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
59
- meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
60
- load_meta = os.path.exists(meta_path)
61
- if load_meta:
62
- print(f"Loading meta from {meta_path}...")
63
- with open(meta_path, 'rb') as f:
64
- meta = pickle.load(f)
65
- # TODO want to make this more general to arbitrary encoder/decoder schemes
66
- stoi, itos = meta['stoi'], meta['itos']
67
- encode = lambda s: [stoi[c] for c in s]
68
- decode = lambda l: ''.join([itos[i] for i in l])
69
- else:
70
- # ok let's assume gpt-2 encodings by default
71
- print("No meta.pkl found, assuming GPT-2 encodings...")
72
- enc = tiktoken.get_encoding("gpt2")
73
- encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
74
- decode = lambda l: enc.decode(l)
75
 
76
- # encode the beginning of the prompt
77
- if start.startswith('FILE:'):
78
- with open(start[5:], 'r', encoding='utf-8') as f:
79
- start = f.read()
80
- start_ids = encode(start)
81
- x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
 
82
 
83
- # run generation
84
- with torch.no_grad():
85
- with ctx:
86
- for k in range(num_samples):
87
- y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
88
- z = decode(y[0].tolist())
89
 
90
- def show_text(prompt=z):
91
- return prompt
92
- iface = gr.Interface(fn=show_text, inputs=[], outputs="textbox",
93
- title="GPT Text Generator", description="Enter a prompt to generate text.")
94
 
95
  iface.launch(share=True)
 
22
  compile = False # use PyTorch 2.0 to compile the model to be faster
23
  #exec(open('configurator.py').read()) # overrides from command line or config file
24
  # -----------------------------------------------------------------------------
25
+ def sample_from_trained_model(start="\n", init_from='resume', out_dir='out-shakespeare-char', num_samples=1,
26
+ max_new_tokens=500, temperature=0.8, top_k=200, seed=1337, device='cpu', compile=False):
27
+ torch.manual_seed(seed)
28
+ torch.cuda.manual_seed(seed)
29
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
30
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
31
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
32
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
33
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
34
 
35
+ # model
36
+ if init_from == 'resume':
37
+ # init from a model saved in a specific directory
38
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
39
+ checkpoint = torch.load(ckpt_path, map_location=device)
40
+ gptconf = GPTConfig(**checkpoint['model_args'])
41
+ model = GPT(gptconf)
42
+ state_dict = checkpoint['model']
43
+ unwanted_prefix = '_orig_mod.'
44
+ for k,v in list(state_dict.items()):
45
+ if k.startswith(unwanted_prefix):
46
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
47
+ model.load_state_dict(state_dict)
48
+ elif init_from.startswith('gpt2'):
49
+ # init from a given GPT-2 model
50
+ model = GPT.from_pretrained(init_from, dict(dropout=0.0))
51
 
52
+ model.eval()
53
+ model.to(device)
54
+ if compile:
55
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # look for the meta pickle in case it is available in the dataset folder
58
+ load_meta = False
59
+ if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
60
+ meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
61
+ load_meta = os.path.exists(meta_path)
62
+ if load_meta:
63
+ print(f"Loading meta from {meta_path}...")
64
+ with open(meta_path, 'rb') as f:
65
+ meta = pickle.load(f)
66
+ # TODO want to make this more general to arbitrary encoder/decoder schemes
67
+ stoi, itos = meta['stoi'], meta['itos']
68
+ encode = lambda s: [stoi[c] for c in s]
69
+ decode = lambda l: ''.join([itos[i] for i in l])
70
+ else:
71
+ # ok let's assume gpt-2 encodings by default
72
+ print("No meta.pkl found, assuming GPT-2 encodings...")
73
+ enc = tiktoken.get_encoding("gpt2")
74
+ encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
75
+ decode = lambda l: enc.decode(l)
76
 
77
+ # encode the beginning of the prompt
78
+ if start.startswith('FILE:'):
79
+ with open(start[5:], 'r', encoding='utf-8') as f:
80
+ start = f.read()
81
+ start_ids = encode(start)
82
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # run generation
85
+ with torch.no_grad():
86
+ with ctx:
87
+ for k in range(num_samples):
88
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
89
+ z = decode(y[0].tolist())
90
+ return z
91
 
 
 
 
 
 
 
92
 
93
+ iface = gr.Interface(fn=sample_from_trained_model, inputs=[], outputs="textbox",
94
+ title="GPT Shakespeare script Generator", description="Press button to generate shakespearean text")
 
 
95
 
96
  iface.launch(share=True)