MadsGalsgaard commited on
Commit
4dbd8e9
1 Parent(s): dd3578f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -186
app.py CHANGED
@@ -442,158 +442,24 @@
442
  ###########new clientkey
443
 
444
 
445
- # import os
446
- # import time
447
- # import spaces
448
- # import torch
449
- # from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
450
- # import gradio as gr
451
- # from threading import Thread
452
-
453
- # MODEL = "THUDM/LongWriter-llama3.1-8b"
454
-
455
- # TITLE = "<h1><center>AreaX LLC-llama3.1-8b</center></h1>"
456
-
457
- # PLACEHOLDER = """
458
- # <center>
459
- # <p>Hi! I'm AreaX AI Agent, capable of generating 10,000+ words. How can I assist you today?</p>
460
- # </center>
461
- # """
462
-
463
- # CSS = """
464
- # .duplicate-button {
465
- # margin: auto !important;
466
- # color: white !important;
467
- # background: black !important;
468
- # border-radius: 100vh !important;
469
- # }
470
- # h3 {
471
- # text-align: center;
472
- # }
473
- # """
474
-
475
- # device = "cuda" if torch.cuda.is_available() else "cpu"
476
-
477
- # tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
478
- # model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
479
- # model = model.eval()
480
-
481
- # @spaces.GPU()
482
- # def stream_chat(
483
- # message: str,
484
- # history: list,
485
- # system_prompt: str,
486
- # temperature: float = 0.5,
487
- # max_new_tokens: int = 32768,
488
- # top_p: float = 1.0,
489
- # top_k: int = 50,
490
- # ):
491
- # print(f'message: {message}')
492
- # print(f'history: {history}')
493
-
494
- # full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
495
- # for prompt, answer in history:
496
- # full_prompt += f"[INST]{prompt}[/INST]{answer}"
497
- # full_prompt += f"[INST]{message}[/INST]"
498
-
499
- # inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
500
- # context_length = inputs.input_ids.shape[-1]
501
-
502
- # streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
503
-
504
- # generate_kwargs = dict(
505
- # inputs=inputs.input_ids,
506
- # max_new_tokens=max_new_tokens,
507
- # do_sample=True,
508
- # top_p=top_p,
509
- # top_k=top_k,
510
- # temperature=temperature,
511
- # num_beams=1,
512
- # streamer=streamer,
513
- # )
514
-
515
- # thread = Thread(target=model.generate, kwargs=generate_kwargs)
516
- # thread.start()
517
-
518
- # buffer = ""
519
- # for new_text in streamer:
520
- # buffer += new_text
521
- # yield buffer
522
-
523
- # chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
524
-
525
- # with gr.Blocks(css=CSS, theme="soft") as demo:
526
- # gr.HTML(TITLE)
527
- # gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
528
- # gr.ChatInterface(
529
- # fn=stream_chat,
530
- # chatbot=chatbot,
531
- # fill_height=True,
532
- # additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
533
- # additional_inputs=[
534
- # gr.Textbox(
535
- # value="You are a helpful assistant capable of generating long-form content.",
536
- # label="System Prompt",
537
- # render=False,
538
- # ),
539
- # gr.Slider(
540
- # minimum=0,
541
- # maximum=1,
542
- # step=0.1,
543
- # value=0.5,
544
- # label="Temperature",
545
- # render=False,
546
- # ),
547
- # gr.Slider(
548
- # minimum=1024,
549
- # maximum=32768,
550
- # step=1024,
551
- # value=32768,
552
- # label="Max new tokens",
553
- # render=False,
554
- # ),
555
- # gr.Slider(
556
- # minimum=0.0,
557
- # maximum=1.0,
558
- # step=0.1,
559
- # value=1.0,
560
- # label="Top p",
561
- # render=False,
562
- # ),
563
- # gr.Slider(
564
- # minimum=1,
565
- # maximum=100,
566
- # step=1,
567
- # value=50,
568
- # label="Top k",
569
- # render=False,
570
- # ),
571
- # ],
572
- # examples=[
573
- # ["Write a 5000-word comprehensive guide on machine learning for beginners."],
574
- # ["Create a detailed 3000-word business plan for a sustainable energy startup."],
575
- # ["Compose a 2000-word short story set in a futuristic underwater city."],
576
- # ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
577
- # ],
578
- # cache_examples=False,
579
- # )
580
-
581
- # if __name__ == "__main__":
582
- # demo.launch()
583
-
584
  import torch
585
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
586
  import gradio as gr
587
  from threading import Thread
588
 
589
- # Model and constants
590
  MODEL = "THUDM/LongWriter-llama3.1-8b"
 
591
  TITLE = "<h1><center>AreaX LLC-llama3.1-8b</center></h1>"
 
592
  PLACEHOLDER = """
593
  <center>
594
  <p>Hi! I'm AreaX AI Agent, capable of generating 10,000+ words. How can I assist you today?</p>
595
  </center>
596
  """
 
597
  CSS = """
598
  .duplicate-button {
599
  margin: auto !important;
@@ -606,61 +472,54 @@ h3 {
606
  }
607
  """
608
 
609
- # Check device
610
  device = "cuda" if torch.cuda.is_available() else "cpu"
