Spaces:
Build error
Build error
from typing import List, Union | |
from copy import deepcopy | |
from collections import namedtuple | |
from pathlib import Path | |
import argparse | |
from argparse import RawDescriptionHelpFormatter | |
import yaml | |
from pydantic import BaseModel as _Base | |
import os | |
class BaseConf(_Base): | |
class Config: | |
validate_all = True | |
allow_mutation = True | |
extra = "ignore" | |
def SingleOrList(inner_type): | |
return Union[inner_type, List[inner_type]] | |
def optional_load_config(fname="config.yml"): | |
cfg = {} | |
conf_fname = Path.cwd() / fname | |
if conf_fname.is_file(): | |
with conf_fname.open("r") as f: | |
raw = f.read() | |
print("loaded config\n ") | |
print(raw) # yaml raw itself is well formatted | |
cfg = yaml.safe_load(raw) | |
return cfg | |
# def optional_load_config(path="config.yml"): | |
# cfg = {} | |
# # conf_fname = Path.cwd() / fname | |
# conf_fname = Path(path) | |
# if conf_fname.is_file(): | |
# with conf_fname.open("r") as f: | |
# raw = f.read() | |
# print("loaded config\n ") | |
# print(raw) # yaml raw itself is well formatted | |
# cfg = yaml.safe_load(raw) | |
# return cfg | |
def write_full_config(cfg_obj, fname="full_config.yml"): | |
cfg = cfg_obj.dict() | |
cfg = _dict_to_yaml(cfg) | |
print(f"\n--- full config ---\n\n{cfg}\n") | |
with (Path.cwd() / fname).open("w") as f: | |
f.write(cfg) | |
def argparse_cfg_template(curr_cfgs): | |
parser = argparse.ArgumentParser( | |
description='Manual spec of configs', | |
epilog=f'curr cfgs:\n\n{_dict_to_yaml(curr_cfgs)}', | |
formatter_class=RawDescriptionHelpFormatter | |
) | |
_, args = parser.parse_known_args() | |
clauses = [] | |
for i in range(0, len(args), 2): | |
assert args[i][:2] == "--", "please start args with --" | |
clauses.append({args[i][2:]: args[i+1]}) | |
maker = ConfigMaker(curr_cfgs) | |
for clu in clauses: | |
maker.execute_clause(clu) | |
final = maker.state.copy() | |
return final | |
def _dict_to_yaml(arg): | |
return yaml.safe_dump(arg, sort_keys=False, allow_unicode=True) | |
def dispatch(module): | |
cfg = optional_load_config(fname="gradio_init.yml") | |
cfg = module(**cfg).dict() | |
cfg = argparse_cfg_template(cfg) # cmdline takes priority | |
mod = module(**cfg) | |
exp_path = os.path.join(cfg['exp_dir'],cfg['initial']) | |
os.makedirs(exp_path, exist_ok=True) | |
write_full_config(mod, os.path.join(exp_path,"full_config.yml")) | |
mod.run() | |
def dispatch_gradio(module, prompt, keyword, ti_step, pt_step, seed): | |
cfg = optional_load_config("gradio_init.yml") | |
cfg['sd']['prompt'] = prompt | |
cfg['sd']['dir'] = os.path.join(cfg['exp_dir'],keyword,'lora/final_lora.safetensors') | |
cfg['ti_step'] = ti_step | |
cfg['pt_step'] = pt_step | |
cfg['initial'] = keyword | |
cfg['random_seed'] = seed | |
cfg = module(**cfg).dict() | |
mod = module(**cfg) | |
return mod | |
# below are some support tools | |
class ConfigMaker(): | |
CMD = namedtuple('cmd', field_names=['sub', 'verb', 'objs']) | |
VERBS = ('add', 'replace', 'del') | |
def __init__(self, base_node): | |
self.state = base_node | |
self.clauses = [] | |
def clone(self): | |
return deepcopy(self) | |
def execute_clause(self, raw_clause): | |
cls = self.__class__ | |
assert isinstance(raw_clause, (str, dict)) | |
if isinstance(raw_clause, dict): | |
assert len(raw_clause) == 1, \ | |
"a clause can only have 1 statement: {} clauses in {}".format( | |
len(raw_clause), raw_clause | |
) | |
cmd = list(raw_clause.keys())[0] | |
arg = raw_clause[cmd] | |
else: | |
cmd = raw_clause | |
arg = None | |
cmd = self.parse_clause_cmd(cmd) | |
tracer = NodeTracer(self.state) | |
tracer.advance_pointer(path=cmd.sub) | |
if cmd.verb == cls.VERBS[0]: | |
tracer.add(cmd.objs, arg) | |
elif cmd.verb == cls.VERBS[1]: | |
tracer.replace(cmd.objs, arg) | |
elif cmd.verb == cls.VERBS[2]: | |
assert isinstance(raw_clause, str) | |
tracer.delete(cmd.objs) | |
self.state = tracer.state | |
def parse_clause_cmd(cls, input): | |
""" | |
Args: | |
input: a string to be parsed | |
1. First test whether a verb is present | |
2. If not present, then str is a single subject, and verb is replace | |
This is a syntactical sugar that makes writing config easy | |
3. If a verb is found, whatever comes before is a subject, and after the | |
objects. | |
4. Handle the edge cases properly. Below are expected parse outputs | |
input sub verb obj | |
--- No verb | |
'' '' replace [] | |
'a.b' 'a.b' replace [] | |
'add' '' add [] | |
'P Q' err: 2 subjects | |
--- Verb present | |
'T add' 'T' add [] | |
'T del a b' 'T' del [a, b] | |
'P Q add a' err: 2 subjects | |
'P add del b' err: 2 verbs | |
""" | |
assert isinstance(input, str) | |
input = input.split() | |
objs = [] | |
sub = '' | |
verb, verb_inx = cls.scan_for_verb(input) | |
if verb is None: | |
assert len(input) <= 1, "no verb present; more than 1 subject: {}"\ | |
.format(input) | |
sub = input[0] if len(input) == 1 else '' | |
verb = cls.VERBS[1] | |
else: | |
assert not verb_inx > 1, 'verb {} at inx {}; more than 1 subject in: {}'\ | |
.format(verb, verb_inx, input) | |
sub = input[0] if verb_inx == 1 else '' | |
objs = input[verb_inx + 1:] | |
cmd = cls.CMD(sub=sub, verb=verb, objs=objs) | |
return cmd | |
def scan_for_verb(cls, input_list): | |
assert isinstance(input_list, list) | |
counts = [ input_list.count(v) for v in cls.VERBS ] | |
presence = [ cnt > 0 for cnt in counts ] | |
if sum(presence) == 0: | |
return None, -1 | |
elif sum(presence) > 1: | |
raise ValueError("multiple verbs discovered in {}".format(input_list)) | |
if max(counts) > 1: | |
raise ValueError("verbs repeated in cmd: {}".format(input_list)) | |
# by now, there is 1 verb that has occured exactly 1 time | |
verb = cls.VERBS[presence.index(1)] | |
inx = input_list.index(verb) | |
return verb, inx | |
class NodeTracer(): | |
def __init__(self, src_node): | |
""" | |
A src node can be either a list or dict | |
""" | |
assert isinstance(src_node, (list, dict)) | |
# these are movable pointers | |
self.child_token = "_" # init token can be anything | |
self.parent = {self.child_token: src_node} | |
# these are permanent pointers at the root | |
self.root_child_token = self.child_token | |
self.root = self.parent | |
def state(self): | |
return self.root[self.root_child_token] | |
def pointed(self): | |
return self.parent[self.child_token] | |
def advance_pointer(self, path): | |
if len(path) == 0: | |
return | |
path_list = list( | |
map(lambda x: int(x) if str.isdigit(x) else x, path.split('.')) | |
) | |
for i, token in enumerate(path_list): | |
self.parent = self.pointed | |
self.child_token = token | |
try: | |
self.pointed | |
except (IndexError, KeyError): | |
raise ValueError( | |
"During the tracing of {}, {}-th token '{}'" | |
" is not present in node {}".format( | |
path, i, self.child_token, self.state | |
) | |
) | |
def replace(self, objs, arg): | |
assert len(objs) == 0 | |
val_type = type(self.parent[self.child_token]) | |
# this is such an unfortunate hack | |
# turn everything to string, so that eval could work | |
# some of the clauses come from cmdline, some from yaml files for sow. | |
arg = str(arg) | |
if val_type == str: | |
pass | |
else: | |
arg = eval(arg) | |
assert type(arg) == val_type, \ | |
f"require {val_type.__name__}, given {type(arg).__name__}" | |
self.parent[self.child_token] = arg | |