Skip to content

Commit

Permalink
Update Particle Container to Pure SoA
Browse files Browse the repository at this point in the history
Transition particle containers to pure SoA layouts.
  • Loading branch information
ax3l committed Jan 31, 2024
1 parent 4e2a6e5 commit 0e8b676
Show file tree
Hide file tree
Showing 43 changed files with 775 additions and 697 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/tooling.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ jobs:
ccache-openmp-clangsan-
- name: build ImpactX
env: {CC: mpicc, CXX: mpic++, OMPI_CC: clang, OMPI_CXX: clang++, CXXFLAGS: -Werror}
env:
CC: mpicc
CXX: mpic++
OMPI_CC: clang
OMPI_CXX: clang++
CXXFLAGS: "-Werror -Wno-error=pass-failed"
run: |
export LDFLAGS="${LDFLAGS} -fsanitize=address,undefined -shared-libsan"
export CXXFLAGS="${CXXFLAGS} -fsanitize=address,undefined -shared-libsan"
Expand Down
6 changes: 3 additions & 3 deletions cmake/dependencies/ABLASTR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,18 @@ set(ImpactX_openpmd_src ""
"Local path to openPMD-api source directory (preferred if set)")

# Git fetcher
set(ImpactX_ablastr_repo "https://github.com/ECP-WarpX/WarpX.git"
set(ImpactX_ablastr_repo "https://github.com/ax3l/WarpX.git"
CACHE STRING
"Repository URI to pull and build ABLASTR from if(ImpactX_ablastr_internal)")
set(ImpactX_ablastr_branch "24.01"
set(ImpactX_ablastr_branch "topic-soa-reintro"
CACHE STRING
"Repository branch for ImpactX_ablastr_repo if(ImpactX_ablastr_internal)")

# AMReX is transitively pulled through ABLASTR
set(ImpactX_amrex_repo "https://github.com/AMReX-Codes/amrex.git"
CACHE STRING
"Repository URI to pull and build AMReX from if(ImpactX_amrex_internal)")
set(ImpactX_amrex_branch ""
set(ImpactX_amrex_branch "689144d157a0106faf3d0ae89f8d90b0250cf975"
CACHE STRING
"Repository branch for ImpactX_amrex_repo if(ImpactX_amrex_internal)")

Expand Down
2 changes: 1 addition & 1 deletion examples/epac2004_benchmarks/input_fodo_rf_SC.in
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ geometry.prob_relative = 4.0
###############################################################################
# Diagnostics
###############################################################################
diag.slice_step_diagnostics = true
diag.slice_step_diagnostics = false
20 changes: 10 additions & 10 deletions examples/fodo/run_fodo_programmable.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ def my_drift(pge, pti, refpart):

else:
array = np.array
# access AoS data such as positions and cpu/id
aos = pti.aos()
aos_arr = array(aos, copy=False)

# access SoA data such as momentum
# access particle attributes
soa = pti.soa()
real_arrays = soa.GetRealData()
px = array(real_arrays[0], copy=False)
py = array(real_arrays[1], copy=False)
pt = array(real_arrays[2], copy=False)
x = array(real_arrays[0], copy=False)
y = array(real_arrays[1], copy=False)
t = array(real_arrays[2], copy=False)
px = array(real_arrays[3], copy=False)
py = array(real_arrays[4], copy=False)
pt = array(real_arrays[5], copy=False)

# length of the current slice
slice_ds = pge.ds / pge.nslice
Expand All @@ -97,9 +97,9 @@ def my_drift(pge, pti, refpart):
betgam2 = pt_ref**2 - 1.0

# advance position and momentum (drift)
aos_arr[:]["x"] += slice_ds * px[:]
aos_arr[:]["y"] += slice_ds * py[:]
aos_arr[:]["z"] += (slice_ds / betgam2) * pt[:]
x[:] += slice_ds * px[:]
y[:] += slice_ds * py[:]
t[:] += (slice_ds / betgam2) * pt[:]


def my_ref_drift(pge, refpart):
Expand Down
70 changes: 49 additions & 21 deletions examples/pytorch_surrogate_model/run_ml_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from urllib import request

import numpy as np

try:
import cupy as cp

cupy_available = True
except ImportError:
cupy_available = False

from surrogate_model_definitions import surrogate_model

try:
Expand All @@ -20,10 +28,11 @@
sys.exit(0)

from impactx import (
Config,
CoordSystem,
ImpactX,
ImpactXParIter,
RefPart,
TransformationDirection,
coordinate_transformation,
distribution,
elements,
Expand Down Expand Up @@ -79,7 +88,22 @@ def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
self.ds = surrogate_length

def surrogate_push(self, pc, step):
array = np.array
# CPU/GPU logic
if Config.have_gpu:
if cupy_available:
array = cp.array
stack = cp.stack
device = torch.device("cuda")
else:
print("Warning: GPU found but cupy not available! Try managed...")
array = np.array
stack = np.stack
device = torch.device("cpu")
if Config.gpu_backend == "SYCL":
print("Warning: SYCL GPU backend not yet implemented for Python")

else:
array = np.array

ref_part = pc.ref_particle()
ref_z_i = ref_part.z
Expand All @@ -100,26 +124,30 @@ def surrogate_push(self, pc, step):
ref_part_final = torch.tensor([0, 0, ref_z_f, 0, 0, ref_uz_f])

# transform
coordinate_transformation(pc, TransformationDirection.to_fixed_t)
coordinate_transformation(pc, direction=CoordSystem.t)

for lvl in range(pc.finest_level + 1):
for pti in ImpactXParIter(pc, level=lvl):
aos = pti.aos()
aos_arr = array(aos, copy=False)

soa = pti.soa()
real_arrays = soa.GetRealData()
px = array(real_arrays[0], copy=False)
py = array(real_arrays[1], copy=False)
pt = array(real_arrays[2], copy=False)
data_arr = (
torch.tensor(
np.vstack(
[aos_arr["x"], aos_arr["y"], aos_arr["z"], real_arrays[:3]]
)
)
.float()
.T
x = array(real_arrays[0], copy=False)
y = array(real_arrays[1], copy=False)
t = array(real_arrays[2], copy=False)
px = array(real_arrays[3], copy=False)
py = array(real_arrays[4], copy=False)
pt = array(real_arrays[5], copy=False)
data_arr = torch.as_tensor(
stack(

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'stack' may be used before it is initialized.
[
x,
y,
t,
px,
py,
py,
]
),
device=device,

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'device' may be used before it is initialized.
)

data_arr[:, 0] += ref_part.x
Expand Down Expand Up @@ -147,9 +175,9 @@ def surrogate_push(self, pc, step):
data_arr_post_model[:, 3 + ii] -= ref_part_final[3 + ii]
data_arr_post_model[:, 3 + ii] /= ref_beta_gamma_final

aos_arr["x"] = data_arr_post_model[:, 0]
aos_arr["y"] = data_arr_post_model[:, 1]
aos_arr["z"] = data_arr_post_model[:, 2]
x[:] = data_arr_post_model[:, 0]
y[:] = data_arr_post_model[:, 1]
t[:] = data_arr_post_model[:, 2]
px[:] = data_arr_post_model[:, 3]
py[:] = data_arr_post_model[:, 4]
pt[:] = data_arr_post_model[:, 5]
Expand All @@ -174,7 +202,7 @@ def surrogate_push(self, pc, step):
# ref_part.s += pge1.ds
# ref_part.t += pge1.ds / ref_beta

coordinate_transformation(pc, TransformationDirection.to_fixed_s)
coordinate_transformation(pc, direction=CoordSystem.s)
## Done!


Expand Down
9 changes: 4 additions & 5 deletions src/particles/CollectLost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ namespace impactx
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;

AMREX_GPU_HOST_DEVICE
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept {
dst.m_aos[dst_ip] = src.m_aos[src_ip];

void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept
{
dst.m_idcpu[dst_ip] = src.m_idcpu[src_ip];
for (int j = 0; j < SrcData::NAR; ++j)
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
for (int j = 0; j < src.m_num_runtime_real; ++j)
Expand Down Expand Up @@ -141,8 +141,7 @@ namespace impactx
// move down
int const new_index = ip - n_removed;

ptile_src_data.m_aos[new_index] = ptile_src_data.m_aos[ip];

ptile_src_data.m_idcpu[new_index] = ptile_src_data.m_idcpu[ip];
for (int j = 0; j < SrcData::NAR; ++j)
ptile_src_data.m_rdata[j][new_index] = ptile_src_data.m_rdata[j][ip];
for (int j = 0; j < ptile_src_data.m_num_runtime_real; ++j)
Expand Down
81 changes: 24 additions & 57 deletions src/particles/ImpactXParticleContainer.H
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <AMReX_MultiFab.H>
#include <AMReX_ParIter.H>
#include <AMReX_Particles.H>
#include <AMReX_ParticleTile.H>

#include <AMReX_IntVect.H>
#include <AMReX_Vector.H>
Expand All @@ -35,43 +36,16 @@ namespace impactx
t ///< fixed t as the independent variable
};

/** AMReX pre-defined Real attributes
*
* These are the AMReX pre-defined struct indexes for the Real attributes
* stored in an AoS in ImpactXParticleContainer. We document this here,
* because we change the meaning of these "positions" depending on the
* coordinate system we are currently in.
*/
struct RealAoS
{
enum
{
x, ///< position in x [m] (at fixed s OR fixed t)
y, ///< position in y [m] (at fixed s OR fixed t)
t, ///< c * time-of-flight [m] (at fixed s)
nattribs ///< the number of attributes above (always last)
};

// at fixed t, the third component represents the position z
enum {
z = t ///< position in z [m] (at fixed t)
};

//! named labels for fixed s
static constexpr auto names_s = { "position_x", "position_y", "position_t" };
//! named labels for fixed t
static constexpr auto names_t = { "position_x", "position_y", "position_z" };
static_assert(names_s.size() == nattribs);
static_assert(names_t.size() == nattribs);
};

/** This struct indexes the additional Real attributes
/** This struct indexes the Real attributes
* stored in an SoA in ImpactXParticleContainer
*/
struct RealSoA
{
enum
{
x, ///< position in x [m] (at fixed s or t)
y, ///< position in y [m] (at fixed s or t)
t, ///< time-of-flight ct [m] (at fixed s)
px, ///< momentum in x, scaled by the magnitude of the reference momentum [unitless] (at fixed s or t)
py, ///< momentum in y, scaled by the magnitude of the reference momentum [unitless] (at fixed s or t)
pt, ///< energy deviation, scaled by speed of light * the magnitude of the reference momentum [unitless] (at fixed s)
Expand All @@ -80,27 +54,28 @@ namespace impactx
nattribs ///< the number of attributes above (always last)
};

// at fixed t, the third component represents the momentum in z
// at fixed t, the third component represents the position z, the 6th component represents the momentum in z
enum {
z = t, ///< position in z [m] (at fixed t)
pz = pt ///< momentum in z, scaled by the magnitude of the reference momentum [unitless] (at fixed t)
};

//! named labels for fixed s
static constexpr auto names_s = { "momentum_x", "momentum_y", "momentum_t", "qm", "weighting" };
static constexpr auto names_s = { "position_x", "position_y", "position_t", "momentum_x", "momentum_y", "momentum_t", "qm", "weighting" };
//! named labels for fixed t
static constexpr auto names_t = { "momentum_x", "momentum_y", "momentum_z", "qm", "weighting" };
static constexpr auto names_t = { "position_x", "position_y", "position_z", "momentum_x", "momentum_y", "momentum_z", "qm", "weighting" };
static_assert(names_s.size() == nattribs);
static_assert(names_t.size() == nattribs);
};

/** This struct indexes the additional Integer attributes
/** This struct indexes the Integer attributes
* stored in an SoA in ImpactXParticleContainer
*/
struct IntSoA
{
enum
{
nattribs ///< the number of particles above (always last)
nattribs ///< the number of attributes above (always last)
};
};

Expand All @@ -109,46 +84,46 @@ namespace impactx
* We subclass here to change the default threading strategy, which is
* `static` in AMReX, to `dynamic` in ImpactX.
*/
class ParIter
: public amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>
class ParIterSoA
: public amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>
{
public:
using amrex::ParIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>::ParIter;
using amrex::ParIterSoA<RealSoA::nattribs, IntSoA::nattribs>::ParIterSoA;

ParIter (ContainerType& pc, int level);
ParIterSoA (ContainerType& pc, int level);

ParIter (ContainerType& pc, int level, amrex::MFItInfo& info);
ParIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info);
};

/** Const AMReX iterator for particle boxes - data is read only.
*
* We subclass here to change the default threading strategy, which is
* `static` in AMReX, to `dynamic` in ImpactX.
*/
class ParConstIter
: public amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>
class ParConstIterSoA
: public amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>
{
public:
using amrex::ParConstIter<0, 0, RealSoA::nattribs, IntSoA::nattribs>::ParConstIter;
using amrex::ParConstIterSoA<RealSoA::nattribs, IntSoA::nattribs>::ParConstIterSoA;

ParConstIter (ContainerType& pc, int level);
ParConstIterSoA (ContainerType& pc, int level);

ParConstIter (ContainerType& pc, int level, amrex::MFItInfo& info);
ParConstIterSoA (ContainerType& pc, int level, amrex::MFItInfo& info);
};

/** Beam Particles in ImpactX
*
* This class stores particles, distributed over MPI ranks.
*/
class ImpactXParticleContainer
: public amrex::ParticleContainer<0, 0, RealSoA::nattribs, IntSoA::nattribs>
: public amrex::ParticleContainerPureSoA<RealSoA::nattribs, IntSoA::nattribs>
{
public:
//! amrex iterator for particle boxes
using iterator = impactx::ParIter;
using iterator = impactx::ParIterSoA;

//! amrex constant iterator for particle boxes (read-only)
using const_iterator = impactx::ParConstIter;
using const_iterator = impactx::ParConstIterSoA;

//! Construct a new particle container
ImpactXParticleContainer (initialization::AmrCoreData* amr_core);
Expand Down Expand Up @@ -276,10 +251,6 @@ namespace impactx
DepositCharge (std::unordered_map<int, amrex::MultiFab> & rho,
amrex::Vector<amrex::IntVect> const & ref_ratio);

/** Get the name of each Real AoS component */
std::vector<std::string>
RealAoS_names () const;

/** Get the name of each Real SoA component */
std::vector<std::string>
RealSoA_names () const;
Expand Down Expand Up @@ -311,10 +282,6 @@ namespace impactx

}; // ImpactXParticleContainer

/** Get the name of each Real AoS component */
std::vector<std::string>
get_RealAoS_names ();

/** Get the name of each Real SoA component
*
* @param num_real_comps number of compile-time + runtime arrays
Expand Down
Loading

0 comments on commit 0e8b676

Please sign in to comment.