Upload random_utils.py with huggingface_hub
Browse files- random_utils.py +39 -0
random_utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import random as python_random
|
3 |
+
import string
|
4 |
+
import threading
|
5 |
+
|
6 |
+
__default_seed__ = 42
|
7 |
+
_thread_local = threading.local()
|
8 |
+
_thread_local.seed = __default_seed__
|
9 |
+
_thread_local.random = python_random.Random()
|
10 |
+
random = _thread_local.random
|
11 |
+
|
12 |
+
|
13 |
+
def set_seed(seed):
|
14 |
+
_thread_local.random.seed(seed)
|
15 |
+
_thread_local.seed = seed
|
16 |
+
|
17 |
+
|
18 |
+
def get_seed():
|
19 |
+
return _thread_local.seed
|
20 |
+
|
21 |
+
|
22 |
+
def get_random_string(length):
|
23 |
+
letters = string.ascii_letters
|
24 |
+
result_str = "".join(random.choice(letters) for _ in range(length))
|
25 |
+
return result_str
|
26 |
+
|
27 |
+
|
28 |
+
@contextlib.contextmanager
|
29 |
+
def nested_seed(sub_seed=None):
|
30 |
+
state = _thread_local.random.getstate()
|
31 |
+
old_global_seed = get_seed()
|
32 |
+
sub_seed = sub_seed or get_random_string(10)
|
33 |
+
new_global_seed = str(old_global_seed) + "/" + sub_seed
|
34 |
+
set_seed(new_global_seed)
|
35 |
+
try:
|
36 |
+
yield _thread_local.random
|
37 |
+
finally:
|
38 |
+
set_seed(old_global_seed)
|
39 |
+
_thread_local.random.setstate(state)
|