You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to convert pytorch code to JAX related with an algorithm that performs hamiltonian sampling couple to a dynamic nested sampling algorithm. The goal is to vmap GPU parallelize the procedure but I am getting GPU memory saturation. The code is complex and I cinlude the three main functions where everything happens. the method is iterative due to the while loop in "find_new_sample_batch". It looks like the shape of some variable grows in each iteration but I am not able to find it.
The vectorisation is done to paralelize a loop over a batch of points in 2D. In this case the batch size is 25
The trace is this:
File "/r5/home/rruiz/projects/lisa/gradNS/dynamic.py", line 91, in add_point_batch
newsample = self.find_new_sample_batch(min_logL, n_points=n_points, labels=labels)
File "/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py", line 1079, in find_new_sample_batch
new_x_active, new_loglike_active, out_frac, like_evals = self.hamiltonian_slice_sampling(x_ini[active], velocity[active], min_loglike, self.key, self.dt, self.max_reflections, self.min_reflections)
File "/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py", line 1020, in hamiltonian_slice_sampling
pos_out, logl_out, killed, like_evals = vmap(
File "/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py", line 999, in hamiltonian_slice_sampling_single
final_state = lax.while_loop(cond_fn, body_fn, initial_state)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 48828125000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.82GiB
constant allocation: 0B
maybe_live_out allocation: 45.47GiB
preallocated temp allocation: 0B
total allocation: 47.29GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 45.47GiB
Operator: op_name="jit(broadcast_in_dim)/jit(main)/broadcast_in_dim[shape=(25, 25, 25, 25, 25, 25, 25) broadcast_dimensions=(1, 2, 3, 4, 5, 6)]" source_file="/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py" source_line=999
XLA Label: fusion
Shape: s64[25,25,25,25,25,25,25]
==========================
Buffer 2:
Size: 1.82GiB
Entry Parameter Subshape: s64[25,25,25,25,25,25]
def hamiltonian_slice_sampling_single(
self,
position,
velocity,
min_like,
key,
dt,
max_reflections,
min_reflections
):
"""
Hamiltonian Slice Sampling algorithm for a single point in JAX.
"""
# Initialize variables
n_reflections = 0
num_steps = 0
max_trajectory_length = max_reflections * 10 # To store trajectory points
killed = False
start_saving = False
# Pre-allocate tensors to store results
pos_tensor = jnp.zeros((max_trajectory_length, position.shape[0]))
logl_tensor = jnp.zeros((max_trajectory_length,))
mask_tensor = jnp.zeros((max_trajectory_length,), dtype=bool)
memory = jnp.zeros((3,), dtype=bool) # To track reflections
x = position
current_key = key
# Add like_evals to the loop state
like_evals = self.like_evals
# Define a helper function to update memory without using jnp.roll
def update_memory(memory, outside):
"""
Updates the memory by shifting left and appending the new 'outside' value.
"""
# Shift memory to the left by one
memory_shifted = memory[1:]
# Append the new 'outside' value
memory_updated = jnp.concatenate([
memory_shifted,
jnp.array([outside], dtype=bool)
])
return memory_updated
# Define the condition function for the loop
def cond_fn(state):
(
n_reflections,
num_steps,
x,
velocity,
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
killed,
key,
memory,
like_evals # Track like_evals in the state
) = state
return jnp.logical_and(
n_reflections < max_reflections,
jnp.logical_and(
num_steps < (max_reflections * 100),
jnp.logical_not(killed)
)
)
# Define the body of the loop
def body_fn(state):
(
n_reflections,
num_steps,
x,
velocity,
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
killed,
key,
memory,
like_evals # Handle like_evals in the loop body
) = state
num_steps += 1
#jax.debug.print("Shape of x before update: {}", x.shape)
#jax.debug.print("Shape of velocity before update: {}", velocity.shape)
x = x + velocity * dt # Euler step for updating position
#jax.debug.print("Shape of x after update: {}", x.shape)
x_reshape = x.reshape(1, -1)
# Debugging print for tensor shapes
# Check if the point is inside the prior
#x = x.reshape(1, -1)
#jax.debug.print("Shape of x after reshape: {}", x.shape)
in_prior = self.is_in_prior(x_reshape)
#in_prior = self.is_in_prior(x)
#sys.exit()
# Calculate the log-likelihood and its gradient, also updating like_evals
p_x, grad_p_x, like_evals_increment = self.get_score_hamiltonian(x_reshape)
#p_x, grad_p_x, like_evals_increment = self.get_score_hamiltonian(x)
p_x = jnp.squeeze(p_x) # Ensure p_x is scalar
like_evals += like_evals_increment # Update like_evals
# Check if the point is outside the slice
reflected = p_x <= min_like
reflected = jnp.squeeze(reflected) # Ensure scalar
outside = reflected | (~in_prior)
outside = jnp.squeeze(outside) # Ensure scalar
outside = jnp.asarray(outside, dtype=bool) # Ensure it's boolean
# Update memory without using jnp.roll
memory = update_memory(memory, outside)
# Determine if the point should be killed
killed = jnp.sum(memory) == 3
# Pack all required values into the operand for lax.cond
operand = (
x,
velocity,
n_reflections,
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
killed,
key
)
# Define functions to handle killed and continue_process
def return_killed(operand):
(
x,
velocity,
n_reflections,
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
killed,
key
) = operand
return (
jnp.zeros_like(x).reshape(x.shape), # Placeholder for x
jnp.zeros_like(velocity).reshape(velocity.shape), # Placeholder for velocity
n_reflections, # Scalar n_reflections
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
True, # killed = True
key
)
def continue_process(operand):
(
x,
velocity,
n_reflections,
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
killed,
key
) = operand
# Reflect velocity if outside
normal = grad_p_x / jnp.linalg.norm(grad_p_x)
delta_velocity = 2 * jnp.vdot(velocity, normal) * normal
velocity_reflected = jnp.where(outside, velocity - delta_velocity, velocity)
# Split the random key
key, subkey = random.split(key)
# Apply prior perturbation if self.prior is defined
if self.prior is not None:
key, subkey = random.split(key)
r = random.normal(subkey, velocity.shape)
velocity_reflected = jnp.where(~outside, dt * self.prior(x) + jnp.sqrt(2) * r, velocity_reflected)
# Split the key again for velocity perturbation
key, subkey = random.split(key)
if self.sigma_vel > 0:
key, subkey = random.split(key)
r = random.normal(subkey, velocity.shape)
r /= jnp.linalg.norm(r, axis=-1, keepdims=True) # Normalize random perturbation
velocity_reflected = jnp.where(~outside, velocity_reflected * (1 + self.sigma_vel * r), velocity_reflected)
# Increment reflections count
n_reflections_updated = n_reflections + reflected
n_reflections_updated = jnp.squeeze(n_reflections_updated) # Ensure scalar
# Determine if we should save the point
scalar_pred = jnp.squeeze(n_reflections_updated > min_reflections)
def process_save_point(_):
start_saving = True
step_idx = num_steps % max_trajectory_length
# Update pos_tensor at step_idx
pos_tensor_updated = pos_tensor.at[step_idx].set(x)
# Update logl_tensor at step_idx with scalar p_x
logl_tensor_updated = logl_tensor.at[step_idx].set(p_x)
# Update mask_tensor at step_idx
mask_tensor_updated = mask_tensor.at[step_idx].set(~outside)
return pos_tensor_updated, logl_tensor_updated, mask_tensor_updated, True
def skip_save_point(_):
return pos_tensor, logl_tensor, mask_tensor, start_saving
pos_tensor_updated, logl_tensor_updated, mask_tensor_updated, start_saving_updated = lax.cond(
scalar_pred,
process_save_point,
skip_save_point,
operand=None
)
return (
x.reshape(x.shape),
velocity_reflected.reshape(velocity.shape),
n_reflections_updated,
pos_tensor_updated,
logl_tensor_updated,
mask_tensor_updated,
start_saving_updated,
False, # killed = False
key
)
# Apply lax.cond to decide whether to return early (if killed) or continue processing
x, velocity, n_reflections, pos_tensor, logl_tensor, mask_tensor, start_saving, killed, key = lax.cond(
killed,
return_killed,
continue_process,
operand
)
return (
n_reflections,
num_steps,
x,
velocity,
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
killed,
key,
memory,
like_evals # Track like_evals in the loop state
)
# Initialize the state of the loop
initial_state = (
n_reflections,
num_steps,
x,
velocity,
pos_tensor,
logl_tensor,
mask_tensor,
start_saving,
killed,
key,
memory,
like_evals # Start with the initial value of like_evals
)
with checking_leaks():
print('before while:loop', x.shape, velocity.shape)
# Run the loop using `lax.while_loop`
final_state = lax.while_loop(cond_fn, body_fn, initial_state)
# Unpack the final state
_, _, _, _, pos_tensor, logl_tensor, mask_tensor, _, _, key, memory, final_like_evals = final_state
# Perform the JAX-compatible selection
final_position, final_logl, has_valid = self.select_final_sample(mask_tensor, pos_tensor, logl_tensor, key)
# Optionally, print additional debug information
jax.debug.print("Has valid index: {}", has_valid)
jax.debug.print("Final position: {}", final_position)
jax.debug.print("Final log-likelihood: {}", final_logl)
return final_position, final_logl, has_valid, final_like_evals
def hamiltonian_slice_sampling(self, positions, velocities, min_like, key, dt, max_reflections, min_reflections):
keys = random.split(key, positions.shape[0])
jax.debug.print("Keys shape after split: {}", keys.shape) # Should be (batch_size, 2)
pos_out, logl_out, killed, like_evals = vmap(
self.hamiltonian_slice_sampling_single,
in_axes=(0, 0, None, 0, None, None, None) # Vectorize over positions, velocities, and keys
)(positions, velocities, min_like, keys, dt, max_reflections, min_reflections)
# Remove references to intermediate tensors to allow garbage collection
del positions, velocities, keys
# Trigger Python garbage collection to free unused memory
gc.collect()
out_frac = jnp.mean(killed)
return pos_out, logl_out, out_frac, like_evals
def find_new_sample_batch(self, min_loglike, n_points, labels=None):
"""
Sample the prior until finding a sample with higher likelihood than a
given value
Parameters
----------
min_like : float
The threshold log-likelihood
labels : nlive_ini // 2 shape
Returns
-------
newsample : pd.DataFrame
A new sample
"""
point = self.live_points.get_samples_from_labels(labels, key=self.key)
ini_labels = point.get_labels()
x_ini = point.get_values() # shape nlive_ini // 2
active = jnp.ones(x_ini.shape[0], dtype=jnp.bool_)
new_x = jnp.zeros_like(x_ini)
new_loglike = jnp.zeros(x_ini.shape[0])
accepted = False
count = 0
while not accepted:
count += 1
if count > 10:
print('finished', count)
sys.exit()
keys = random.split(self.key, x_ini.shape[0]) # Split the key into `batch_size` subkeys
assert jnp.min(self.loglike(x_ini)) >= min_loglike, f"min_loglike = {min_loglike}, x_loglike = {self.loglike(x_ini)}"
assert jnp.all(self.is_in_prior(x_ini)), f"min_loglike = {min_loglike}, x_loglike = {self.loglike(x_ini)}"
# Generate initial velocities
velocity = random.normal(self.key, x_ini.shape)
velocity /= jnp.linalg.norm(velocity, axis=-1, keepdims=True)
# Parallelize the slice sampling across all active points
#vmapped_hamiltonian_sampling = jax.vmap(self.hamiltonian_slice_sampling, in_axes=(0, 0, 0, None), out_axes=(0, 0))
#new_x_active, new_loglike_active, out_frac = vmapped_hamiltonian_sampling(x_ini[active], velocity[active], keys, min_loglike)
# Instead of passing keys (list of keys), pass a single key.
new_x_active, new_loglike_active, out_frac, like_evals = self.hamiltonian_slice_sampling(x_ini[active], velocity[active], min_loglike, self.key, self.dt, self.max_reflections, self.min_reflections)
self.like_evals = like_evals
#new_x_active, new_loglike_active, out_frac = self.hamiltonian_slice_sampling(position=x_ini[active], velocity=velocity[active], min_like=min_loglike)
new_x = new_x.at[active].set(new_x_active)
#print('new_x', new_x, out_frac)
new_loglike = new_loglike.at[active].set(new_loglike_active)
if (out_frac > 0.15) and (jnp.sum(active) >= max(2, len(active) // 2)):
self.dt = jnp.clip(self.dt * 0.9, 1e-5, 10)
if self.verbose:
print("Decreasing dt to ", self.dt, "out_frac = ", out_frac, "active = ", jnp.sum(active))
active = jnp.ones(x_ini.shape[0], dtype=jnp.bool_)
elif (out_frac < 0.05) and (jnp.sum(active) >= max(2, len(active) // 2)):
self.dt = jnp.clip(self.dt * 1.1, 1e-5, 10)
if self.verbose:
print("Increasing dt to ", self.dt, "out_frac = ", out_frac, "active = ", jnp.sum(active))
active = jnp.ones(x_ini.shape[0], dtype=jnp.bool_)
else:
in_prior = self.is_in_prior(new_x)
active = (new_loglike < min_loglike) | (~in_prior)
if self.verbose and jnp.sum(active) > 0:
print(f"Active: {jnp.sum(active)} / {len(active)}")
print(new_x, new_loglike, active, jnp.sum(active))
#sys.exit()
accepted = jnp.sum(active) == 0
sys.exit()
assert jnp.min(new_loglike) >= min_loglike, f"min_loglike = {min_loglike}, new_loglike = {new_loglike}"
sample = NSPoints(self.nparams)
sample.add_samples(values=new_x, logL=new_loglike, logweights=jnp.zeros(new_loglike.shape[0]), labels=ini_labels)
gc.collect()
return sample
Thanks !
best,
Roberto
System info (python version, jaxlib version, accelerator, etc.)
$ nvidia-smi
Thu Oct 3 17:42:15 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla V100-PCIE-32GB Off | 00000000:61:00.0 Off | 0 |
| N/A 37C P0 38W / 250W | 0MiB / 32768MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 Tesla T4 Off | 00000000:DB:00.0 Off | 0 |
| N/A 43C P0 28W / 70W | 104MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 1 N/A N/A 3709258 C python 102MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
hello,
I am trying to convert pytorch code to JAX related with an algorithm that performs hamiltonian sampling couple to a dynamic nested sampling algorithm. The goal is to vmap GPU parallelize the procedure but I am getting GPU memory saturation. The code is complex and I cinlude the three main functions where everything happens. the method is iterative due to the while loop in "find_new_sample_batch". It looks like the shape of some variable grows in each iteration but I am not able to find it.
The vectorisation is done to paralelize a loop over a batch of points in 2D. In this case the batch size is 25
The trace is this:
Thanks !
best,
Roberto
System info (python version, jaxlib version, accelerator, etc.)
python 3.9.19
jaxlib 0.4.28
jax.print_environment_info()
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='som5.ific.uv.es', release='4.18.0-373.el8.x86_64', version='#1 SMP Tue Mar 22 15:11:47 UTC 2022', machine='x86_64')
$ nvidia-smi
Thu Oct 3 17:42:15 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla V100-PCIE-32GB Off | 00000000:61:00.0 Off | 0 |
| N/A 37C P0 38W / 250W | 0MiB / 32768MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 Tesla T4 Off | 00000000:DB:00.0 Off | 0 |
| N/A 43C P0 28W / 70W | 104MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 1 N/A N/A 3709258 C python 102MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: