-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: Add utilities for working with CuH2
Currently for xtensor only
- Loading branch information
Showing
5 changed files
with
199 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// MIT License | ||
// Copyright 2023--present Rohit Goswami <HaoZeke> | ||
// clang-format off | ||
#include <algorithm> | ||
#include <limits> | ||
// clang-format on | ||
#include "rgpot/CuH2/cuh2Utils.hpp" | ||
#include "include/ReadCon.hpp" | ||
using rgpot::types::AtomMatrix; | ||
|
||
#ifdef WITH_XTENSOR | ||
namespace rgpot::cuh2::utils::xts { | ||
|
||
xt::xtensor<double, 2> | ||
extract_positions(const yodecon::types::ConFrameVec &frame) { | ||
size_t n_atoms = frame.x.size(); | ||
std::array<size_t, 2> shape = {static_cast<size_t>(n_atoms), 3}; | ||
|
||
xt::xtensor<double, 2> positions = xt::empty<double>(shape); | ||
for (size_t i = 0; i < n_atoms; ++i) { | ||
positions(i, 0) = frame.x[i]; | ||
positions(i, 1) = frame.y[i]; | ||
positions(i, 2) = frame.z[i]; | ||
} | ||
|
||
return positions; | ||
} | ||
|
||
xt::xtensor<double, 2> | ||
peturb_positions(const xt::xtensor<double, 2> &base_positions, | ||
const xt::xtensor<int, 1> &atmNumVec, double hcu_dist, | ||
double hh_dist) { | ||
xt::xtensor<double, 2> positions = base_positions; | ||
std::vector<size_t> hIndices, cuIndices; | ||
|
||
for (size_t i = 0; i < atmNumVec.size(); ++i) { | ||
if (atmNumVec(i) == 1) { // Hydrogen atom | ||
hIndices.push_back(i); | ||
} else if (atmNumVec(i) == 29) { // Copper atom | ||
cuIndices.push_back(i); | ||
} else { | ||
throw std::runtime_error("Unexpected atomic number"); | ||
} | ||
} | ||
|
||
if (hIndices.size() != 2) { | ||
throw std::runtime_error("Expected exactly two hydrogen atoms"); | ||
} | ||
|
||
// Compute the midpoint of the hydrogens | ||
auto hMidpoint = | ||
(xt::row(positions, hIndices[0]) + xt::row(positions, hIndices[1])) / 2; | ||
|
||
// TODO(rg): This is buggy in cuh2vizR!! (maybe) | ||
// Compute the HH direction | ||
xt::xtensor<double, 1> hh_direction; | ||
size_t h1_idx, h2_idx; | ||
if (positions(hIndices[0], 0) < positions(hIndices[1], 0)) { | ||
hh_direction = | ||
xt::row(positions, hIndices[1]) - xt::row(positions, hIndices[0]); | ||
ensure_normalized(hh_direction); | ||
h1_idx = hIndices[0]; | ||
h2_idx = hIndices[1]; | ||
} else { | ||
hh_direction = | ||
xt::row(positions, hIndices[0]) - xt::row(positions, hIndices[1]); | ||
ensure_normalized(hh_direction); | ||
h1_idx = hIndices[1]; | ||
h2_idx = hIndices[0]; | ||
} | ||
|
||
// Set the new position of the hydrogens using the recorded indices | ||
xt::row(positions, h1_idx) = hMidpoint - (0.5 * hh_dist) * hh_direction; | ||
xt::row(positions, h2_idx) = hMidpoint + (0.5 * hh_dist) * hh_direction; | ||
|
||
// Find the z-coordinate of the topmost Cu layer | ||
double maxCuZ = std::numeric_limits<double>::lowest(); | ||
for (auto cuIndex : cuIndices) { | ||
maxCuZ = std::max(maxCuZ, positions(cuIndex, 2)); | ||
} | ||
|
||
// Compute the new z-coordinate for the H atoms | ||
double new_z = maxCuZ + hcu_dist; | ||
|
||
// Update the z-coordinates of the H atoms | ||
for (auto hIndex : hIndices) { | ||
positions(hIndex, 2) = new_z; | ||
} | ||
|
||
return positions; | ||
} | ||
|
||
std::pair<double, double> | ||
calculateDistances(const xt::xtensor<double, 2> &positions, | ||
const xt::xtensor<int, 1> &atmNumVec) { | ||
std::vector<size_t> hIndices, cuIndices; | ||
for (size_t i = 0; i < atmNumVec.size(); ++i) { | ||
if (atmNumVec(i) == 1) { // Hydrogen atom | ||
hIndices.push_back(i); | ||
} else if (atmNumVec(i) == 29) { // Copper atom | ||
cuIndices.push_back(i); | ||
} else { | ||
throw std::runtime_error("Unexpected atomic number"); | ||
} | ||
} | ||
|
||
if (hIndices.size() != 2) { | ||
throw std::runtime_error("Expected exactly two hydrogen atoms"); | ||
} | ||
|
||
// Calculate the distance between Hydrogen atoms | ||
double hDistance = | ||
xt::linalg::norm(xt::view(positions, hIndices[0], xt::all()) - | ||
xt::view(positions, hIndices[1], xt::all())); | ||
|
||
// Calculate the midpoint of Hydrogen atoms | ||
xt::xtensor<double, 1> hMidpoint = | ||
(xt::view(positions, hIndices[0], xt::all()) + | ||
xt::view(positions, hIndices[1], xt::all())) / | ||
2.0; | ||
|
||
// Find the z-coordinate of the topmost Cu layer | ||
double maxCuZ = std::numeric_limits<double>::lowest(); | ||
for (size_t cuIndex : cuIndices) { | ||
maxCuZ = std::max(maxCuZ, positions(cuIndex, 2)); | ||
} | ||
|
||
double cuSlabDist = positions(hIndices[0], 2) - maxCuZ; | ||
|
||
return std::make_pair(hDistance, cuSlabDist); | ||
} | ||
|
||
} // namespace rgpot::cuh2::utils::xts | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#pragma once | ||
// MIT License | ||
// Copyright 2024--present Rohit Goswami <HaoZeke> | ||
// clang-format off | ||
#include <utility> | ||
#include <vector> | ||
// clang-format on | ||
#include "include/ReadCon.hpp" | ||
#include "rgpot/types/AtomMatrix.hpp" | ||
#ifdef WITH_XTENSOR | ||
#include "xtensor-blas/xlinalg.hpp" | ||
#include "xtensor/xarray.hpp" | ||
#include "xtensor/xview.hpp" | ||
#endif | ||
using rgpot::types::AtomMatrix; | ||
|
||
namespace rgpot { | ||
namespace cuh2 { | ||
namespace utils { | ||
#ifdef WITH_XTENSOR | ||
namespace xts { | ||
xt::xtensor<double, 2> | ||
extract_positions(const yodecon::types::ConFrameVec &frame); | ||
xt::xtensor<double, 2> | ||
peturb_positions(const xt::xtensor<double, 2> &base_positions, | ||
const xt::xtensor<int, 1> &atmNumVec, double hcu_dist, | ||
double hh_dist); | ||
std::pair<double, double> | ||
calculateDistances(const xt::xtensor<double, 2> &positions, | ||
const xt::xtensor<int, 1> &atmNumVec); | ||
|
||
// TODO(rg): This is duplicated from xts::func !! | ||
template <class E, class ScalarType = double> | ||
void ensure_normalized(E &&vector, bool is_normalized = false, | ||
ScalarType tol = static_cast<ScalarType>(1e-6)) { | ||
if (!is_normalized) { | ||
auto norm = xt::linalg::norm(vector, 2); | ||
if (std::abs(norm - static_cast<ScalarType>(1.0)) >= tol) { | ||
vector /= norm; | ||
} else { | ||
throw std::runtime_error( | ||
"Cannot normalize a vector whose norm is smaller than tol"); | ||
} | ||
} | ||
} | ||
|
||
} // namespace xts | ||
#endif | ||
} // namespace utils | ||
|
||
} // namespace cuh2 | ||
|
||
} // namespace rgpot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[wrap-git] | ||
directory=readcon | ||
url=https://github.com/HaoZeke/readCon.git | ||
revision=73bdb0cba065a63c9820a3238e799bdd1379d7c9 |