File size: 804 Bytes
a986605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import json
import os
from pathlib import Path

import farconf
from cleanba.config import Args
from cleanba.environments import SokobanConfig

soko_env = SokobanConfig(
    max_episode_steps=100, num_envs=1, dim_room=(10, 10), num_boxes=1, asynchronous=False, tinyworld_obs=True
).make()


def parameter_count(root: Path) -> str:
    model_dir = os.listdir(root)[0]
    cp_dir = os.listdir(root / model_dir)[0]

    with open(root / model_dir / cp_dir / "cfg.json", "r") as f:
        cfg = json.load(f)

    args = farconf.from_dict(cfg["cfg"], Args)
    num = args.net.count_params(soko_env)
    return f"{num:,} ({num/1_000_000:.2f}M)"


print("- DRC(3, 3): ", parameter_count(Path("drc33")))
print("- DRC(1, 1): ", parameter_count(Path("drc11")))
print("- ResNet: ", parameter_count(Path("resnet")))