chandrujobs commited on
Commit
c8c940e
·
verified ·
1 Parent(s): 837c657

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -3,12 +3,12 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
  @st.cache_resource
5
  def load_model():
6
- model_name = "Salesforce/codet5-small"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
  return tokenizer, model
10
 
11
- # Load model
12
  tokenizer, model = load_model()
13
 
14
  st.title("Code Generator")
@@ -21,14 +21,25 @@ if st.button("Generate Code"):
21
  if prompt.strip():
22
  with st.spinner("Generating code..."):
23
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
24
- outputs = model.generate(inputs.input_ids, max_length=max_length, num_beams=5, temperature=0.7, early_stopping=True)
25
 
 
 
 
 
 
 
 
 
 
 
26
  st.write("### Debugging: Raw Model Output")
27
- st.json(outputs.tolist()) # Debugging output
 
 
 
28
 
29
- generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
-
31
  st.write("### Generated Code:")
32
  st.code(generated_code, language="python")
 
33
  else:
34
  st.warning("Please enter a prompt!")
 
3
 
4
  @st.cache_resource
5
  def load_model():
6
+ model_name = "Salesforce/codet5-base" # Switch to 'codet5-base' for better results
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
  return tokenizer, model
10
 
11
+ # Load the model and tokenizer
12
  tokenizer, model = load_model()
13
 
14
  st.title("Code Generator")
 
21
  if prompt.strip():
22
  with st.spinner("Generating code..."):
23
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
 
24
 
25
+ # Use sampling-based generation for better quality
26
+ outputs = model.generate(
27
+ inputs.input_ids,
28
+ max_length=max_length,
29
+ temperature=0.7,
30
+ top_p=0.95,
31
+ do_sample=True,
32
+ )
33
+
34
+ # Debugging: Show raw token output
35
  st.write("### Debugging: Raw Model Output")
36
+ st.json(outputs.tolist())
37
+
38
+ # Decode tokens properly
39
+ generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
40
 
 
 
41
  st.write("### Generated Code:")
42
  st.code(generated_code, language="python")
43
+
44
  else:
45
  st.warning("Please enter a prompt!")