Skip to content

Commit

Permalink
Merge pull request #20 from columbia-ai-robotics/cchi/eval_script
Browse files Browse the repository at this point in the history
added eval script and documentation
  • Loading branch information
cheng-chi authored Sep 10, 2023
2 parents 68eef44 + a98e748 commit 5c3d54f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image

7 directories, 16 files
```
### 🆕 Evaluate Pre-trained Checkpoints
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).

Run the evaluation script:
```console
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
```

This will generate the following directory structure:
```console
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
data/pusht_eval_output
├── eval_log.json
└── media
├── 1fxtno84.mp4
├── 224l7jqd.mp4
├── 2fo4btlf.mp4
├── 2in4cn7a.mp4
├── 34b3o2qq.mp4
└── 3p7jqn32.mp4

1 directory, 7 files
```

`eval_log.json` contains metrics that is logged to wandb during training:
```console
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
{
"test/mean_score": 0.9150393806777066,
"test/sim_max_reward_4300000": 1.0,
"test/sim_max_reward_4300001": 0.9872969750774386,
...
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
}
```

## 🦾 Demo, Training and Eval on a Real Robot
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
Expand Down
64 changes: 64 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Usage:
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
"""

import sys
# use line-buffering for both stdout and stderr
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

import os
import pathlib
import click
import hydra
import torch
import dill
import wandb
import json
from diffusion_policy.workspace.base_workspace import BaseWorkspace

@click.command()
@click.option('-c', '--checkpoint', required=True)
@click.option('-o', '--output_dir', required=True)
@click.option('-d', '--device', default='cuda:0')
def main(checkpoint, output_dir, device):
if os.path.exists(output_dir):
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

# load checkpoint
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# get policy from workspace
policy = workspace.model
if cfg.training.use_ema:
policy = workspace.ema_model

device = torch.device(device)
policy.to(device)
policy.eval()

# run eval
env_runner = hydra.utils.instantiate(
cfg.task.env_runner,
output_dir=output_dir)
runner_log = env_runner.run(policy)

# dump log to json
json_log = dict()
for key, value in runner_log.items():
if isinstance(value, wandb.sdk.data_types.video.Video):
json_log[key] = value._path
else:
json_log[key] = value
out_path = os.path.join(output_dir, 'eval_log.json')
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)

if __name__ == '__main__':
main()

0 comments on commit 5c3d54f

Please sign in to comment.