Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with vmap and GPU memory issues #24099

Open
rruizdeaustri opened this issue Oct 3, 2024 · 0 comments
Open

Issues with vmap and GPU memory issues #24099

rruizdeaustri opened this issue Oct 3, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@rruizdeaustri
Copy link

rruizdeaustri commented Oct 3, 2024

Description

hello,

I am trying to convert pytorch code to JAX related with an algorithm that performs hamiltonian sampling couple to a dynamic nested sampling algorithm. The goal is to vmap GPU parallelize the procedure but I am getting GPU memory saturation. The code is complex and I cinlude the three main functions where everything happens. the method is iterative due to the while loop in "find_new_sample_batch". It looks like the shape of some variable grows in each iteration but I am not able to find it.

The vectorisation is done to paralelize a loop over a batch of points in 2D. In this case the batch size is 25

The trace is this:

File "/r5/home/rruiz/projects/lisa/gradNS/dynamic.py", line 91, in add_point_batch
    newsample = self.find_new_sample_batch(min_logL, n_points=n_points, labels=labels)
  File "/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py", line 1079, in find_new_sample_batch
    new_x_active, new_loglike_active, out_frac, like_evals = self.hamiltonian_slice_sampling(x_ini[active], velocity[active], min_loglike, self.key, self.dt, self.max_reflections, self.min_reflections)
  File "/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py", line 1020, in hamiltonian_slice_sampling
    pos_out, logl_out, killed, like_evals = vmap(
  File "/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py", line 999, in hamiltonian_slice_sampling_single
    final_state = lax.while_loop(cond_fn, body_fn, initial_state) 

 jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 48828125000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.82GiB
              constant allocation:         0B
        maybe_live_out allocation:   45.47GiB
     preallocated temp allocation:         0B
                 total allocation:   47.29GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 45.47GiB
		Operator: op_name="jit(broadcast_in_dim)/jit(main)/broadcast_in_dim[shape=(25, 25, 25, 25, 25, 25, 25) broadcast_dimensions=(1, 2, 3, 4, 5, 6)]" source_file="/r5/home/rruiz/projects/lisa/gradNS/hamiltonian.py" source_line=999
		XLA Label: fusion
		Shape: s64[25,25,25,25,25,25,25]
		==========================

	Buffer 2:
		Size: 1.82GiB
		Entry Parameter Subshape: s64[25,25,25,25,25,25]

   def hamiltonian_slice_sampling_single(
            self,
            position,
            velocity,
            min_like,
            key,
            dt,
            max_reflections,
            min_reflections
        ):
        """
        Hamiltonian Slice Sampling algorithm for a single point in JAX.
        """
        # Initialize variables
        n_reflections = 0
        num_steps = 0
        max_trajectory_length = max_reflections * 10  # To store trajectory points
        killed = False
        start_saving = False

        # Pre-allocate tensors to store results
        pos_tensor = jnp.zeros((max_trajectory_length, position.shape[0]))
        logl_tensor = jnp.zeros((max_trajectory_length,))
        mask_tensor = jnp.zeros((max_trajectory_length,), dtype=bool)
        memory = jnp.zeros((3,), dtype=bool)  # To track reflections

        x = position
        current_key = key

        # Add like_evals to the loop state
        like_evals = self.like_evals

        # Define a helper function to update memory without using jnp.roll
        def update_memory(memory, outside):
            """
            Updates the memory by shifting left and appending the new 'outside' value.
            """
            # Shift memory to the left by one
            memory_shifted = memory[1:]
            # Append the new 'outside' value
            memory_updated = jnp.concatenate([
                memory_shifted,
                jnp.array([outside], dtype=bool)
            ])
            return memory_updated

        # Define the condition function for the loop
        def cond_fn(state):
            (
                n_reflections,
                num_steps,
                x,
                velocity,
                pos_tensor,
                logl_tensor,
                mask_tensor,
                start_saving,
                killed,
                key,
                memory,
                like_evals  # Track like_evals in the state
            ) = state
            return jnp.logical_and(
                n_reflections < max_reflections,
                jnp.logical_and(
                    num_steps < (max_reflections * 100),
                    jnp.logical_not(killed)
                )
            )

        # Define the body of the loop
        def body_fn(state):
            (
                n_reflections,
                num_steps,
                x,
                velocity,
                pos_tensor,
                logl_tensor,
                mask_tensor,
                start_saving,
                killed,
                key,
                memory,
                like_evals  # Handle like_evals in the loop body
            ) = state

            num_steps += 1

            #jax.debug.print("Shape of x before update: {}", x.shape)
            #jax.debug.print("Shape of velocity before update: {}", velocity.shape)
    
            x = x + velocity * dt  # Euler step for updating position
            
            #jax.debug.print("Shape of x after update: {}", x.shape)
            
            x_reshape =  x.reshape(1, -1)
            
            # Debugging print for tensor shapes
            # Check if the point is inside the prior
            #x = x.reshape(1, -1)
            #jax.debug.print("Shape of x after reshape: {}", x.shape)
            in_prior = self.is_in_prior(x_reshape)
            #in_prior = self.is_in_prior(x)

            #sys.exit()
            # Calculate the log-likelihood and its gradient, also updating like_evals
            p_x, grad_p_x, like_evals_increment = self.get_score_hamiltonian(x_reshape)
            #p_x, grad_p_x, like_evals_increment = self.get_score_hamiltonian(x)
            p_x = jnp.squeeze(p_x)  # Ensure p_x is scalar
            like_evals += like_evals_increment  # Update like_evals

            # Check if the point is outside the slice
            reflected = p_x <= min_like
            reflected = jnp.squeeze(reflected)  # Ensure scalar
            outside = reflected | (~in_prior)
            outside = jnp.squeeze(outside)  # Ensure scalar
            outside = jnp.asarray(outside, dtype=bool)  # Ensure it's boolean

            # Update memory without using jnp.roll
            memory = update_memory(memory, outside)

            # Determine if the point should be killed
            killed = jnp.sum(memory) == 3

            # Pack all required values into the operand for lax.cond
            operand = (
                x,
                velocity,
                n_reflections,
                pos_tensor,
                logl_tensor,
                mask_tensor,
                start_saving,
                killed,
                key
            )

            # Define functions to handle killed and continue_process
            def return_killed(operand):
                (
                    x,
                    velocity,
                    n_reflections,
                    pos_tensor,
                    logl_tensor,
                    mask_tensor,
                    start_saving,
                    killed,
                    key
                ) = operand
                return (
                    jnp.zeros_like(x).reshape(x.shape),           # Placeholder for x
                    jnp.zeros_like(velocity).reshape(velocity.shape),    # Placeholder for velocity
                    n_reflections,               # Scalar n_reflections
                    pos_tensor,
                    logl_tensor,
                    mask_tensor,
                    start_saving,
                    True,                         # killed = True
                    key
                )

            def continue_process(operand):
                (
                    x,
                    velocity,
                    n_reflections,
                    pos_tensor,
                    logl_tensor,
                    mask_tensor,
                    start_saving,
                    killed,
                    key
                ) = operand

                # Reflect velocity if outside
                normal = grad_p_x / jnp.linalg.norm(grad_p_x)
                delta_velocity = 2 * jnp.vdot(velocity, normal) * normal
                velocity_reflected = jnp.where(outside, velocity - delta_velocity, velocity)

                # Split the random key
                key, subkey = random.split(key)

                # Apply prior perturbation if self.prior is defined
                if self.prior is not None:
                    key, subkey = random.split(key)
                    r = random.normal(subkey, velocity.shape)
                    velocity_reflected = jnp.where(~outside, dt * self.prior(x) + jnp.sqrt(2) * r, velocity_reflected)

                # Split the key again for velocity perturbation
                key, subkey = random.split(key)

                if self.sigma_vel > 0:
                    key, subkey = random.split(key)
                    r = random.normal(subkey, velocity.shape)
                    r /= jnp.linalg.norm(r, axis=-1, keepdims=True)  # Normalize random perturbation
                    velocity_reflected = jnp.where(~outside, velocity_reflected * (1 + self.sigma_vel * r), velocity_reflected)

                # Increment reflections count
                n_reflections_updated = n_reflections + reflected
                n_reflections_updated = jnp.squeeze(n_reflections_updated)  # Ensure scalar

                # Determine if we should save the point
                scalar_pred = jnp.squeeze(n_reflections_updated > min_reflections)

                def process_save_point(_):
                    start_saving = True
                    step_idx = num_steps % max_trajectory_length

                    # Update pos_tensor at step_idx
                    pos_tensor_updated = pos_tensor.at[step_idx].set(x)

                    # Update logl_tensor at step_idx with scalar p_x
                    logl_tensor_updated = logl_tensor.at[step_idx].set(p_x)

                    # Update mask_tensor at step_idx
                    mask_tensor_updated = mask_tensor.at[step_idx].set(~outside)

                    return pos_tensor_updated, logl_tensor_updated, mask_tensor_updated, True

                def skip_save_point(_):
                    return pos_tensor, logl_tensor, mask_tensor, start_saving

                pos_tensor_updated, logl_tensor_updated, mask_tensor_updated, start_saving_updated = lax.cond(
                    scalar_pred,
                    process_save_point,
                    skip_save_point,
                    operand=None
                )

                return (
                    x.reshape(x.shape),
                    velocity_reflected.reshape(velocity.shape),
                    n_reflections_updated,
                    pos_tensor_updated,
                    logl_tensor_updated,
                    mask_tensor_updated,
                    start_saving_updated,
                    False,  # killed = False
                    key
                )

            # Apply lax.cond to decide whether to return early (if killed) or continue processing
            x, velocity, n_reflections, pos_tensor, logl_tensor, mask_tensor, start_saving, killed, key = lax.cond(
                killed,
                return_killed,
                continue_process,
                operand
            )

            return (
                n_reflections,
                num_steps,
                x,
                velocity,
                pos_tensor,
                logl_tensor,
                mask_tensor,
                start_saving,
                killed,
                key,
                memory,
                like_evals  # Track like_evals in the loop state
            )

        # Initialize the state of the loop
        initial_state = (
            n_reflections,
            num_steps,
            x,
            velocity,
            pos_tensor,
            logl_tensor,
            mask_tensor,
            start_saving,
            killed,
            key,
            memory,
            like_evals  # Start with the initial value of like_evals
        )

        with checking_leaks():
         print('before while:loop', x.shape, velocity.shape)
         # Run the loop using `lax.while_loop`
         final_state = lax.while_loop(cond_fn, body_fn, initial_state)

        # Unpack the final state
        _, _, _, _, pos_tensor, logl_tensor, mask_tensor, _, _, key, memory, final_like_evals = final_state
 
        # Perform the JAX-compatible selection
        final_position, final_logl, has_valid = self.select_final_sample(mask_tensor, pos_tensor, logl_tensor, key)

        # Optionally, print additional debug information
        jax.debug.print("Has valid index: {}", has_valid)
        jax.debug.print("Final position: {}", final_position)
        jax.debug.print("Final log-likelihood: {}", final_logl)

        return final_position, final_logl, has_valid, final_like_evals     
    

    def hamiltonian_slice_sampling(self, positions, velocities, min_like, key, dt, max_reflections, min_reflections):
     keys = random.split(key, positions.shape[0])

     jax.debug.print("Keys shape after split: {}", keys.shape)  # Should be (batch_size, 2)

     pos_out, logl_out, killed, like_evals = vmap(
        self.hamiltonian_slice_sampling_single, 
        in_axes=(0, 0, None, 0, None, None, None)  # Vectorize over positions, velocities, and keys
     )(positions, velocities, min_like, keys, dt, max_reflections, min_reflections)

     # Remove references to intermediate tensors to allow garbage collection
     del positions, velocities, keys

     # Trigger Python garbage collection to free unused memory
     gc.collect()
     
     out_frac = jnp.mean(killed)
     return pos_out, logl_out, out_frac, like_evals

     
    def find_new_sample_batch(self, min_loglike, n_points, labels=None):
        """
        Sample the prior until finding a sample with higher likelihood than a
        given value
        Parameters
        ----------
        min_like : float
        The threshold log-likelihood
        labels : nlive_ini // 2 shape
        Returns
        -------
        newsample : pd.DataFrame
        A new sample
        """
        point = self.live_points.get_samples_from_labels(labels, key=self.key)
        ini_labels = point.get_labels()
        x_ini = point.get_values()  # shape nlive_ini // 2

        active = jnp.ones(x_ini.shape[0], dtype=jnp.bool_)
        new_x = jnp.zeros_like(x_ini)
        new_loglike = jnp.zeros(x_ini.shape[0])

        accepted = False
        count = 0
        while not accepted:
            count += 1
            if count > 10:
              print('finished', count)
              sys.exit()
            keys = random.split(self.key, x_ini.shape[0])  # Split the key into `batch_size` subkeys
            
            assert jnp.min(self.loglike(x_ini)) >= min_loglike, f"min_loglike = {min_loglike}, x_loglike = {self.loglike(x_ini)}"
            assert jnp.all(self.is_in_prior(x_ini)), f"min_loglike = {min_loglike}, x_loglike = {self.loglike(x_ini)}"

            # Generate initial velocities
            velocity = random.normal(self.key, x_ini.shape)
            velocity /= jnp.linalg.norm(velocity, axis=-1, keepdims=True)

            # Parallelize the slice sampling across all active points
            #vmapped_hamiltonian_sampling = jax.vmap(self.hamiltonian_slice_sampling, in_axes=(0, 0, 0, None), out_axes=(0, 0))
            #new_x_active, new_loglike_active, out_frac = vmapped_hamiltonian_sampling(x_ini[active], velocity[active], keys, min_loglike)


            # Instead of passing keys (list of keys), pass a single key.
            new_x_active, new_loglike_active, out_frac, like_evals = self.hamiltonian_slice_sampling(x_ini[active], velocity[active], min_loglike, self.key, self.dt, self.max_reflections, self.min_reflections)

            self.like_evals = like_evals
 
            #new_x_active, new_loglike_active, out_frac = self.hamiltonian_slice_sampling(position=x_ini[active], velocity=velocity[active], min_like=min_loglike)
                        
            new_x = new_x.at[active].set(new_x_active)
            #print('new_x', new_x, out_frac)
            
            new_loglike = new_loglike.at[active].set(new_loglike_active)

            if (out_frac > 0.15) and (jnp.sum(active) >= max(2, len(active) // 2)):
                self.dt = jnp.clip(self.dt * 0.9, 1e-5, 10)
                if self.verbose:
                    print("Decreasing dt to ", self.dt, "out_frac = ", out_frac, "active = ", jnp.sum(active))
                active = jnp.ones(x_ini.shape[0], dtype=jnp.bool_)
            elif (out_frac < 0.05) and (jnp.sum(active) >= max(2, len(active) // 2)):
                self.dt = jnp.clip(self.dt * 1.1, 1e-5, 10)
                if self.verbose:
                    print("Increasing dt to ", self.dt, "out_frac = ", out_frac, "active = ", jnp.sum(active))
                active = jnp.ones(x_ini.shape[0], dtype=jnp.bool_)
            else:
                in_prior = self.is_in_prior(new_x)
                active = (new_loglike < min_loglike) | (~in_prior)
                if self.verbose and jnp.sum(active) > 0:
                    print(f"Active: {jnp.sum(active)} / {len(active)}")

            print(new_x, new_loglike, active, jnp.sum(active))       
            #sys.exit()
            accepted = jnp.sum(active) == 0
        sys.exit()
        assert jnp.min(new_loglike) >= min_loglike, f"min_loglike = {min_loglike}, new_loglike = {new_loglike}"
        sample = NSPoints(self.nparams)
        sample.add_samples(values=new_x, logL=new_loglike, logweights=jnp.zeros(new_loglike.shape[0]), labels=ini_labels)

        gc.collect()
        return sample

Thanks !

best,
Roberto

System info (python version, jaxlib version, accelerator, etc.)

python 3.9.19
jaxlib 0.4.28

jax.print_environment_info()
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='som5.ific.uv.es', release='4.18.0-373.el8.x86_64', version='#1 SMP Tue Mar 22 15:11:47 UTC 2022', machine='x86_64')

$ nvidia-smi
Thu Oct 3 17:42:15 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla V100-PCIE-32GB Off | 00000000:61:00.0 Off | 0 |
| N/A 37C P0 38W / 250W | 0MiB / 32768MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 Tesla T4 Off | 00000000:DB:00.0 Off | 0 |
| N/A 43C P0 28W / 70W | 104MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 1 N/A N/A 3709258 C python 102MiB |
+-----------------------------------------------------------------------------------------+

@rruizdeaustri rruizdeaustri added the bug Something isn't working label Oct 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant