Skip to content

Commit

Permalink
Forward model fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Oct 4, 2024
1 parent 775c32e commit d13fd33
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion atomgpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version number."""

__version__ = "2024.9.18"
__version__ = "2024.9.30"
20 changes: 13 additions & 7 deletions atomgpt/forward_models/forward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import pprint
import sys
import argparse
from alignn.pretrained import get_figshare_model

parser = argparse.ArgumentParser(
description="Atomistic Generative Pre-trained Transformer."
Expand Down Expand Up @@ -287,10 +286,10 @@ def run_atomgpt(config_file="config.json"):
pprint.pprint(config)
id_prop_path = config.id_prop_path
convert = config.convert
if convert:
model = get_figshare_model(
model_name="jv_formation_energy_peratom_alignn"
)
# if convert:
# model = get_figshare_model(
# model_name="jv_formation_energy_peratom_alignn"
# )
if ".zip" in id_prop_path:
zp = zipfile.ZipFile(id_prop_path)
dat = json.loads(zp.read(id_prop_path.split(".zip")[0]))
Expand All @@ -310,7 +309,8 @@ def run_atomgpt(config_file="config.json"):
)
if convert:
atoms = Atoms.from_poscar(pth)
lines = atoms.describe(model=model)[config.desc_type]
lines = atoms.describe()[config.desc_type]
# lines = atoms.describe(model=model)[config.desc_type]
else:

with open(pth, "r") as f:
Expand Down Expand Up @@ -529,7 +529,9 @@ def run_atomgpt(config_file="config.json"):
train_loss = 0
# train_result = []
input_ids = batch[0]["input_ids"].squeeze() # .squeeze(0)
# print('input_ids',input_ids.shape)
if "t5" in model_name:
input_ids = batch[0]["input_ids"].squeeze(1) # .squeeze(0)
predictions = (
model(
input_ids.to(device),
Expand Down Expand Up @@ -571,7 +573,8 @@ def run_atomgpt(config_file="config.json"):
f.write("id,target,predictions\n")
with torch.no_grad():
for batch in val_dataloader:
input_ids = batch[0]["input_ids"].squeeze() # .squeeze(0)
# input_ids = batch[0]["input_ids"].squeeze() # .squeeze(0)
input_ids = batch[0]["input_ids"].squeeze(1) # .squeeze(0)
ids = batch[1]
if "t5" in model_name:
predictions = (
Expand Down Expand Up @@ -645,6 +648,9 @@ def run_atomgpt(config_file="config.json"):
for batch in test_dataloader:
input_ids = batch[0]["input_ids"].squeeze() # .squeeze(0)
if "t5" in model_name:
input_ids = batch[0]["input_ids"].squeeze(
1
) # .squeeze(0)
predictions = (
model(
input_ids.to(device),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="atomgpt",
version="2024.9.18",
version="2024.9.30",
author="Kamal Choudhary",
author_email="[email protected]",
description="atomgpt",
Expand Down

0 comments on commit d13fd33

Please sign in to comment.