mirror of
https://github.com/ggml-org/whisper.cpp.git
synced 2025-09-15 13:28:35 +08:00
ggml: Add initial WebGPU backend (llama/14521)
ggml-ci
This commit is contained in:
parent
03d6607691
commit
5ed45b2518
54
ggml/src/ggml-webgpu/CMakeLists.txt
Normal file
54
ggml/src/ggml-webgpu/CMakeLists.txt
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.13)
|
||||||
|
|
||||||
|
find_package(Python3 REQUIRED)
|
||||||
|
|
||||||
|
# Shader locations
|
||||||
|
set(SHADER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders")
|
||||||
|
set(SHADER_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
|
||||||
|
set(SHADER_HEADER "${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp")
|
||||||
|
file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
|
||||||
|
|
||||||
|
message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
|
||||||
|
|
||||||
|
# Find all WGSL files
|
||||||
|
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
|
||||||
|
|
||||||
|
# Generate the header using a Python script
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT ${SHADER_HEADER}
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E echo "Embedding WGSL shaders to ggml-wgsl-shaders.hpp"
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
|
||||||
|
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||||
|
--input "${SHADER_DIR}"
|
||||||
|
--output "${SHADER_HEADER}"
|
||||||
|
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||||
|
VERBATIM
|
||||||
|
)
|
||||||
|
|
||||||
|
add_custom_target(generate_shaders DEPENDS ${SHADER_HEADER})
|
||||||
|
|
||||||
|
ggml_add_backend_library(ggml-webgpu
|
||||||
|
ggml-webgpu.cpp
|
||||||
|
${SHADER_HEADER}
|
||||||
|
../../include/ggml-webgpu.h
|
||||||
|
)
|
||||||
|
|
||||||
|
add_dependencies(ggml-webgpu generate_shaders)
|
||||||
|
|
||||||
|
if(EMSCRIPTEN)
|
||||||
|
set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
|
||||||
|
|
||||||
|
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||||
|
target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||||
|
else()
|
||||||
|
find_package(Dawn REQUIRED)
|
||||||
|
set(DawnWebGPU_TARGET dawn::webgpu_dawn)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_WEBGPU_DEBUG)
|
||||||
|
target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR})
|
||||||
|
target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET})
|
||||||
1190
ggml/src/ggml-webgpu/ggml-webgpu.cpp
Normal file
1190
ggml/src/ggml-webgpu/ggml-webgpu.cpp
Normal file
File diff suppressed because it is too large
Load Diff
60
ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl
Normal file
60
ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
enable f16;
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<f16>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ne: u32, // total number of elements
|
||||||
|
offset_src: u32, // in elements
|
||||||
|
offset_dst: u32, // in elements
|
||||||
|
|
||||||
|
// Strides (in elements) — may be permuted
|
||||||
|
stride_src0: u32,
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_dst0: u32,
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// Logical shape (same for both tensors)
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
ne3: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||||
|
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||||
|
|
||||||
|
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
|
||||||
|
i2 * params.stride_dst2 + i3 * params.stride_dst3;
|
||||||
|
|
||||||
|
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
|
||||||
|
}
|
||||||
35
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Executable file
35
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def escape_triple_quotes(wgsl):
|
||||||
|
# Simple defense in case of embedded """
|
||||||
|
return wgsl.replace('"""', '\\"""')
|
||||||
|
|
||||||
|
|
||||||
|
def to_cpp_string_literal(varname, content):
|
||||||
|
return f'const char* wgsl_{varname} = R"({content})";\n'
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--input', required=True)
|
||||||
|
parser.add_argument('--output', required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.output, 'w', encoding='utf-8') as out:
|
||||||
|
out.write("// Auto-generated shader embedding \n\n")
|
||||||
|
for fname in sorted(os.listdir(args.input)):
|
||||||
|
if not fname.endswith('.wgsl'):
|
||||||
|
continue
|
||||||
|
shader_path = os.path.join(args.input, fname)
|
||||||
|
varname = os.path.splitext(fname)[0]
|
||||||
|
with open(shader_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
content = escape_triple_quotes(content)
|
||||||
|
out.write(to_cpp_string_literal(varname, content))
|
||||||
|
out.write('\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
40
ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
Normal file
40
ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> output_buffer: array<u32>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset: u32, // in bytes
|
||||||
|
size: u32, // in bytes
|
||||||
|
value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
override bytes_per_thread: u32;
|
||||||
|
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
let i = gid.x * bytes_per_thread;
|
||||||
|
let start = params.offset;
|
||||||
|
let end = params.offset + params.size;
|
||||||
|
|
||||||
|
for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
|
||||||
|
let byte_index = start + i + j;
|
||||||
|
if (byte_index + 4u <= end) {
|
||||||
|
output_buffer[(byte_index >> 2u)] = params.value;
|
||||||
|
} else {
|
||||||
|
// Handle tail (unaligned)
|
||||||
|
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
||||||
|
let idx = byte_index + k;
|
||||||
|
if (idx < end) {
|
||||||
|
let word_idx = idx >> 2u;
|
||||||
|
let byte_offset = (idx & 3u) * 8u;
|
||||||
|
let mask = ~(0xffu << byte_offset);
|
||||||
|
let existing = output_buffer[word_idx];
|
||||||
|
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
56
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
Normal file
56
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
struct MulMatParams {
|
||||||
|
m: u32,
|
||||||
|
n: u32,
|
||||||
|
k: u32,
|
||||||
|
// all strides are in elements
|
||||||
|
stride_01: u32,
|
||||||
|
stride_11: u32,
|
||||||
|
stride_02: u32,
|
||||||
|
stride_12: u32,
|
||||||
|
stride_03: u32,
|
||||||
|
stride_13: u32,
|
||||||
|
|
||||||
|
bs02: u32,
|
||||||
|
bs03: u32,
|
||||||
|
broadcast2: u32,
|
||||||
|
broadcast3: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
|
||||||
|
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
|
||||||
|
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||||
|
|
||||||
|
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||||
|
|
||||||
|
@compute @workgroup_size(64)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||||
|
if (global_id.x >= total) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst2_stride = params.m * params.n;
|
||||||
|
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||||
|
|
||||||
|
let dst3_idx = global_id.x / dst3_stride;
|
||||||
|
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
|
||||||
|
let src13_idx = dst3_idx; // src1 is not broadcast
|
||||||
|
let dst3_rem = global_id.x % dst3_stride;
|
||||||
|
|
||||||
|
let dst2_idx = dst3_rem / dst2_stride;
|
||||||
|
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
|
||||||
|
let src12_idx = dst2_idx; // src1 is not broadcast
|
||||||
|
|
||||||
|
let dst2_rem = dst3_rem % dst2_stride;
|
||||||
|
|
||||||
|
let row = dst2_rem / params.n; // output row
|
||||||
|
let col = dst2_rem % params.n; // output column
|
||||||
|
|
||||||
|
var sum = 0.0;
|
||||||
|
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
|
||||||
|
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
|
||||||
|
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
|
||||||
|
sum = sum + src0[src0_idx] * src1[src1_idx];
|
||||||
|
}
|
||||||
|
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
|
||||||
|
}
|
||||||
82
ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
Normal file
82
ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
enable f16;
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> idx: array<u32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<f16>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<storage, read_write> error: atomic<u32>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src: u32, // in elements
|
||||||
|
offset_idx: u32, // in elements
|
||||||
|
offset_dst: u32, // in elements
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_idx0: u32,
|
||||||
|
stride_idx1: u32,
|
||||||
|
stride_idx2: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// Shape of src
|
||||||
|
ne0: u32,
|
||||||
|
n_rows: u32,
|
||||||
|
ne2: u32,
|
||||||
|
ne3: u32,
|
||||||
|
|
||||||
|
// Shape of idx
|
||||||
|
idx1: u32,
|
||||||
|
idx2: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(4)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
var i = gid.x;
|
||||||
|
let i_src3 = i / (params.ne2 * params.n_rows);
|
||||||
|
let i_dst3 = i / (params.ne2 * 3);
|
||||||
|
|
||||||
|
i = i % (params.ne2 * params.n_rows);
|
||||||
|
let i_src2 = i / params.n_rows;
|
||||||
|
let i_src1 = i % params.n_rows;
|
||||||
|
|
||||||
|
let i_idx2 = i_src3 % params.idx2;
|
||||||
|
let i_idx1 = i_src2 % params.idx1;
|
||||||
|
let i_idx0 = i_src1;
|
||||||
|
|
||||||
|
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
|
||||||
|
|
||||||
|
let idx_high_val = idx[idx_high];
|
||||||
|
let idx_low_val = idx[idx_high + 1];
|
||||||
|
|
||||||
|
if (idx_low_val != 0) {
|
||||||
|
// Upper bits of index are not zero, output will be incorrect
|
||||||
|
atomicStore(&error, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||||
|
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
||||||
|
|
||||||
|
for (var i: u32 = 0; i < params.ne0; i++) {
|
||||||
|
dst[i_dst_row + i] = f16(src[i_src_row + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user