diff --git a/text2vec/sentence_model.py b/text2vec/sentence_model.py index 3273f20..605aee5 100644 --- a/text2vec/sentence_model.py +++ b/text2vec/sentence_model.py @@ -178,6 +178,8 @@ def encode( self.bert.eval() if device is None: device = self.device + self.bert.to(device) + if max_seq_length is None: max_seq_length = self.max_seq_length if convert_to_tensor: