vilarin commited on
Commit
0a38613
·
verified ·
1 Parent(s): 112e28c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import LlamaForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
  from polyglot.detect import Detector
@@ -15,7 +15,7 @@ TITLE = "<h1><center>LLaMAX3-8B-Translation</center></h1>"
15
 
16
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
17
 
18
- model = LlamaForCausalLM.from_pretrained(
19
  MODEL,
20
  torch_dtype=torch.bfloat16,
21
  device_map="auto",
@@ -61,12 +61,12 @@ def translate(
61
  print(f'Text is - {source_text}')
62
 
63
  prompt = Prompt_template(source_text, source_lang, target_lang)
64
- inputs = tokenizer(prompt, return_tensors="pt")
65
 
66
  streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
67
 
68
  generate_kwargs = dict(
69
- inputs.input_ids,
70
  streamer=streamer,
71
  max_length=max_length,
72
  do_sample=True,
 
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
  from polyglot.detect import Detector
 
15
 
16
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
17
 
18
+ model = AutoModelForCausalLM.from_pretrained(
19
  MODEL,
20
  torch_dtype=torch.bfloat16,
21
  device_map="auto",
 
61
  print(f'Text is - {source_text}')
62
 
63
  prompt = Prompt_template(source_text, source_lang, target_lang)
64
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
65
 
66
  streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
67
 
68
  generate_kwargs = dict(
69
+ input_ids=input_ids,
70
  streamer=streamer,
71
  max_length=max_length,
72
  do_sample=True,