Skip to content

Commit

Permalink
Merge pull request #296 from ACEsuit/develop
Browse files Browse the repository at this point in the history
Change stress input + update version
  • Loading branch information
ilyes319 committed Jan 16, 2024
2 parents 44cc610 + 9395221 commit 6df8827
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mace/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.3"
__version__ = "0.3.4"
4 changes: 3 additions & 1 deletion mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def from_config(
else None
)
virials = (
torch.tensor(config.virials, dtype=torch.get_default_dtype()).unsqueeze(0)
voigt_to_matrix(
torch.tensor(config.virials, dtype=torch.get_default_dtype())
).unsqueeze(0)
if config.virials is not None
else None
)
Expand Down
4 changes: 2 additions & 2 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
Vector = np.ndarray # [3,]
Positions = np.ndarray # [..., 3]
Forces = np.ndarray # [..., 3]
Stress = np.ndarray # [6, ]
Virials = np.ndarray # [3,3]
Stress = np.ndarray # [6, ], [3,3], [9, ]
Virials = np.ndarray # [6, ], [3,3], [9, ]
Charges = np.ndarray # [..., 1]
Cell = np.ndarray # [3,3]
Pbc = tuple # (3,)
Expand Down
6 changes: 4 additions & 2 deletions mace/tools/torch_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def cartesian_to_spherical(t: torch.Tensor):
def voigt_to_matrix(t: torch.Tensor):
"""
Convert voigt notation to matrix notation
:param t: (6,) tensor or (3, 3) tensor
:param t: (6,) tensor or (3, 3) tensor or (9,) tensor
:return: (3, 3) tensor
"""
if t.shape == (3, 3):
Expand All @@ -121,9 +121,11 @@ def voigt_to_matrix(t: torch.Tensor):
],
dtype=t.dtype,
)
if t.shape == (9,):
return t.view(3, 3)

raise ValueError(
f"Stress tensor must be of shape (6,) or (3, 3), but has shape {t.shape}"
f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}"
)


Expand Down

0 comments on commit 6df8827

Please sign in to comment.