diff --git a/colbert/parameters.py b/colbert/parameters.py index 60ae0a74..2d21f36f 100644 --- a/colbert/parameters.py +++ b/colbert/parameters.py @@ -1,6 +1,6 @@ import torch -DEVICE = torch.device("cuda") +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 300*1000, 400*1000] SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000]