You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When 'trn.AddOffsets(MD17.energy, add_mean=True, add_atomrefs=False)' is used in 'spk.model.NeuralNetworkPotential', the batch predictions only has one value in 'energy'. If this term is not used, it works correctly.
#example script from MD17 in tutorials
from ase import Atoms
res = {}
for i in range(256):
# create atoms object from dataset
structure = ethanol_data.test_dataset[i]
atoms = Atoms(
numbers=structure[spk.properties.Z], positions=structure[spk.properties.R]
)
inputs = converter(atoms)
for list in inputs:
if list in res:
res[list] = torch.cat([res[list],inputs[list]], 0)
else:
res[list] = inputs[list]
res['_pbc'] = res['_pbc'].reshape(-1) #need to reshape too?
print(res['energy'])
res["energy"] = torch.rand(256)
print(res['energy'].shape)
convert atoms to SchNetPack inputs and perform prediction
results = best_model(res)
print(results)
The text was updated successfully, but these errors were encountered:
When 'trn.AddOffsets(MD17.energy, add_mean=True, add_atomrefs=False)' is used in 'spk.model.NeuralNetworkPotential', the batch predictions only has one value in 'energy'. If this term is not used, it works correctly.
#example script from MD17 in tutorials
from ase import Atoms
load model
model_path = os.path.join(forcetut, "best_inference_model")
best_model = torch.load(model_path).to('cpu')
set up converter
converter = spk.interfaces.AtomsConverter(
neighbor_list=trn.ASENeighborList(cutoff=5.0), dtype=torch.float32
)
res = {}
for i in range(256):
# create atoms object from dataset
structure = ethanol_data.test_dataset[i]
atoms = Atoms(
numbers=structure[spk.properties.Z], positions=structure[spk.properties.R]
)
inputs = converter(atoms)
for list in inputs:
if list in res:
res[list] = torch.cat([res[list],inputs[list]], 0)
else:
res[list] = inputs[list]
res['_pbc'] = res['_pbc'].reshape(-1) #need to reshape too?
print(res['energy'])
res["energy"] = torch.rand(256)
print(res['energy'].shape)
convert atoms to SchNetPack inputs and perform prediction
results = best_model(res)
print(results)
The text was updated successfully, but these errors were encountered: