nxphi47 commited on
Commit
f4b3d1c
1 Parent(s): 226d418

Update multipurpose_chatbot/engines/transformers_engine.py

Browse files
multipurpose_chatbot/engines/transformers_engine.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import numpy as np
5
  import argparse
6
  import torch
 
7
  import gradio as gr
8
  from typing import Any, Iterator
9
  from typing import Iterator, List, Optional, Tuple
@@ -427,6 +428,7 @@ class TransformersEngine(BaseEngine):
427
  def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
428
 
429
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
 
430
  with torch.no_grad():
431
  inputs = self.tokenizer(prompt, return_tensors='pt')
432
  num_tokens = inputs.input_ids.size(1)
 
4
  import numpy as np
5
  import argparse
6
  import torch
7
+ import sys
8
  import gradio as gr
9
  from typing import Any, Iterator
10
  from typing import Iterator, List, Optional, Tuple
 
428
  def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
429
 
430
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
431
+ import sys
432
  with torch.no_grad():
433
  inputs = self.tokenizer(prompt, return_tensors='pt')
434
  num_tokens = inputs.input_ids.size(1)