Quentin Gallouédec commited on
Commit
c8398fa
·
1 Parent(s): 319c2e7
Files changed (2) hide show
  1. requirements.txt +2 -0
  2. src/evaluation.py +6 -2
requirements.txt CHANGED
@@ -14,9 +14,11 @@ pandas==2.0.0
14
  python-dateutil==2.8.2
15
  requests==2.28.2
16
  rliable==1.0.8
 
17
  torch==2.2.2
18
  tqdm==4.65.0
19
 
 
20
  # Log Visualizer
21
  BeautifulSoup4==4.12.2
22
  lxml==4.9.3
 
14
  python-dateutil==2.8.2
15
  requests==2.28.2
16
  rliable==1.0.8
17
+ --extra-index-url https://download.pytorch.org/whl/cu113
18
  torch==2.2.2
19
  tqdm==4.65.0
20
 
21
+
22
  # Log Visualizer
23
  BeautifulSoup4==4.12.2
24
  lxml==4.9.3
src/evaluation.py CHANGED
@@ -15,6 +15,10 @@ logger = setup_logger(__name__)
15
 
16
  API = HfApi(token=os.environ.get("TOKEN"))
17
 
 
 
 
 
18
  ALL_ENV_IDS = [
19
  "AdventureNoFrameskip-v4",
20
  "AirRaidNoFrameskip-v4",
@@ -338,7 +342,7 @@ def evaluate(model_id, revision):
338
 
339
  # Load the agent
340
  try:
341
- agent = torch.jit.load(agent_path)
342
  except Exception as e:
343
  logger.error(f"Error loading agent: {e}")
344
  return None
@@ -349,7 +353,7 @@ def evaluate(model_id, revision):
349
  observations, _ = envs.reset()
350
  episodic_returns = []
351
  while len(episodic_returns) < 10:
352
- actions = agent(torch.tensor(observations)).numpy()
353
  observations, _, _, _, infos = envs.step(actions)
354
  if "final_info" in infos:
355
  for info in infos["final_info"]:
 
15
 
16
  API = HfApi(token=os.environ.get("TOKEN"))
17
 
18
+ logger.info(f"Is CUDA available: {torch.cuda.is_available()}")
19
+ logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
20
+
21
+
22
  ALL_ENV_IDS = [
23
  "AdventureNoFrameskip-v4",
24
  "AirRaidNoFrameskip-v4",
 
342
 
343
  # Load the agent
344
  try:
345
+ agent = torch.jit.load(agent_path).to("cuda")
346
  except Exception as e:
347
  logger.error(f"Error loading agent: {e}")
348
  return None
 
353
  observations, _ = envs.reset()
354
  episodic_returns = []
355
  while len(episodic_returns) < 10:
356
+ actions = agent(torch.tensor(observations, device="cuda")).cpu().numpy()
357
  observations, _, _, _, infos = envs.step(actions)
358
  if "final_info" in infos:
359
  for info in infos["final_info"]: