learned-planner / count_params.py
agaralon's picture
Parameter counts and some explanation
a986605 unverified
raw
history blame
804 Bytes
import json
import os
from pathlib import Path
import farconf
from cleanba.config import Args
from cleanba.environments import SokobanConfig
soko_env = SokobanConfig(
max_episode_steps=100, num_envs=1, dim_room=(10, 10), num_boxes=1, asynchronous=False, tinyworld_obs=True
).make()
def parameter_count(root: Path) -> str:
model_dir = os.listdir(root)[0]
cp_dir = os.listdir(root / model_dir)[0]
with open(root / model_dir / cp_dir / "cfg.json", "r") as f:
cfg = json.load(f)
args = farconf.from_dict(cfg["cfg"], Args)
num = args.net.count_params(soko_env)
return f"{num:,} ({num/1_000_000:.2f}M)"
print("- DRC(3, 3): ", parameter_count(Path("drc33")))
print("- DRC(1, 1): ", parameter_count(Path("drc11")))
print("- ResNet: ", parameter_count(Path("resnet")))