NickyNicky commited on
Commit
e868007
·
verified ·
1 Parent(s): 1ead06e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
4
+ import json
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig
7
+ # from google.colab import userdata
8
+ import os
9
+
10
+ model_id = "somosnlp/gemma-1.1-2b-it_ColombiaRAC_FullyCurated_format_chatML_V1"
11
+ bnb_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_quant_type="nf4",
14
+ bnb_4bit_compute_dtype=torch.bfloat16
15
+ )
16
+ max_seq_length=400
17
+
18
+ # if torch.cuda.get_device_capability()[0] >= 8:
19
+ # # print("Flash Attention")
20
+ # attn_implementation="flash_attention_2"
21
+ # else:
22
+ # attn_implementation=None
23
+ attn_implementation=None
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id,
26
+ max_length = max_seq_length)
27
+ model = AutoModelForCausalLM.from_pretrained(model_id,
28
+ # quantization_config=bnb_config,
29
+ device_map = {"":0},
30
+ attn_implementation = attn_implementation, # A100 o H100
31
+ ).eval()
32
+
33
+
34
+
35
+ class ListOfTokensStoppingCriteria(StoppingCriteria):
36
+ """
37
+ Clase para definir un criterio de parada basado en una lista de tokens específicos.
38
+ """
39
+ def __init__(self, tokenizer, stop_tokens):
40
+ self.tokenizer = tokenizer
41
+ # Codifica cada token de parada y guarda sus IDs en una lista
42
+ self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens]
43
+
44
+ def __call__(self, input_ids, scores, **kwargs):
45
+ # Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada
46
+ for stop_token_ids in self.stop_token_ids_list:
47
+ len_stop_tokens = len(stop_token_ids)
48
+ if len(input_ids[0]) >= len_stop_tokens:
49
+ if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids:
50
+ return True
51
+ return False
52
+
53
+ # Uso del criterio de parada personalizado
54
+ stop_tokens = ["<end_of_turn>"] # Lista de tokens de parada
55
+
56
+ # Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada
57
+ stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)
58
+
59
+ # Añade tu criterio de parada a una StoppingCriteriaList
60
+ stopping_criteria_list = StoppingCriteriaList([stopping_criteria])
61
+
62
+ def generate_text(prompt, max_length=2100):
63
+ # prompt="""What were the main contributions of Eratosthenes to the development of mathematics in ancient Greece?"""
64
+ prompt=prompt.replace("\n", "").replace("¿","").replace("?","")
65
+
66
+
67
+ #EXAMPLE
68
+ input_text = f'''<bos><start_of_turn>system
69
+ You are a helpful AI assistant.
70
+ Responde en formato json.
71
+ Eres un agente experto en la normativa aeronautica Colombiana.<end_of_turn>
72
+ <start_of_turn>user
73
+ ¿{prompt}?<end_of_turn>
74
+ <start_of_turn>model
75
+ '''
76
+
77
+ inputs = tokenizer.encode(input_text,
78
+ return_tensors="pt",
79
+ add_special_tokens=False).to("cuda:0")
80
+ max_new_tokens=max_length
81
+ generation_config = GenerationConfig(
82
+ max_new_tokens=max_new_tokens,
83
+ temperature=0.32,
84
+ #top_p=0.9,
85
+ top_k=50, # 45
86
+ repetition_penalty=1.04, #1.1
87
+ do_sample=True,
88
+ )
89
+ outputs = model.generate(generation_config=generation_config,
90
+ input_ids=inputs,
91
+ stopping_criteria=stopping_criteria_list,)
92
+ return tokenizer.decode(outputs[0], skip_special_tokens=False) #True
93
+
94
+
95
+
96
+ def mostrar_respuesta(pregunta):
97
+ try:
98
+ res= generate_text(pregunta, max_length=500)
99
+ inicio_json = res.find('{')
100
+ fin_json = res.rfind('}') + 1
101
+ json_str = res[inicio_json:fin_json]
102
+ json_obj = json.loads(json_str)
103
+ # print(json_obj)
104
+ return json_obj["Respuesta"], json_obj["Pagina"], json_obj["Rac"]
105
+ except:
106
+ json_obj={}
107
+ json_obj['Respuesta']='Error'
108
+ json_obj['Pagina']='Error'
109
+ json_obj['Rac']='Error'
110
+ return json_obj
111
+
112
+ # Ejemplos de preguntas
113
+ ejemplos = [
114
+ ["¿Cuál fue la fecha de publicación del RAC 1 en el Diario Oficial?"],
115
+ ["¿Qué se incorpora a los Reglamentos Aeronáuticos de Colombia?"],
116
+ ["Cuál fue la fecha de publicación del RAC 1 en el Diario Oficial?"],
117
+ ]
118
+
119
+ iface = gr.Interface(
120
+ fn=mostrar_respuesta,
121
+ inputs=gr.Textbox(label="Pregunta"),
122
+ outputs=[
123
+ gr.Textbox(label="Respuesta", lines=2),
124
+ gr.Textbox(label="Pagina", lines=1),
125
+ gr.Textbox(label="Rac", lines=1)
126
+ ],
127
+ title="Consultas Normativa Aeronáutica Colombiana",
128
+ description="Introduce tu pregunta sobre la normativa aeronáutica colombiana para obtener una respuesta.",
129
+ examples=ejemplos,
130
+ )
131
+
132
+ iface.queue(max_size=14).launch() # share=True,debug=True