Spaces:
Sleeping
Sleeping
import gspread | |
from oauth2client.service_account import ServiceAccountCredentials | |
from typing import Dict | |
class SheetCRUDRepository: | |
def __init__(self, worksheet): | |
self.worksheet = worksheet | |
self.titles = self.worksheet.row_values(1) # Assuming titles are in the first row | |
assert len(set(self.titles)) == len(self.titles), f"Failed to init {SheetCRUDRepository.__class__}, titles: {self.titles} contain duplicated values!" | |
def create(self, data: Dict): | |
values = [data.get(title, '') for title in self.titles] | |
self.worksheet.append_row(values) | |
def read(self, row_index: int) -> Dict: | |
""" | |
return {} if empty | |
""" | |
values = self.worksheet.row_values(row_index) | |
return {title: value for title, value in zip(self.titles, values)} | |
def update(self, row_index: int, data: Dict): | |
values = [data.get(title, '') for title in self.titles] | |
self.worksheet.update(f"A{row_index}:Z{row_index}", [values]) | |
def delete(self, row_index: int): | |
self.worksheet.delete_row(row_index) | |
def find(self, search_dict): | |
for col_title, value in search_dict.items(): | |
if col_title in self.titles: | |
col_index = self.titles.index(col_title) + 1 # Adding 1 to match gspread indexing | |
cell = self.worksheet.find(value, in_column=col_index) | |
if cell is None: | |
break | |
row_number = cell.row | |
return row_number, self.read(row_number) | |
return None | |
def create_repositories(): | |
scope = [ | |
'https://www.googleapis.com/auth/spreadsheets', | |
'https://www.googleapis.com/auth/drive' | |
] | |
creds = ServiceAccountCredentials.from_json_keyfile_name('credentials.json', scope) | |
client = gspread.authorize(creds) | |
# sheet_url = "https://docs.google.com/spreadsheets/d/17OxKF0iP_aJJ0HCgJkwFsH762EUrtcEIYcPmyiiKnaM" | |
sheet_url = "https://docs.google.com/spreadsheets/d/1KzUYgWwbvYXGfyehOTyZCCQf0udZiwVXxaxpmkXEe3E/edit?usp=sharing" | |
sheet = client.open_by_url(sheet_url) | |
config_repository = SheetCRUDRepository(sheet.get_worksheet(0)) | |
log_repository = SheetCRUDRepository(sheet.get_worksheet(1)) | |
return config_repository, log_repository | |
conf_repo, log_repo = create_repositories() | |
import platform,socket,re,uuid,json,psutil,logging | |
from datetime import datetime as dt | |
from google_sheet import log_repo | |
version="v1.0.0" | |
def get_sys_info(): | |
try: | |
info={} | |
info['platform']=platform.system() | |
info['platform-release']=platform.release() | |
info['platform-version']=platform.version() | |
info['architecture']=platform.machine() | |
info['hostname']=socket.gethostname() | |
info['ip-address']=socket.gethostbyname(socket.gethostname()) | |
info['mac-address']=':'.join(re.findall('..', '%012x' % uuid.getnode())) | |
info['processor']=platform.processor() | |
info['ram']=str(round(psutil.virtual_memory().total / (1024.0 **3)))+" GB" | |
return json.dumps(info) | |
except Exception as e: | |
logging.exception(e) | |
class SheetLogger: | |
def __init__(self, log_repo): | |
self.log_repo = log_repo | |
def log(self, log='', nb='', username=''): | |
self.log_repo.create({ | |
"time": str(dt.now()), | |
"notebook_name": nb, | |
"kaggle_username": username, | |
"log": log, | |
"device": str(get_sys_info()), | |
"version": version | |
}) | |
sheet_logger = SheetLogger(log_repo) | |
import json | |
import os | |
from typing import Callable, List, Union, Dict | |
# fake default account to use kaggle.api.kaggle_api_extended | |
os.environ['KAGGLE_USERNAME']='' | |
os.environ['KAGGLE_KEY']='' | |
from kaggle.api.kaggle_api_extended import KaggleApi | |
from kaggle.rest import ApiException | |
import shutil | |
import time | |
import threading | |
import copy | |
from logger import sheet_logger | |
def get_api(): | |
api = KaggleApi() | |
api.authenticate() | |
return api | |
class KaggleApiWrapper(KaggleApi): | |
""" | |
Override KaggleApi.read_config_environment to use username and secret without environment variables | |
""" | |
def __init__(self, username, secret): | |
super().__init__() | |
self.username = username | |
self.secret = secret | |
def read_config_environment(self, config_data=None, quiet=False): | |
config = super().read_config_environment(config_data, quiet) | |
config['username'] = self.username | |
config['key'] = self.secret | |
return config_data | |
def __del__(self): | |
# todo: fix bug when delete api | |
pass | |
class KaggleNotebook: | |
def __init__(self, api: KaggleApi, kernel_slug: str, container_path: str = "./tmp"): | |
""" | |
:param api: KaggleApi | |
:param kernel_slug: Notebook id, you can find it in the url of the notebook. | |
For example, `username/notebook-name-123456` | |
:param container_path: Path to the local folder where the notebook will be downloaded | |
""" | |
self.api = api | |
self.kernel_slug = kernel_slug | |
self.container_path = container_path | |
def status(self) -> str or None: | |
""" | |
:return: | |
"running" | |
"cancelAcknowledged" | |
"queued": waiting for run | |
"error": when raise exception in notebook | |
Throw exception when failed | |
""" | |
res = self.api.kernels_status(self.kernel_slug) | |
print(f"Status: {res}") | |
if res is None: | |
return None | |
return res['status'] | |
def _get_local_nb_path(self) -> str: | |
return os.path.join(self.container_path, self.kernel_slug) | |
def pull(self, path=None) -> str or None: | |
""" | |
:param path: | |
:return: | |
:raises: ApiException if notebook not found or not share to user | |
""" | |
self._clean() | |
path = path or self._get_local_nb_path() | |
metadata_path = os.path.join(path, "kernel-metadata.json") | |
res = self.api.kernels_pull(self.kernel_slug, path=path, metadata=True, quiet=False) | |
if not os.path.exists(metadata_path): | |
print(f"Warn: Not found {metadata_path}. Clean {path}") | |
self._clean() | |
return None | |
return res | |
def push(self, path=None) -> str or None: | |
status = self.status() | |
if status in ['queued', 'running']: | |
print("Warn: Notebook is " + status + ". Skip push notebook!") | |
return None | |
self.api.kernels_push(path or self._get_local_nb_path()) | |
time.sleep(1) | |
status = self.status() | |
return status | |
def _clean(self) -> None: | |
if os.path.exists(self._get_local_nb_path()): | |
shutil.rmtree(self._get_local_nb_path()) | |
def get_metadata(self, path=None): | |
path = path or self._get_local_nb_path() | |
metadata_path = os.path.join(path, "kernel-metadata.json") | |
if not os.path.exists(metadata_path): | |
return None | |
return json.loads(open(metadata_path).read()) | |
def check_nb_permission(self) -> Union[tuple[bool], tuple[bool, None]]: | |
try: | |
status = self.status() | |
if status is None: | |
return False, status | |
return True, status | |
except ApiException as e: | |
print(f"Error: {e.status} {e.reason} with notebook {self.kernel_slug}") | |
return False, None | |
def check_datasets_permission(self) -> bool: | |
meta = self.get_metadata() | |
if meta is None: | |
print("Warn: cannot get metadata. Pull and try again?") | |
dataset_sources = meta['dataset_sources'] | |
for dataset in dataset_sources: | |
try: | |
self.api.dataset_status(dataset) | |
except ApiException as e: | |
print(f"Error: {e.status} {e.reason} with dataset {dataset} in notebook {self.kernel_slug}") | |
return False | |
return True | |
class AccountTransactionManager: | |
def __init__(self, acc_secret_dict: dict=None): | |
""" | |
:param acc_secret_dict: {username: secret} | |
""" | |
self.acc_secret_dict = acc_secret_dict | |
if self.acc_secret_dict is None: | |
self.acc_secret_dict = {} | |
# self.api_dict = {username: KaggleApiWrapper(username, secret) for username, secret in acc_secret_dict.items()} | |
# lock for each account to avoid concurrent use api | |
self.lock_dict = {username: False for username in self.acc_secret_dict.keys()} | |
self.state_lock = threading.Lock() | |
def _get_api(self, username: str) -> KaggleApiWrapper: | |
# return self.api_dict[username] | |
return KaggleApiWrapper(username, self.acc_secret_dict[username]) | |
def _get_lock(self, username: str) -> bool: | |
return self.lock_dict[username] | |
def _set_lock(self, username: str, lock: bool) -> None: | |
self.lock_dict[username] = lock | |
def add_account(self, username, secret): | |
if username not in self.acc_secret_dict.keys(): | |
self.state_lock.acquire() | |
self.acc_secret_dict[username] = secret | |
self.lock_dict[username] = False | |
self.state_lock.release() | |
def remove_account(self, username): | |
if username in self.acc_secret_dict.keys(): | |
self.state_lock.acquire() | |
del self.acc_secret_dict[username] | |
del self.lock_dict[username] | |
self.state_lock.release() | |
else: | |
print(f"Warn: try to remove account not in the list: {username}, list: {self.acc_secret_dict.keys()}") | |
def get_unlocked_api_unblocking(self, username_list: List[str]) -> tuple[KaggleApiWrapper, Callable[[], None]]: | |
""" | |
:param username_list: list of username | |
:return: (api, release) where release is a function to release api | |
""" | |
while True: | |
print("get_unlocked_api_unblocking" + str(username_list)) | |
for username in username_list: | |
self.state_lock.acquire() | |
if not self._get_lock(username): | |
self._set_lock(username, True) | |
api = self._get_api(username) | |
def release(): | |
self.state_lock.acquire() | |
self._set_lock(username, False) | |
api.__del__() | |
self.state_lock.release() | |
self.state_lock.release() | |
return api, release | |
self.state_lock.release() | |
time.sleep(1) | |
class NbJob: | |
def __init__(self, acc_dict: dict, nb_slug: str, rerun_stt: List[str] = None, not_rerun_stt: List[str] = None): | |
""" | |
:param acc_dict: | |
:param nb_slug: | |
:param rerun_stt: | |
:param not_rerun_stt: If notebook status in this list, do not rerun it. (Note: do not add "queued", "running") | |
""" | |
self.rerun_stt = rerun_stt | |
if self.rerun_stt is None: | |
self.rerun_stt = ['complete'] | |
self.not_rerun_stt = not_rerun_stt | |
if self.not_rerun_stt is None: | |
self.not_rerun_stt = ['queued', 'running', 'cancelAcknowledged'] | |
assert "queued" in self.not_rerun_stt | |
assert "running" in self.not_rerun_stt | |
self.acc_dict = acc_dict | |
self.nb_slug = nb_slug | |
def get_acc_dict(self): | |
return self.acc_dict | |
def get_username_list(self): | |
return list(self.acc_dict.keys()) | |
def is_valid_with_acc(self, api): | |
notebook = KaggleNotebook(api, self.nb_slug) | |
try: | |
notebook.pull() | |
except ApiException as e: | |
return False | |
stt, _ = notebook.check_nb_permission() | |
if not stt: | |
return False | |
stt = notebook.check_datasets_permission() | |
if not stt: | |
return False | |
return True | |
def is_valid(self): | |
for username in self.acc_dict.keys(): | |
secrets = self.acc_dict[username] | |
api = KaggleApiWrapper(username=username, secret=secrets) | |
api.authenticate() | |
if not self.is_valid_with_acc(api): | |
return False | |
return True | |
def acc_check_and_rerun_if_need(self, api: KaggleApi) -> bool: | |
""" | |
:return: | |
True if rerun success or notebook is running | |
False user does not have enough gpu quotas | |
:raises | |
Exception if setup error | |
""" | |
notebook = KaggleNotebook(api, self.nb_slug, "./tmp") # todo: change hardcode container_path here | |
notebook.pull() | |
assert notebook.check_datasets_permission(), f"User {api} does not have permission on datasets of notebook {self.nb_slug}" | |
success, status1 = notebook.check_nb_permission() | |
assert success, f"User {api} does not have permission on notebook {self.nb_slug}" # todo: using api.username | |
if status1 in self.rerun_stt: | |
status2 = notebook.push() | |
time.sleep(10) | |
status3 = notebook.status() | |
# if 3 times same stt -> acc out of quota | |
if status1 == status2 == status3: | |
sheet_logger.log(username=api.username, nb=self.nb_slug, log="Try but no effect. Seem account to be out of quota") | |
return False | |
if status3 in self.not_rerun_stt: | |
sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook is in ignore status list {self.not_rerun_stt}, do nothing!") | |
return True | |
if status3 not in ["queued", "running"]: | |
# return False # todo: check when user is out of quota | |
print(f"Error: status is {status3}") | |
raise Exception("Setup exception") | |
sheet_logger.log(username=api.username, nb=self.nb_slug, | |
log=f"Schedule notebook successfully. Current status is '{status3}'") | |
return True | |
sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook status is '{status1}' is not in {self.rerun_stt}, do nothing!") | |
return True | |
def from_dict(obj: dict): | |
return NbJob(acc_dict=obj['accounts'], nb_slug=obj['slug'], rerun_stt=obj.get('rerun_status'), not_rerun_stt=obj.get('not_rerun_stt')) | |
class KernelRerunService: | |
def __init__(self): | |
self.jobs: Dict[str, NbJob] = {} | |
self.acc_manager = AccountTransactionManager() | |
self.username2jobid = {} | |
self.jobid2username = {} | |
def add_job(self, nb_job: NbJob): | |
if nb_job.nb_slug in self.jobs.keys(): | |
print("Warn: nb_job already in job list") | |
return | |
self.jobs[nb_job.nb_slug] = nb_job | |
self.jobid2username[nb_job.nb_slug] = nb_job.get_username_list() | |
for username in nb_job.get_username_list(): | |
if username not in self.username2jobid.keys(): | |
self.username2jobid[username] = [] | |
self.acc_manager.add_account(username, nb_job.acc_dict[username]) | |
self.username2jobid[username].append(nb_job.nb_slug) | |
def remove_job(self, nb_job): | |
if nb_job.nb_slug not in self.jobs.keys(): | |
print("Warn: try to remove nb_job not in list") | |
return | |
username_list = self.jobid2username[nb_job.nb_slug] | |
username_list = [username for username in username_list if len(self.username2jobid[username]) == 1] | |
for username in username_list: | |
del self.username2jobid[username] | |
self.acc_manager.remove_account(username) | |
del self.jobs[nb_job.nb_slug] | |
del self.jobid2username[nb_job.nb_slug] | |
def validate_all(self): | |
for username in self.acc_manager.acc_secret_dict.keys(): | |
api, release = self.acc_manager.get_unlocked_api_unblocking([username]) | |
api.authenticate() | |
print(f"Using username: {api.username}") | |
for job in self.jobs.values(): | |
if username in job.get_username_list(): | |
print(f"Validate user: {username}, job: {job.nb_slug}") | |
if not job.is_valid_with_acc(api): | |
print(f"Error: not valid") | |
a = f"Setup error: {username} does not have permission on notebook {job.nb_slug} or related datasets" | |
raise Exception(a) | |
release() | |
return True | |
def status_all(self): | |
for job in self.jobs.values(): | |
print(f"Job: {job.nb_slug}") | |
api, release = self.acc_manager.get_unlocked_api_unblocking(job.get_username_list()) | |
api.authenticate() | |
print(f"Using username: {api.username}") | |
notebook = KaggleNotebook(api, job.nb_slug) | |
print(f"Notebook: {notebook.kernel_slug}") | |
print(notebook.status()) | |
release() | |
def run(self, nb_job: NbJob): | |
username_list = copy.copy(nb_job.get_username_list()) | |
while len(username_list) > 0: | |
api, release = self.acc_manager.get_unlocked_api_unblocking(username_list) | |
api.authenticate() | |
print(f"Using username: {api.username}") | |
try: | |
result = nb_job.acc_check_and_rerun_if_need(api) | |
if result: | |
return True | |
except Exception as e: | |
print(e) | |
release() | |
break | |
if api.username in username_list: | |
username_list.remove(api.username) | |
release() | |
else: | |
release() | |
raise Exception("") | |
return False | |
def run_all(self): | |
for job in self.jobs.values(): | |
success = self.run(job) | |
print(f"Job: {job.nb_slug} {success}") | |
import json | |
from kaggle_service import KernelRerunService, NbJob | |
from logger import sheet_logger | |
if __name__ == "__main__": | |
configs = [] | |
try: | |
for i in range(2, 1000): | |
rs = conf_repo.read(i) | |
if not rs: | |
break | |
cfg = json.loads(rs['config']) | |
configs.append(cfg) | |
print(cfg) | |
except Exception as e: | |
sheet_logger.log(log="Get config failed!!") | |
service = KernelRerunService() | |
for config in configs: | |
service.add_job(NbJob.from_dict(config)) | |
try: | |
service.validate_all() | |
service.status_all() | |
service.run_all() | |
except Exception as e: | |
sheet_logger.log(log=str(e)) | |