Skip to content

Commit

Permalink
Fix too many shader groups called validation error in llama3 on AMD a…
Browse files Browse the repository at this point in the history
…nd Intel GPUs
  • Loading branch information
0cc4m committed Jun 17, 2024
1 parent d63aca3 commit 0a321fc
Show file tree
Hide file tree
Showing 8 changed files with 19,802 additions and 19,336 deletions.
39,093 changes: 19,771 additions & 19,322 deletions ggml-vulkan-shaders.hpp

Large diffs are not rendered by default.

33 changes: 25 additions & 8 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ struct vk_context {
};

struct ggml_tensor_extra_gpu {
ggml_backend_vk_context * backend_ctx;
size_t ctx_idx;

vk_buffer_ref buffer_gpu;
Expand Down Expand Up @@ -2746,9 +2745,6 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
ggml_vk_ensure_sync_staging_buffer(src->device, size);
ggml_vk_ensure_sync_staging_buffer(dst->device, size);

std::lock_guard<std::mutex> src_lock(src->device->mutex);
std::lock_guard<std::mutex> dst_lock(dst->device->mutex);

// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
// memcpy to dst staging buffer
Expand Down Expand Up @@ -3228,18 +3224,30 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}

const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];

uint32_t groups_x = ne01;
uint32_t groups_z = 1;

if (ne01 > max_groups_x) {
groups_z = 64;
groups_x /= groups_z;
}

// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} },
sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
}

static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
Expand Down Expand Up @@ -3740,6 +3748,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}

const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];

uint32_t groups_x = ne01;
uint32_t groups_z = 1;

if (ne01 > max_groups_x) {
groups_z = 64;
groups_x /= groups_z;
}

// compute
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
Expand All @@ -3749,7 +3767,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23}, { d_ids, ids_buf_offset, ids_sz } },
sizeof(vk_mat_vec_id_push_constants), &pc, { (uint32_t)ne01, (uint32_t)nei0, 1 });
sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
}

static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
Expand Down Expand Up @@ -5606,7 +5624,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
}

extra->ctx_idx = ctx->compute_ctx->idx;
extra->backend_ctx = ctx;

#ifdef GGML_VULKAN_CHECK_RESULTS
// Force context reset on each node so that each tensor ends up in its own context
Expand Down
2 changes: 1 addition & 1 deletion vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];

void main() {
const uint row = gl_WorkGroupID.x;
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
const uint tid = gl_LocalInvocationID.x;

uint a_offset, b_offset, d_offset;
Expand Down
2 changes: 1 addition & 1 deletion vulkan-shaders/mul_mat_vec_q2_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];

void main() {
const uint row = gl_WorkGroupID.x;
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
Expand Down
2 changes: 1 addition & 1 deletion vulkan-shaders/mul_mat_vec_q3_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];

void main() {
const uint row = gl_WorkGroupID.x;
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
Expand Down
2 changes: 1 addition & 1 deletion vulkan-shaders/mul_mat_vec_q4_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];

void main() {
const uint row = gl_WorkGroupID.x;
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
Expand Down
2 changes: 1 addition & 1 deletion vulkan-shaders/mul_mat_vec_q5_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];

void main() {
const uint row = gl_WorkGroupID.x;
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
Expand Down
2 changes: 1 addition & 1 deletion vulkan-shaders/mul_mat_vec_q6_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];

void main() {
const uint row = gl_WorkGroupID.x;
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
Expand Down

0 comments on commit 0a321fc

Please sign in to comment.