You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I just add some codes for using BERTScore on Apple Silicon and it works fine. I think MPS support should be add to the code base for Mac users' convenience.
I tested MPS backend by this: score(candidates, references, lang="ja", device="mps")
Basically I added MPS detection code below every CUDA detection code.
in score.py:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else "cpu" # added to line 101
model.to(device)
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else "cpu" # added to line 242
model.to(device)
in scorer.py:
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = "mps" if torch.backends.mps.is_available() else "cpu" # added to line 74
else:
self.device = device
in utils.py:
device = "mps" if torch.backends.mps.is_available() else "cpu" # added to line 409 & 446 & 605
My environment:
Apple M1
macOS 14.5
Python 3.12
torch 2.3.0
bert-score 0.3.13
The text was updated successfully, but these errors were encountered:
I just add some codes for using BERTScore on Apple Silicon and it works fine. I think MPS support should be add to the code base for Mac users' convenience.
I tested MPS backend by this:
score(candidates, references, lang="ja", device="mps")
Basically I added MPS detection code below every CUDA detection code.
in
score.py
:in
scorer.py
:in
utils.py
:My environment:
The text was updated successfully, but these errors were encountered: