nxphi47 commited on
Commit
55ebaa5
1 Parent(s): 41936ab

Update multipurpose_chatbot/engines/transformers_engine.py

Browse files
multipurpose_chatbot/engines/transformers_engine.py CHANGED
@@ -429,7 +429,8 @@ class TransformersEngine(BaseEngine):
429
 
430
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
431
  import sys
432
- self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
 
433
  with torch.no_grad():
434
  inputs = self.tokenizer(prompt, return_tensors='pt')
435
  num_tokens = inputs.input_ids.size(1)
 
429
 
430
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
431
  import sys
432
+ # self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
433
+ self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
434
  with torch.no_grad():
435
  inputs = self.tokenizer(prompt, return_tensors='pt')
436
  num_tokens = inputs.input_ids.size(1)