ThomasBlumet commited on
Commit
72903e4
·
1 Parent(s): 2d4b9ba

changeto run on GPU

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -11,16 +11,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
11
  #model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto",trust_remote_code=False,revision="main")
13
 
 
 
 
14
  # Generate text using the model and tokenizer
15
  def generate_text(input_text):
16
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
17
  #attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
18
  output = model.generate(input_ids, max_new_tokens=512, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)# attention_mask=attention_mask, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)
19
  return tokenizer.decode(output[0])
20
 
21
  # Example of disabling Exllama backend (if applicable in your configuration)
22
- config = {"disable_exllama": True}
23
- model.config.update(config)
24
 
25
  # def generate_text(prompt):
26
  # inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
 
11
  #model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto",trust_remote_code=False,revision="main")
13
 
14
+ #transfer model on GPU
15
+ model.to("cuda")
16
+
17
  # Generate text using the model and tokenizer
18
  def generate_text(input_text):
19
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda")
20
  #attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
21
  output = model.generate(input_ids, max_new_tokens=512, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)# attention_mask=attention_mask, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)
22
  return tokenizer.decode(output[0])
23
 
24
  # Example of disabling Exllama backend (if applicable in your configuration)
25
+ #config = {"disable_exllama": True}
26
+ #model.config.update(config)
27
 
28
  # def generate_text(prompt):
29
  # inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512, padding="max_length")