mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
finetune : add -ngl parameter (#3762)
* Add '-ngl' support to finetune.cpp * Add fprintf in ggml_cuda_op_add When I tried CUDA offloading during finetuning following the readme, I got an assert here. This probably isn't an important case because inference later gives a warning saying you should use f16 or f32 instead when using lora * Add 'finetune.sh', which currently fails when using GPU "error: operator (): Finetuning on tensors with type 'f16' is not yet supported" * tweak finetune.sh * Suppress some warnings in ggml.c * Add f16 implementation to ggml_compute_forward_add_f16_f32 * Add an f16 case to ggml_add_cast_impl and llama_build_lora_finetune_graphs * finetune.sh: Edit comments * Add "add_f16_f32_f32_cuda" * Tweak an error message * finetune.sh: Add an optional LLAMA_MODEL_DIR variable * finetune.sh: Add an optional LLAMA_TRAINING_DIR variable * train : minor * tabs to spaces --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
This commit is contained in:
parent
f0e209324a
commit
73bdcb395e
@ -1045,6 +1045,7 @@ struct train_params_common get_default_train_params_common() {
|
|||||||
params.n_batch = 8;
|
params.n_batch = 8;
|
||||||
params.n_gradient_accumulation = 1;
|
params.n_gradient_accumulation = 1;
|
||||||
params.n_epochs = -1;
|
params.n_epochs = -1;
|
||||||
|
params.n_gpu_layers = 0;
|
||||||
|
|
||||||
params.custom_n_ctx = false;
|
params.custom_n_ctx = false;
|
||||||
|
|
||||||
@ -1080,6 +1081,7 @@ struct train_params_common get_default_train_params_common() {
|
|||||||
params.adam_beta2 = 0.999f;
|
params.adam_beta2 = 0.999f;
|
||||||
params.adam_gclip = 1.0f;
|
params.adam_gclip = 1.0f;
|
||||||
params.adam_eps_f = 0.0f;
|
params.adam_eps_f = 0.0f;
|
||||||
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ struct train_params_common {
|
|||||||
int n_batch;
|
int n_batch;
|
||||||
int n_gradient_accumulation;
|
int n_gradient_accumulation;
|
||||||
int n_epochs;
|
int n_epochs;
|
||||||
|
int n_gpu_layers;
|
||||||
|
|
||||||
bool custom_n_ctx;
|
bool custom_n_ctx;
|
||||||
|
|
||||||
|
@ -652,7 +652,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||||||
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
||||||
if (ggml_is_quantized(a->type)) {
|
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
|
||||||
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
|
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
|
||||||
} else if (a->type == GGML_TYPE_F32) {
|
} else if (a->type == GGML_TYPE_F32) {
|
||||||
return ggml_add(ctx, a, b);
|
return ggml_add(ctx, a, b);
|
||||||
@ -1459,6 +1459,17 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
|||||||
}
|
}
|
||||||
params->n_rank_w3 = std::stoi(argv[i]);
|
params->n_rank_w3 = std::stoi(argv[i]);
|
||||||
params->custom_n_rank_w3 = true;
|
params->custom_n_rank_w3 = true;
|
||||||
|
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||||
|
params->common.n_gpu_layers = std::stoi(argv[i]);
|
||||||
|
#else
|
||||||
|
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
|
||||||
|
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
train_print_usage(argc, argv, &default_params);
|
train_print_usage(argc, argv, &default_params);
|
||||||
@ -1545,6 +1556,7 @@ int main(int argc, char ** argv) {
|
|||||||
srand(params.common.seed);
|
srand(params.common.seed);
|
||||||
|
|
||||||
struct llama_model_params llama_mparams = llama_model_default_params();
|
struct llama_model_params llama_mparams = llama_model_default_params();
|
||||||
|
llama_mparams.n_gpu_layers = params.common.n_gpu_layers;
|
||||||
llama_mparams.vocab_only = false;
|
llama_mparams.vocab_only = false;
|
||||||
|
|
||||||
printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
|
printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
|
||||||
|
34
examples/finetune/finetune.sh
Normal file
34
examples/finetune/finetune.sh
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
cd `dirname $0`
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
EXE="./finetune"
|
||||||
|
|
||||||
|
if [[ ! $LLAMA_MODEL_DIR ]]; then LLAMA_MODEL_DIR="./models"; fi
|
||||||
|
if [[ ! $LLAMA_TRAINING_DIR ]]; then LLAMA_TRAINING_DIR="."; fi
|
||||||
|
|
||||||
|
# MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2-q8_0.gguf" # This is the model the readme uses.
|
||||||
|
MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2.gguf" # An f16 model. Note in this case with "-g", you get an f32-format .BIN file that isn't yet supported if you use it with "main --lora" with GPU inferencing.
|
||||||
|
|
||||||
|
while getopts "dg" opt; do
|
||||||
|
case $opt in
|
||||||
|
d)
|
||||||
|
DEBUGGER="gdb --args"
|
||||||
|
;;
|
||||||
|
g)
|
||||||
|
EXE="./build/bin/Release/finetune"
|
||||||
|
GPUARG="--gpu-layers 25"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
$DEBUGGER $EXE \
|
||||||
|
--model-base $MODEL \
|
||||||
|
$GPUARG \
|
||||||
|
--checkpoint-in chk-ol3b-shakespeare-LATEST.gguf \
|
||||||
|
--checkpoint-out chk-ol3b-shakespeare-ITERATION.gguf \
|
||||||
|
--lora-out lora-ol3b-shakespeare-ITERATION.bin \
|
||||||
|
--train-data "$LLAMA_TRAINING_DIR\shakespeare.txt" \
|
||||||
|
--save-every 10 \
|
||||||
|
--threads 10 --adam-iter 30 --batch 4 --ctx 64 \
|
||||||
|
--use-checkpointing
|
17
ggml-cuda.cu
17
ggml-cuda.cu
@ -513,6 +513,15 @@ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * d
|
|||||||
dst[i] = __hadd(x[i], __float2half(y[i]));
|
dst[i] = __hadd(x[i], __float2half(y[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = __half2float(x[i]) + y[i];
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
@ -4693,6 +4702,11 @@ static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, co
|
|||||||
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
||||||
|
add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
||||||
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
||||||
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||||
@ -5996,7 +6010,10 @@ inline void ggml_cuda_op_add(
|
|||||||
add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
|
add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||||
add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
|
add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||||
|
add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
} else {
|
} else {
|
||||||
|
fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -716,6 +716,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
|||||||
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
|
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
UNUSED(nb);
|
||||||
// scalar
|
// scalar
|
||||||
quantize_row_q8_0_reference(x, y, k);
|
quantize_row_q8_0_reference(x, y, k);
|
||||||
#endif
|
#endif
|
||||||
@ -969,6 +970,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
|
|||||||
y[i].s = sum*d;
|
y[i].s = sum*d;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
UNUSED(nb);
|
||||||
// scalar
|
// scalar
|
||||||
quantize_row_q8_1_reference(x, y, k);
|
quantize_row_q8_1_reference(x, y, k);
|
||||||
#endif
|
#endif
|
||||||
|
27
ggml.c
27
ggml.c
@ -3153,7 +3153,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
|
|||||||
// TODO: support less-strict constraint
|
// TODO: support less-strict constraint
|
||||||
// GGML_ASSERT(ggml_can_repeat(b, a));
|
// GGML_ASSERT(ggml_can_repeat(b, a));
|
||||||
GGML_ASSERT(ggml_can_repeat_rows(b, a));
|
GGML_ASSERT(ggml_can_repeat_rows(b, a));
|
||||||
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
|
GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
@ -6927,9 +6927,15 @@ static void ggml_compute_forward_add_f16_f32(
|
|||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
|
||||||
|
|
||||||
|
if (dst->type == GGML_TYPE_F32) {
|
||||||
|
GGML_ASSERT( nb0 == sizeof(float));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
@ -6940,6 +6946,7 @@ static void ggml_compute_forward_add_f16_f32(
|
|||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
if (nb10 == sizeof(float)) {
|
if (nb10 == sizeof(float)) {
|
||||||
|
if (dst->type == GGML_TYPE_F16) {
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// src0, src1 and dst are same shape => same indices
|
// src0, src1 and dst are same shape => same indices
|
||||||
const int i3 = ir/(ne2*ne1);
|
const int i3 = ir/(ne2*ne1);
|
||||||
@ -6954,6 +6961,22 @@ static void ggml_compute_forward_add_f16_f32(
|
|||||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
|
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
// src0, src1 and dst are same shape => same indices
|
||||||
|
const int i3 = ir/(ne2*ne1);
|
||||||
|
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||||
|
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||||
|
|
||||||
|
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||||
|
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
||||||
|
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
|
||||||
|
|
||||||
|
for (int i = 0; i < ne0; i++) {
|
||||||
|
dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// src1 is not contiguous
|
// src1 is not contiguous
|
||||||
|
@ -8003,7 +8003,7 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
|
if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
|
||||||
if (dest_t->type != GGML_TYPE_F16) {
|
if (dest_t->type != GGML_TYPE_F16) {
|
||||||
throw std::runtime_error(format(
|
throw std::runtime_error(format(
|
||||||
"%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__));
|
"%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models. dest_t->type: %d", __func__, dest_t->type));
|
||||||
}
|
}
|
||||||
offload_func = ggml_cuda_assign_buffers;
|
offload_func = ggml_cuda_assign_buffers;
|
||||||
offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;
|
offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;
|
||||||
|
Loading…
Reference in New Issue
Block a user