611
 
612
- # Load model and tokenizer
613
  tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
614
- model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto").eval()
 
615
 
 
616
  def stream_chat(
617
  message: str,
618
  history: list,
619
  system_prompt: str,
620
  temperature: float = 0.5,
621
- max_new_tokens: int = 4096, # Lowered max tokens for efficiency
622
  top_p: float = 1.0,
623
  top_k: int = 50,
624
  ):
625
- try:
626
- full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
627
- for prompt, answer in history:
628
- full_prompt += f"[INST]{prompt}[/INST]{answer}"
629
- full_prompt += f"[INST]{message}[/INST]"
630
-
631
- # Tokenize input
632
- inputs = tokenizer(full_prompt, truncation=True, max_length=2048, return_tensors="pt").to(device)
633
- context_length = inputs.input_ids.shape[-1]
634
-
635
- # Setup TextIteratorStreamer for streaming response
636
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
637
-
638
- # Generation parameters
639
- generate_kwargs = dict(
640
- inputs=inputs.input_ids,
641
- max_new_tokens=max_new_tokens,
642
- do_sample=True,
643
- top_p=top_p,
644
- top_k=top_k,
645
- temperature=temperature,
646
- num_beams=1,
647
- streamer=streamer,
648
- )
649
-
650
- # Generate text in a separate thread to avoid blocking
651
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
652
- thread.start()
653
-
654
- # Stream response
655
- buffer = ""
656
- for new_text in streamer:
657
- buffer += new_text
658
- yield buffer
659
-
660
- except Exception as e:
661
- yield f"An error occurred: {str(e)}"
662
-
663
- # Gradio setup
664
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
665
 
666
  with gr.Blocks(css=CSS, theme="soft") as demo:
@@ -687,9 +546,9 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
687
  ),
688
  gr.Slider(
689
  minimum=1024,
690
- maximum=4096, # Reduced to a more manageable value
691
  step=1024,
692
- value=4096,
693
  label="Max new tokens",
694
  render=False,
695
  ),
@@ -710,7 +569,14 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
710
  render=False,
711
  ),
712
  ],
 
 
 
 
 
 
 
713
  )
714
 
715
  if __name__ == "__main__":
716
- demo.launch()
 
442
  ###########new clientkey
443
 
444
 
445
+ import os
446
+ import time
447
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  import torch
449
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
450
  import gradio as gr
451
  from threading import Thread
452
 
 
453
  MODEL = "THUDM/LongWriter-llama3.1-8b"
454
+
455
  TITLE = "<h1><center>AreaX LLC-llama3.1-8b</center></h1>"
456
+
457
  PLACEHOLDER = """
458
  <center>
459
  <p>Hi! I'm AreaX AI Agent, capable of generating 10,000+ words. How can I assist you today?</p>
460
  </center>
461
  """
462
+
463
  CSS = """
464
  .duplicate-button {
465
  margin: auto !important;
 
472
  }
473
  """
474
 
 
475
  device = "cuda" if torch.cuda.is_available() else "cpu"
476
 
 
477
  tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
478
+ model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
479
+ model = model.eval()
480
 
481
+ @spaces.GPU()
482
  def stream_chat(
483
  message: str,
484
  history: list,
485
  system_prompt: str,
486
  temperature: float = 0.5,
487
+ max_new_tokens: int = 32768,
488
  top_p: float = 1.0,
489
  top_k: int = 50,
490
  ):
491
+ print(f'message: {message}')
492
+ print(f'history: {history}')
493
+
494
+ full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
495
+ for prompt, answer in history:
496
+ full_prompt += f"[INST]{prompt}[/INST]{answer}"
497
+ full_prompt += f"[INST]{message}[/INST]"
498
+
499
+ inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
500
+ context_length = inputs.input_ids.shape[-1]
501
+
502
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
503
+
504
+ generate_kwargs = dict(
505
+ inputs=inputs.input_ids,
506
+ max_new_tokens=max_new_tokens,
507
+ do_sample=True,
508
+ top_p=top_p,
509
+ top_k=top_k,
510
+ temperature=temperature,
511
+ num_beams=1,
512
+ streamer=streamer,
513
+ )
514
+
515
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
516
+ thread.start()
517
+
518
+ buffer = ""
519
+ for new_text in streamer:
520
+ buffer += new_text
521
+ yield buffer
522
+
 
 
 
 
 
 
 
523
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
524
 
525
  with gr.Blocks(css=CSS, theme="soft") as demo:
 
546
  ),
547
  gr.Slider(
548
  minimum=1024,
549
+ maximum=32768,
550
  step=1024,
551
+ value=32768,
552
  label="Max new tokens",
553
  render=False,
554
  ),
 
569
  render=False,
570
  ),
571
  ],
572
+ # examples=[
573
+ # ["Write a 5000-word comprehensive guide on machine learning for beginners."],
574
+ # ["Create a detailed 3000-word business plan for a sustainable energy startup."],
575
+ # ["Compose a 2000-word short story set in a futuristic underwater city."],
576
+ # ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
577
+ # ],
578
+ # cache_examples=False,
579
  )
580
 
581
  if __name__ == "__main__":
582
+ demo.launch()