Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal gpu matrix3D addition test #87

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4b38bb7
metal gpu matrix3D addition test
DerrickYLJ Jan 15, 2024
cb5419a
metal gpu matrix3D addition test
DerrickYLJ Jan 15, 2024
d23203c
add element-wise add test in swift
DerrickYLJ Jan 16, 2024
3005bb3
metal c++ version inside cpp folder
DerrickYLJ Jan 24, 2024
7cad520
minor fix
RaymondWang0 Jan 25, 2024
69c929d
speedup issue with metal
DerrickYLJ Jan 25, 2024
11fa7b6
minor fix
RaymondWang0 Jan 26, 2024
37ff487
reorganized clear version for metal main
DerrickYLJ Jan 27, 2024
c3e3316
matmul metal
DerrickYLJ Jan 27, 2024
d55be07
matmul correctness pass
DerrickYLJ Jan 27, 2024
ab446e7
matmul work
DerrickYLJ Jan 27, 2024
69407d5
header param
DerrickYLJ Jan 28, 2024
dd9fda7
metal matmul Int4 working
DerrickYLJ Jan 29, 2024
e8331c3
kernel minor change
DerrickYLJ Jan 29, 2024
e79e702
add metal op
RaymondWang0 Jan 30, 2024
bc5cad2
interleave two versions equal
DerrickYLJ Jan 31, 2024
3a3e97d
add matmulInt4_SIMD with different versions
DerrickYLJ Jan 31, 2024
17e90b6
building metal interface, done with kernel ops
DerrickYLJ Feb 1, 2024
a3f8852
update metal-cpp
RaymondWang0 Feb 1, 2024
246744b
everything except for TODO
DerrickYLJ Feb 1, 2024
1b5d027
rms_norm done
DerrickYLJ Feb 1, 2024
c5fb54f
done with softmax, rope and minor fix on opparam struct
DerrickYLJ Feb 2, 2024
189ed98
update for test
DerrickYLJ Feb 2, 2024
3217635
updated matmulf32 + test script
DerrickYLJ Feb 4, 2024
06c472a
new test script on normalization
DerrickYLJ Feb 4, 2024
0c38582
more metal ops (TODO needed)
DerrickYLJ Feb 20, 2024
962266e
All basic operations and matmul are included
DerrickYLJ Feb 26, 2024
cd416a3
update llama matmul parameter and test
DerrickYLJ Mar 11, 2024
67ea0cc
new format for metal in general
DerrickYLJ Mar 22, 2024
2b71f9f
fix matmul
RaymondWang0 Mar 22, 2024
f792a1b
add new kernels and reorganize
DerrickYLJ Mar 26, 2024
2440e9e
Merge branch 'MetalGPU' of https://github.com/mit-han-lab/TinyChatEng…
DerrickYLJ Mar 26, 2024
0b094ad
fix parameters
RaymondWang0 Apr 6, 2024
926dd0d
draft llama.cpp model
DerrickYLJ Apr 16, 2024
3ac955c
update rope
DerrickYLJ Apr 26, 2024
46d9627
update makefile
RaymondWang0 May 1, 2024
0d4cd66
reorganize metal
DerrickYLJ May 12, 2024
562d32a
minor fix
DerrickYLJ May 16, 2024
a269001
add op source
DerrickYLJ May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion kernels/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ typedef half_float::half naive_float16_t;
#include <cuda_fp16.h>
#include <cuda_runtime.h>
typedef half float16_t;
#elif defined(QM_METAL)
typedef half_float::half float16_t;
typedef float16_t half;
#elif defined(__ARM_NEON)
typedef __fp16 float16_t;
#elif defined(__x86_64__)
Expand Down Expand Up @@ -99,6 +102,10 @@ struct thread_args {
int start_i, end_i, blk_size;
};

#ifdef QM_METAL
#include "metal/include/metal_compute.h"
// typedef half_float::half half;
#endif

