keminglu commited on
Commit
7582cbf
·
1 Parent(s): b831299

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -3
app.py CHANGED
@@ -3,12 +3,73 @@ import torch
3
  import json
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
 
 
 
 
 
6
  tokenizer = AutoTokenizer.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", padding_side="left")
7
  model = AutoModelForCausalLM.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", torch_dtype=torch.float16)
8
- #input_device = torch.device("cuda:5")
9
  model.requires_grad_(False)
10
  model.eval()
11
- #model = model.to(input_device)
 
12
 
13
  examples = json.load(open("examples.json"))
14
- description = open("description.txt").read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import json
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ if torch.cuda.is_available():
7
+ use_cuda = True
8
+ else:
9
+ use_cuda = False
10
+
11
  tokenizer = AutoTokenizer.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", padding_side="left")
12
  model = AutoModelForCausalLM.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", torch_dtype=torch.float16)
 
13
  model.requires_grad_(False)
14
  model.eval()
15
+ if use_cuda:
16
+ model = model.to("cuda")
17
 
18
  examples = json.load(open("examples.json"))
19
+ description = open("description.txt").read()
20
+
21
+ def inference(context, instruction, num_beams:int=4):
22
+ input_str = f"\"{context}\"\n\n{instruction}"
23
+ if not input_str.endswith("."):
24
+ input_str += "."
25
+
26
+ input_tokens = tokenizer(input_str, return_tensors="pt", padding=True)
27
+ if use_cuda:
28
+ for t in input_tokens:
29
+ if torch.is_tensor(input_tokens[t]):
30
+ input_tokens[t] = input_tokens[t].to(input_device)
31
+
32
+ output = model.generate(
33
+ input_tokens['input_ids'],
34
+ num_beams=num_beams,
35
+ do_sample=False,
36
+ max_new_tokens=2048,
37
+ num_return_sequences=1,
38
+ return_dict_in_generate=True,
39
+ )
40
+
41
+ num_input_tokens = input_tokens["input_ids"].shape[1]
42
+ output_tokens = output.sequences
43
+ generated_tokens = output_tokens[:, num_input_tokens:]
44
+ num_generated_tokens = (generated_tokens != tokenizer.pad_token_id).sum(dim=-1).tolist()[0]
45
+ prefix_to_add = torch.tensor([[tokenizer("A")["input_ids"][0]]]).to(input_device)
46
+ generated_tokens = torch.cat([prefix_to_add, generated_tokens], dim=1)
47
+ generated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
48
+ string_output = [i[1:].strip() for i in generated_text][0]
49
+ json_output = None
50
+ try:
51
+ json_output = json.loads(string_output)
52
+ except json.JSONDecodeError:
53
+ json_output = {"error": "Unfortunately, there is a JSON decode error on your output, which is really rare in our experiment :("}
54
+ except Exception as e:
55
+ raise gr.Error(e)
56
+
57
+ return num_generated_tokens, string_output, json_output
58
+
59
+ demo = gr.Interface(
60
+ fn=inference,
61
+ inputs=["text", "text", gr.Slider(1,5,value=4,step=1)],
62
+ outputs=[
63
+ gr.Number(label="Number of Generated Tokens"),
64
+ gr.Textbox(label="Raw String Output"),
65
+ gr.JSON(label="Json Output")],
66
+ examples=examples,
67
+ examples_per_page=3,
68
+ title="Instruction-following Open-world Information Extraction",
69
+ description=description,
70
+ )
71
+
72
+ demo.launch(
73
+ share=True,
74
+ auth=("miega", "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"),
75
+ show_error=True)