Spaces:
Running
on
Zero
Running
on
Zero
Update multipurpose_chatbot/engines/transformers_engine.py
Browse files
multipurpose_chatbot/engines/transformers_engine.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
|
|
|
2 |
import os
|
3 |
import numpy as np
|
4 |
import argparse
|
@@ -420,7 +421,8 @@ class TransformersEngine(BaseEngine):
|
|
420 |
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
421 |
print(self._model)
|
422 |
print(f"{self.max_position_embeddings=}")
|
423 |
-
|
|
|
424 |
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
425 |
|
426 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
@@ -428,7 +430,7 @@ class TransformersEngine(BaseEngine):
|
|
428 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
429 |
num_tokens = inputs.input_ids.size(1)
|
430 |
|
431 |
-
inputs = inputs.to(self.
|
432 |
|
433 |
generator = self._model.generate(
|
434 |
**inputs,
|
|
|
1 |
|
2 |
+
import spaces
|
3 |
import os
|
4 |
import numpy as np
|
5 |
import argparse
|
|
|
421 |
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
422 |
print(self._model)
|
423 |
print(f"{self.max_position_embeddings=}")
|
424 |
+
|
425 |
+
@spaces.GPU
|
426 |
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
427 |
|
428 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
|
|
430 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
431 |
num_tokens = inputs.input_ids.size(1)
|
432 |
|
433 |
+
inputs = inputs.to(self._model.device)
|
434 |
|
435 |
generator = self._model.generate(
|
436 |
**inputs,
|