#define MAX(A, B) ((A) > (B) ? (A) : (B))
#define MIN(A, B) ((A) < (B) ? (A) : (B))
Expand All @@ -121,7 +128,7 @@ class MatmulOperator {
// void mat_mul_accelerator_int8_fast_2x2_omp(const struct matmul_params *params);
// int4
void mat_mul_accelerator_int4_fast(const struct matmul_params *params);
void mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params);
void mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params); //also supported by metal
void mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params *params);
void naive_mat_mul_int4(const struct matmul_params *params);
void naive_mat_mul_int4_with_offset(const struct matmul_params *params);
Expand All @@ -136,6 +143,26 @@ class MatmulOperator {
void gemm_forward_cuda_half_test(const struct matmul_params *params, int split_k_iters);
//// GEMV
void gemv_forward_cuda(const struct matmul_params *params);
// metal
void mat_mul_int4_f32_metal(const struct matmul_params *params);
void mat_mul_f32_f32_metal(const struct matmul_params *params);
void mat_vec_int4_f32_metal(const struct matmul_params *params);
void mat_vec_f32_f32_metal(const struct matmul_params *params);
// void batch_add_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z);
// void mat_mul_f32_metal(const struct matmul_params *params);
// void batch_add_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z);
// void relu_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z);
// void silu_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z);
// void gelu_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z);
// void gelu_quick_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z);
// void rms_norm_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z, float eps);
// void soft_max_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z, int64_t scale);
// void soft_max_4_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z, int64_t scale);
// void rope_metal(const struct matmul_params *params, unsigned int m_dim_x, unsigned int m_dim_y, unsigned int m_dim_z,
// int n_past, int n_dims, int mode, int n_orig_ctx, float freq_base, float freq_scale, float ext_factor, float attn_factor,
// float beta_fast, float beta_slow);



private:
float interval_to_us(struct timeval *start, struct timeval *end);
Expand Down
125 changes: 125 additions & 0 deletions kernels/metal/include/metal_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#ifndef METAL_COMPUTE_H
#define METAL_COMPUTE_H

#include "../../matmul.h"
#include "operators.h"
#include "Foundation/Foundation.hpp"
#include "Metal/Metal.hpp"

bool has_init = false;

struct metal_kernel {
MTL::ComputePipelineState * pipeline;
};

struct metal_context * ctx;
struct metal_cgraph * mgraph;

enum {
MTLGPUFamilyApple1 = 1001,
MTLGPUFamilyCommon1 = 3001,
MTLGPUFamilyMetal3 = 5001,
MTLGPUFamilyApple7 = 1007,
};

enum metal_kernel_type {
METAL_KERNEL_FLOAT2HALF,
METAL_KERNEL_HALF2FLOAT,
METAL_KERNEL_PREPARE_DECODER_ATTENTION_MASK_HALF,
METAL_KERNEL_SILUMUL_HALF,
METAL_KERNEL_ADD_HALF,
METAL_KERNEL_SHAPE_QKV,
METAL_KERNEL_UNSHAPE,
METAL_KERNEL_TRANSPOSE_1_2IDX,
METAL_KERNEL_CHECK_INF_HALF,
METAL_KERNEL_EMBEDDING,
METAL_KERNEL_BATCH_ADD,
METAL_KERNEL_RELU,
METAL_KERNEL_SILU,
METAL_KERNEL_GELU,
METAL_KERNEL_GELU_QUICK,
METAL_KERNEL_RMS_NORM,
METAL_KERNEL_SOFT_MAX,
METAL_KERNEL_SOFT_MAX_4,
METAL_KERNEL_ROPE,
METAL_KERNEL_MUL_MM_INT4_F32,
METAL_KERNEL_MUL_MV_INT4_F32,
METAL_KERNEL_MUL_MM_F32_F32,
METAL_KERNEL_MUL_MV_F32_F32,
METAL_KERNEL_TYPE_COUNT
};

enum status {
STATUS_SUCCESS,
STATUS_FAILED
};

// Context struct holding Metal related objects and state
struct metal_context {
int n_cb;
MTL::Device * device;
MTL::CommandQueue * queue;
static std::unordered_map<void *, MTL::Buffer *> _mumap;
metal_kernel kernels[METAL_KERNEL_TYPE_COUNT];
bool support_simdgroup_reduction;
bool support_simdgroup_mm;
bool should_capture_next_compute;
// dispatch_queue_t d_queue;
};

struct metal_constants {
float eps; //rms_norm
float scale; //softmax
int embed_dim; //embed
};

struct metal_params {
struct matrix A, B, C, D, bias;
struct optimization_params opt_params;
float alpha, beta;
float16_t half_alpha;
// batch_size
int bs;
// for int4
float *scales, *offset, *zero_point;
float16_t *half_scales;
naive_float16_t *fp16_scales;
int *int32_zero_point;
int block_size;
// for int8 activation
float *A_scales;
int8_t A_zero_point;
// op
metal_kernel_type op;
// consts
float eps; //rms_norm
float scale; //softmax
int embed_dim; //embed
int n_orig_ctx;
int n_past;
int n_dims;
int mode;
int freq_base;
int freq_scale;
int ext_factor;
int attn_factor;
int beta_fast;
int beta_slow;
//
int sqlen, past_sqlen, num_heads, head_dim, input_m_dim_z, tgz ;
};

struct metal_cgraph{
int capacity;
int n_nodes;
const struct metal_params ** mm_nodes; // matmul ops (A, B, C)
};

void *allocateSharedMem(size_t size);
void init();
static void metal_free(struct metal_context * ctx);
enum status metal_graph_compute(
struct metal_cgraph * metal_data);
void add_node(const struct metal_params * new_node);

#endif
8 changes: 0 additions & 8 deletions kernels/metal/include/opParams.h

This file was deleted.

Loading