Spaces:
Runtime error
Runtime error
Quentin Gallouédec
commited on
Commit
·
c8398fa
1
Parent(s):
319c2e7
try cuda
Browse files- requirements.txt +2 -0
- src/evaluation.py +6 -2
requirements.txt
CHANGED
@@ -14,9 +14,11 @@ pandas==2.0.0
|
|
14 |
python-dateutil==2.8.2
|
15 |
requests==2.28.2
|
16 |
rliable==1.0.8
|
|
|
17 |
torch==2.2.2
|
18 |
tqdm==4.65.0
|
19 |
|
|
|
20 |
# Log Visualizer
|
21 |
BeautifulSoup4==4.12.2
|
22 |
lxml==4.9.3
|
|
|
14 |
python-dateutil==2.8.2
|
15 |
requests==2.28.2
|
16 |
rliable==1.0.8
|
17 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
18 |
torch==2.2.2
|
19 |
tqdm==4.65.0
|
20 |
|
21 |
+
|
22 |
# Log Visualizer
|
23 |
BeautifulSoup4==4.12.2
|
24 |
lxml==4.9.3
|
src/evaluation.py
CHANGED
@@ -15,6 +15,10 @@ logger = setup_logger(__name__)
|
|
15 |
|
16 |
API = HfApi(token=os.environ.get("TOKEN"))
|
17 |
|
|
|
|
|
|
|
|
|
18 |
ALL_ENV_IDS = [
|
19 |
"AdventureNoFrameskip-v4",
|
20 |
"AirRaidNoFrameskip-v4",
|
@@ -338,7 +342,7 @@ def evaluate(model_id, revision):
|
|
338 |
|
339 |
# Load the agent
|
340 |
try:
|
341 |
-
agent = torch.jit.load(agent_path)
|
342 |
except Exception as e:
|
343 |
logger.error(f"Error loading agent: {e}")
|
344 |
return None
|
@@ -349,7 +353,7 @@ def evaluate(model_id, revision):
|
|
349 |
observations, _ = envs.reset()
|
350 |
episodic_returns = []
|
351 |
while len(episodic_returns) < 10:
|
352 |
-
actions = agent(torch.tensor(observations)).numpy()
|
353 |
observations, _, _, _, infos = envs.step(actions)
|
354 |
if "final_info" in infos:
|
355 |
for info in infos["final_info"]:
|
|
|
15 |
|
16 |
API = HfApi(token=os.environ.get("TOKEN"))
|
17 |
|
18 |
+
logger.info(f"Is CUDA available: {torch.cuda.is_available()}")
|
19 |
+
logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
20 |
+
|
21 |
+
|
22 |
ALL_ENV_IDS = [
|
23 |
"AdventureNoFrameskip-v4",
|
24 |
"AirRaidNoFrameskip-v4",
|
|
|
342 |
|
343 |
# Load the agent
|
344 |
try:
|
345 |
+
agent = torch.jit.load(agent_path).to("cuda")
|
346 |
except Exception as e:
|
347 |
logger.error(f"Error loading agent: {e}")
|
348 |
return None
|
|
|
353 |
observations, _ = envs.reset()
|
354 |
episodic_returns = []
|
355 |
while len(episodic_returns) < 10:
|
356 |
+
actions = agent(torch.tensor(observations, device="cuda")).cpu().numpy()
|
357 |
observations, _, _, _, infos = envs.step(actions)
|
358 |
if "final_info" in infos:
|
359 |
for info in infos["final_info"]:
|