mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 22:08:46 +01:00
Add tokenizer test + revert to C++11 (#355)
* Add test-tokenizer-0 to do a few tokenizations - feel free to expand * Added option to convert-pth-to-ggml.py script to dump just the vocabulary * Added ./models/ggml-vocab.bin containing just LLaMA vocab data (used for tests) * Added utility to load vocabulary file from previous point (temporary implementation) * Avoid using std::string_view and drop back to C++11 (hope I didn't break something) * Rename gpt_vocab -> llama_vocab * All CMake binaries go into ./bin/ now
This commit is contained in:
parent
2e664f1ff4
commit
eb34620aec
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
@ -54,6 +54,7 @@ jobs:
|
|||||||
cd build
|
cd build
|
||||||
cmake ..
|
cmake ..
|
||||||
cmake --build . --config Release
|
cmake --build . --config Release
|
||||||
|
ctest --output-on-failure
|
||||||
|
|
||||||
macOS-latest-make:
|
macOS-latest-make:
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
@ -90,6 +91,7 @@ jobs:
|
|||||||
cd build
|
cd build
|
||||||
cmake ..
|
cmake ..
|
||||||
cmake --build . --config Release
|
cmake --build . --config Release
|
||||||
|
ctest --output-on-failure
|
||||||
|
|
||||||
windows-latest-cmake:
|
windows-latest-cmake:
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
@ -106,6 +108,7 @@ jobs:
|
|||||||
cd build
|
cd build
|
||||||
cmake ..
|
cmake ..
|
||||||
cmake --build . --config Release
|
cmake --build . --config Release
|
||||||
|
ctest --output-on-failure
|
||||||
|
|
||||||
- name: Get commit hash
|
- name: Get commit hash
|
||||||
id: commit
|
id: commit
|
||||||
|
@ -1,11 +1,37 @@
|
|||||||
cmake_minimum_required(VERSION 3.12)
|
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
|
||||||
project("llama.cpp" C CXX)
|
project("llama.cpp" C CXX)
|
||||||
|
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
|
||||||
|
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
|
set(LLAMA_STANDALONE ON)
|
||||||
|
|
||||||
|
# configure project version
|
||||||
|
# TODO
|
||||||
|
else()
|
||||||
|
set(LLAMA_STANDALONE OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (EMSCRIPTEN)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
|
||||||
|
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
|
||||||
|
else()
|
||||||
|
if (MINGW)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
else()
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Option list
|
# Option list
|
||||||
#
|
#
|
||||||
@ -34,6 +60,9 @@ option(LLAMA_FMA "llama: enable FMA"
|
|||||||
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
|
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
|
||||||
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
||||||
|
|
||||||
|
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||||
|
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||||
|
|
||||||
#
|
#
|
||||||
# Compile flags
|
# Compile flags
|
||||||
#
|
#
|
||||||
@ -187,17 +216,19 @@ add_executable(llama main.cpp)
|
|||||||
|
|
||||||
add_executable(quantize quantize.cpp)
|
add_executable(quantize quantize.cpp)
|
||||||
|
|
||||||
add_library(ggml OBJECT
|
|
||||||
ggml.c
|
|
||||||
ggml.h)
|
|
||||||
|
|
||||||
add_library(utils OBJECT
|
add_library(utils OBJECT
|
||||||
utils.cpp
|
utils.cpp
|
||||||
utils.h)
|
utils.h)
|
||||||
|
|
||||||
|
target_include_directories(utils PUBLIC .)
|
||||||
|
target_compile_features(utils PUBLIC cxx_std_11) # don't bump
|
||||||
|
|
||||||
|
add_library(ggml OBJECT
|
||||||
|
ggml.c
|
||||||
|
ggml.h)
|
||||||
|
|
||||||
target_include_directories(ggml PUBLIC .)
|
target_include_directories(ggml PUBLIC .)
|
||||||
target_compile_features(ggml PUBLIC c_std_11)
|
target_compile_features(ggml PUBLIC c_std_11) # don't bump
|
||||||
target_compile_features(utils PUBLIC cxx_std_17)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Linking
|
# Linking
|
||||||
@ -206,3 +237,16 @@ target_compile_features(utils PUBLIC cxx_std_17)
|
|||||||
target_link_libraries(ggml PRIVATE Threads::Threads ${LLAMA_EXTRA_LIBS})
|
target_link_libraries(ggml PRIVATE Threads::Threads ${LLAMA_EXTRA_LIBS})
|
||||||
target_link_libraries(llama PRIVATE ggml utils)
|
target_link_libraries(llama PRIVATE ggml utils)
|
||||||
target_link_libraries(quantize PRIVATE ggml utils)
|
target_link_libraries(quantize PRIVATE ggml utils)
|
||||||
|
|
||||||
|
#
|
||||||
|
# programs, examples and tests
|
||||||
|
#
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||||
|
enable_testing()
|
||||||
|
add_subdirectory(tests)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
#if (LLAMA_BUILD_EXAMPLES)
|
||||||
|
# add_subdirectory(examples)
|
||||||
|
#endif()
|
||||||
|
3
Makefile
3
Makefile
@ -30,8 +30,9 @@ endif
|
|||||||
# Compile flags
|
# Compile flags
|
||||||
#
|
#
|
||||||
|
|
||||||
|
# keep standard at C11 and C++11
|
||||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++17 -fPIC
|
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||||
LDFLAGS =
|
LDFLAGS =
|
||||||
|
|
||||||
# OS specific
|
# OS specific
|
||||||
|
@ -10,12 +10,10 @@
|
|||||||
# - Name (char[name_length])
|
# - Name (char[name_length])
|
||||||
# - Data (float[n_dims])
|
# - Data (float[n_dims])
|
||||||
#
|
#
|
||||||
# By default, the bigger matrices are converted to 16-bit floats.
|
|
||||||
# This can be disabled by adding the "use-f32" CLI argument.
|
|
||||||
#
|
|
||||||
# At the start of the ggml file we write the model parameters
|
# At the start of the ggml file we write the model parameters
|
||||||
# and vocabulary.
|
# and vocabulary.
|
||||||
#
|
#
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -23,6 +21,7 @@ import json
|
|||||||
import struct
|
import struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -30,6 +29,7 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
|
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
|
||||||
parser.add_argument('dir_model', help='directory containing the model checkpoint')
|
parser.add_argument('dir_model', help='directory containing the model checkpoint')
|
||||||
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
|
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
|
||||||
|
parser.add_argument('vocab_only', type=bool, default=False, help='only write vocab to file')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
def get_n_parts(dim):
|
def get_n_parts(dim):
|
||||||
@ -134,6 +134,27 @@ def main():
|
|||||||
ftype_str = ["f32", "f16"]
|
ftype_str = ["f32", "f16"]
|
||||||
|
|
||||||
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
||||||
|
|
||||||
|
# if only writing vocab to file
|
||||||
|
if args.vocab_only:
|
||||||
|
|
||||||
|
fname_model = f"{dir_model}/consolidated.00.pth"
|
||||||
|
fname_out = f"{dir_model}/ggml-vocab.bin"
|
||||||
|
|
||||||
|
print(f"Extracting only the vocab from '{fname_model}'\n")
|
||||||
|
|
||||||
|
model = torch.load(fname_model, map_location="cpu")
|
||||||
|
|
||||||
|
with open(fname_out, "wb") as fout:
|
||||||
|
fout.write(struct.pack("i", hparams["vocab_size"]))
|
||||||
|
write_tokens(fout, tokenizer)
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
print(f"Done. Output file: {fname_out}\n")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
n_parts = get_n_parts(hparams["dim"])
|
n_parts = get_n_parts(hparams["dim"])
|
||||||
|
|
||||||
for p in range(n_parts):
|
for p in range(n_parts):
|
||||||
@ -151,6 +172,7 @@ def main():
|
|||||||
process_and_write_variables(fout, model, ftype)
|
process_and_write_variables(fout, model, ftype)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
|
|
||||||
print(f"Done. Output file: {fname_out}, (part {p})\n")
|
print(f"Done. Output file: {fname_out}, (part {p})\n")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
26
main.cpp
26
main.cpp
@ -90,7 +90,7 @@ struct llama_model {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// load the model's weights from a file
|
// load the model's weights from a file
|
||||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
|
bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
|
||||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||||
|
|
||||||
std::vector<char> f_buf(1024*1024);
|
std::vector<char> f_buf(1024*1024);
|
||||||
@ -544,9 +544,9 @@ bool llama_eval(
|
|||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const int n_threads,
|
const int n_threads,
|
||||||
const int n_past,
|
const int n_past,
|
||||||
const std::vector<gpt_vocab::id> & embd_inp,
|
const std::vector<llama_vocab::id> & embd_inp,
|
||||||
std::vector<float> & embd_w,
|
std::vector<float> & embd_w,
|
||||||
size_t & mem_per_token) {
|
size_t & mem_per_token) {
|
||||||
const int N = embd_inp.size();
|
const int N = embd_inp.size();
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
@ -832,7 +832,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
int64_t t_load_us = 0;
|
int64_t t_load_us = 0;
|
||||||
|
|
||||||
gpt_vocab vocab;
|
llama_vocab vocab;
|
||||||
llama_model model;
|
llama_model model;
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
@ -864,13 +864,13 @@ int main(int argc, char ** argv) {
|
|||||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||||
params.prompt.insert(0, 1, ' ');
|
params.prompt.insert(0, 1, ' ');
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
|
std::vector<llama_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
|
||||||
|
|
||||||
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
|
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
|
||||||
|
|
||||||
// prefix & suffix for instruct mode
|
// prefix & suffix for instruct mode
|
||||||
const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
|
const std::vector<llama_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
|
||||||
const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
|
const std::vector<llama_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
|
||||||
|
|
||||||
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
||||||
if (params.instruct) {
|
if (params.instruct) {
|
||||||
@ -879,7 +879,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tokenize the reverse prompt
|
// tokenize the reverse prompt
|
||||||
std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
|
std::vector<std::vector<llama_vocab::id>> antipromptv_inp;
|
||||||
|
|
||||||
for (auto antiprompt : params.antiprompt) {
|
for (auto antiprompt : params.antiprompt) {
|
||||||
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
|
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
|
||||||
@ -925,14 +925,14 @@ int main(int argc, char ** argv) {
|
|||||||
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||||
fprintf(stderr, "\n\n");
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> embd;
|
std::vector<llama_vocab::id> embd;
|
||||||
|
|
||||||
// determine the required inference memory per token:
|
// determine the required inference memory per token:
|
||||||
size_t mem_per_token = 0;
|
size_t mem_per_token = 0;
|
||||||
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||||
|
|
||||||
int last_n_size = params.repeat_last_n;
|
int last_n_size = params.repeat_last_n;
|
||||||
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
std::vector<llama_vocab::id> last_n_tokens(last_n_size);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
@ -980,7 +980,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const int n_vocab = model.hparams.n_vocab;
|
const int n_vocab = model.hparams.n_vocab;
|
||||||
|
|
||||||
gpt_vocab::id id = 0;
|
llama_vocab::id id = 0;
|
||||||
|
|
||||||
{
|
{
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
@ -1066,7 +1066,7 @@ int main(int argc, char ** argv) {
|
|||||||
} while (another_line);
|
} while (another_line);
|
||||||
if (params.use_color) printf(ANSI_COLOR_RESET);
|
if (params.use_color) printf(ANSI_COLOR_RESET);
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
|
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
|
||||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||||
|
|
||||||
if (params.instruct) {
|
if (params.instruct) {
|
||||||
|
BIN
models/ggml-vocab.bin
Normal file
BIN
models/ggml-vocab.bin
Normal file
Binary file not shown.
@ -44,7 +44,7 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
gpt_vocab vocab;
|
llama_vocab vocab;
|
||||||
|
|
||||||
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
||||||
|
|
||||||
|
4
tests/CMakeLists.txt
Normal file
4
tests/CMakeLists.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
set(TEST_TARGET test-tokenizer-0)
|
||||||
|
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
|
||||||
|
target_link_libraries(${TEST_TARGET} PRIVATE utils)
|
||||||
|
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
|
69
tests/test-tokenizer-0.cpp
Normal file
69
tests/test-tokenizer-0.cpp
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
static const std::map<std::string, std::vector<llama_vocab::id>> k_tests = {
|
||||||
|
{ "Hello World", { 1, 10994, 2787, }, },
|
||||||
|
{ " Hello World", { 1, 15043, 2787, }, },
|
||||||
|
{ " Hello World!", { 1, 15043, 2787, 29991, }, },
|
||||||
|
{ " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, },
|
||||||
|
{ "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
|
||||||
|
{ "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, },
|
||||||
|
};
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
if (argc < 2) {
|
||||||
|
fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string fname = argv[1];
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
||||||
|
|
||||||
|
llama_vocab vocab;
|
||||||
|
|
||||||
|
if (!llama_vocab_load(fname, vocab)) {
|
||||||
|
fprintf(stderr, "%s : failed to load vocab from: '%s'\n", __func__, fname.c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_vocab = vocab.id_to_token.size();
|
||||||
|
|
||||||
|
if (n_vocab != 32000) {
|
||||||
|
fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & test_kv : k_tests) {
|
||||||
|
const auto res = llama_tokenize(vocab, test_kv.first, true);
|
||||||
|
|
||||||
|
bool correct = res.size() == test_kv.second.size();
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
||||||
|
if (res[i] != test_kv.second[i]) {
|
||||||
|
correct = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!correct) {
|
||||||
|
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
||||||
|
fprintf(stderr, "%s : expected tokens: ", __func__);
|
||||||
|
for (const auto & t : test_kv.second) {
|
||||||
|
fprintf(stderr, "%6d, ", t);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "%s : got tokens: ", __func__);
|
||||||
|
for (const auto & t : res) {
|
||||||
|
fprintf(stderr, "%6d, ", t);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
174
utils.cpp
174
utils.cpp
@ -240,61 +240,6 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
|
||||||
std::vector<std::string> words;
|
|
||||||
|
|
||||||
// first split the text into words
|
|
||||||
{
|
|
||||||
std::string str = text;
|
|
||||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
|
||||||
|
|
||||||
std::regex re(pat);
|
|
||||||
std::smatch m;
|
|
||||||
|
|
||||||
while (std::regex_search(str, m, re)) {
|
|
||||||
for (auto x : m) {
|
|
||||||
words.push_back(x);
|
|
||||||
}
|
|
||||||
str = m.suffix();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// find the longest tokens that form the words:
|
|
||||||
std::vector<gpt_vocab::id> tokens;
|
|
||||||
for (const auto & word : words) {
|
|
||||||
if (word.size() == 0) continue;
|
|
||||||
|
|
||||||
int i = 0;
|
|
||||||
int n = word.size();
|
|
||||||
while (i < n) {
|
|
||||||
int j = n;
|
|
||||||
while (j > i) {
|
|
||||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
|
||||||
if (it != vocab.token_to_id.end()) {
|
|
||||||
tokens.push_back(it->second);
|
|
||||||
i = j;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
--j;
|
|
||||||
}
|
|
||||||
if (i == n) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (j == i) {
|
|
||||||
auto sub = word.substr(i, 1);
|
|
||||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
|
||||||
tokens.push_back(vocab.token_to_id.at(sub));
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
|
||||||
}
|
|
||||||
++i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t utf8_len(char src) {
|
static size_t utf8_len(char src) {
|
||||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||||
@ -305,7 +250,8 @@ struct llama_sp_symbol {
|
|||||||
using index = int;
|
using index = int;
|
||||||
index prev;
|
index prev;
|
||||||
index next;
|
index next;
|
||||||
std::string_view text;
|
const char * text;
|
||||||
|
size_t n;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sp_bigram {
|
struct llama_sp_bigram {
|
||||||
@ -322,19 +268,23 @@ struct llama_sp_bigram {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// original implementation:
|
||||||
|
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||||
struct llama_tokenizer {
|
struct llama_tokenizer {
|
||||||
llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {}
|
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
|
||||||
|
|
||||||
void tokenize(std::string_view text, std::vector<gpt_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
// split string into utf8 chars
|
// split string into utf8 chars
|
||||||
int index = 0;
|
int index = 0;
|
||||||
while (!text.empty()) {
|
size_t offs = 0;
|
||||||
|
while (offs < text.size()) {
|
||||||
llama_sp_symbol sym;
|
llama_sp_symbol sym;
|
||||||
size_t char_len = std::min(text.size(), utf8_len(text.data()[0]));
|
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
|
||||||
sym.text = std::string_view(text.data(), char_len);
|
sym.text = text.c_str() + offs;
|
||||||
|
sym.n = char_len;
|
||||||
|
offs += char_len;
|
||||||
sym.prev = index - 1;
|
sym.prev = index - 1;
|
||||||
text.remove_prefix(char_len);
|
sym.next = offs == text.size() ? -1 : index + 1;
|
||||||
sym.next = text.empty() ? -1 : index + 1;
|
|
||||||
index++;
|
index++;
|
||||||
symbols_.emplace_back(std::move(sym));
|
symbols_.emplace_back(std::move(sym));
|
||||||
}
|
}
|
||||||
@ -353,14 +303,16 @@ struct llama_tokenizer {
|
|||||||
auto & right_sym = symbols_[bigram.right];
|
auto & right_sym = symbols_[bigram.right];
|
||||||
|
|
||||||
// if one of the symbols already got merged, skip it.
|
// if one of the symbols already got merged, skip it.
|
||||||
if (left_sym.text.empty() || right_sym.text.empty() ||
|
if (left_sym.n == 0 || right_sym.n == 0 ||
|
||||||
left_sym.text.size() + right_sym.text.size() != bigram.size) {
|
left_sym.n + right_sym.n != bigram.size) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// merge the right sym into the left one
|
// merge the right sym into the left one
|
||||||
left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size());
|
left_sym.n += right_sym.n;
|
||||||
right_sym.text = std::string_view("");
|
right_sym.n = 0;
|
||||||
|
|
||||||
|
//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
|
||||||
|
|
||||||
// remove the right sym from the chain
|
// remove the right sym from the chain
|
||||||
left_sym.next = right_sym.next;
|
left_sym.next = right_sym.next;
|
||||||
@ -374,13 +326,13 @@ struct llama_tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i != -1; i = symbols_[i].next) {
|
for (int i = 0; i != -1; i = symbols_[i].next) {
|
||||||
auto& symbol = symbols_[i];
|
auto & symbol = symbols_[i];
|
||||||
auto token = vocab_.token_to_id.find(std::string(symbol.text));
|
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
|
||||||
|
|
||||||
if (token == vocab_.token_to_id.end()) {
|
if (token == vocab_.token_to_id.end()) {
|
||||||
// output any symbols that did not form tokens as bytes.
|
// output any symbols that did not form tokens as bytes.
|
||||||
for (int j = 0; j < symbol.text.size(); ++j) {
|
for (int j = 0; j < (int) symbol.n; ++j) {
|
||||||
gpt_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
||||||
output.push_back(token_id);
|
output.push_back(token_id);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -395,8 +347,8 @@ private:
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size());
|
const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
|
||||||
auto token = vocab_.token_to_id.find(std::string(text));
|
auto token = vocab_.token_to_id.find(text);
|
||||||
|
|
||||||
if (token == vocab_.token_to_id.end()) {
|
if (token == vocab_.token_to_id.end()) {
|
||||||
return;
|
return;
|
||||||
@ -416,14 +368,52 @@ private:
|
|||||||
work_queue_.push(bigram);
|
work_queue_.push(bigram);
|
||||||
}
|
}
|
||||||
|
|
||||||
const gpt_vocab & vocab_;
|
const llama_vocab & vocab_;
|
||||||
std::vector<llama_sp_symbol> symbols_;
|
std::vector<llama_sp_symbol> symbols_;
|
||||||
llama_sp_bigram::queue work_queue_;
|
llama_sp_bigram::queue work_queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) {
|
// TODO: temporary code duplication with llama.cpp
|
||||||
|
// will resolve after #77 is merged
|
||||||
|
bool llama_vocab_load(const std::string & fname, llama_vocab & vocab) {
|
||||||
|
std::ifstream fin(fname, std::ios::binary);
|
||||||
|
if (!fin.is_open()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_vocab = 0;
|
||||||
|
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
||||||
|
|
||||||
|
std::string word;
|
||||||
|
std::vector<char> tmp(64);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_vocab; i++) {
|
||||||
|
uint32_t len;
|
||||||
|
fin.read((char *) &len, sizeof(len));
|
||||||
|
|
||||||
|
word.resize(len);
|
||||||
|
if (len > 0) {
|
||||||
|
tmp.resize(len);
|
||||||
|
fin.read(tmp.data(), len);
|
||||||
|
word.assign(tmp.data(), len);
|
||||||
|
} else {
|
||||||
|
word.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
float score;
|
||||||
|
fin.read((char *) &score, sizeof(score));
|
||||||
|
|
||||||
|
vocab.token_to_id[word] = i;
|
||||||
|
vocab.id_to_token[i] = word;
|
||||||
|
vocab.score[i] = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
||||||
llama_tokenizer tokenizer(vocab);
|
llama_tokenizer tokenizer(vocab);
|
||||||
std::vector<gpt_vocab::id> output;
|
std::vector<llama_vocab::id> output;
|
||||||
|
|
||||||
if (text.size() == 0) {
|
if (text.size() == 0) {
|
||||||
return output;
|
return output;
|
||||||
@ -437,42 +427,22 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_v
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
|
||||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
|
||||||
|
|
||||||
vocab.token_to_id = ::json_parse(fname);
|
|
||||||
|
|
||||||
for (const auto & kv : vocab.token_to_id) {
|
|
||||||
vocab.id_to_token[kv.second] = kv.first;
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
|
||||||
|
|
||||||
// print the vocabulary
|
|
||||||
//for (auto kv : vocab.token_to_id) {
|
|
||||||
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
|
||||||
//}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) {
|
|
||||||
// find the top K tokens
|
// find the top K tokens
|
||||||
std::partial_sort(
|
std::partial_sort(
|
||||||
logits_id.begin(),
|
logits_id.begin(),
|
||||||
logits_id.begin() + top_k, logits_id.end(),
|
logits_id.begin() + top_k, logits_id.end(),
|
||||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
[](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
|
||||||
return a.first > b.first;
|
return a.first > b.first;
|
||||||
});
|
});
|
||||||
|
|
||||||
logits_id.resize(top_k);
|
logits_id.resize(top_k);
|
||||||
}
|
}
|
||||||
|
|
||||||
gpt_vocab::id llama_sample_top_p_top_k(
|
llama_vocab::id llama_sample_top_p_top_k(
|
||||||
const gpt_vocab & vocab,
|
const llama_vocab & vocab,
|
||||||
const float * logits,
|
const float * logits,
|
||||||
std::vector<gpt_vocab::id> & last_n_tokens,
|
std::vector<llama_vocab::id> & last_n_tokens,
|
||||||
double repeat_penalty,
|
double repeat_penalty,
|
||||||
int top_k,
|
int top_k,
|
||||||
double top_p,
|
double top_p,
|
||||||
@ -480,7 +450,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
|
|||||||
std::mt19937 & rng) {
|
std::mt19937 & rng) {
|
||||||
int n_logits = vocab.id_to_token.size();
|
int n_logits = vocab.id_to_token.size();
|
||||||
|
|
||||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
std::vector<std::pair<double, llama_vocab::id>> logits_id;
|
||||||
logits_id.reserve(n_logits);
|
logits_id.reserve(n_logits);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
28
utils.h
28
utils.h
@ -60,7 +60,7 @@ std::string gpt_random_prompt(std::mt19937 & rng);
|
|||||||
// Vocab utils
|
// Vocab utils
|
||||||
//
|
//
|
||||||
|
|
||||||
struct gpt_vocab {
|
struct llama_vocab {
|
||||||
using id = int32_t;
|
using id = int32_t;
|
||||||
using token = std::string;
|
using token = std::string;
|
||||||
|
|
||||||
@ -74,34 +74,22 @@ void replace(std::string & str, const std::string & needle, const std::string &
|
|||||||
// poor-man's JSON parsing
|
// poor-man's JSON parsing
|
||||||
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||||
|
|
||||||
// split text into tokens
|
// TODO: temporary until #77 is merged, need this now for some tokenizer tests
|
||||||
//
|
bool llama_vocab_load(const std::string & fname, llama_vocab & vocab);
|
||||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
|
||||||
//
|
|
||||||
// Regex (Python):
|
|
||||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
|
||||||
//
|
|
||||||
// Regex (C++):
|
|
||||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
|
||||||
//
|
|
||||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
|
||||||
|
|
||||||
// TODO: this is probably wrong, but I cannot figure out how this tokenizer works ..
|
// TODO: this is probably wrong, but I cannot figure out how this tokenizer works ..
|
||||||
// ref: https://github.com/google/sentencepiece
|
// ref: https://github.com/google/sentencepiece
|
||||||
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos);
|
std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos);
|
||||||
|
|
||||||
// load the tokens from encoder.json
|
|
||||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
|
||||||
|
|
||||||
// sample next token given probabilities for each embedding
|
// sample next token given probabilities for each embedding
|
||||||
//
|
//
|
||||||
// - consider only the top K tokens
|
// - consider only the top K tokens
|
||||||
// - from them, consider only the top tokens with cumulative probability > P
|
// - from them, consider only the top tokens with cumulative probability > P
|
||||||
//
|
//
|
||||||
gpt_vocab::id llama_sample_top_p_top_k(
|
llama_vocab::id llama_sample_top_p_top_k(
|
||||||
const gpt_vocab & vocab,
|
const llama_vocab & vocab,
|
||||||
const float * logits,
|
const float * logits,
|
||||||
std::vector<gpt_vocab::id> & last_n_tokens,
|
std::vector<llama_vocab::id> & last_n_tokens,
|
||||||
double repeat_penalty,
|
double repeat_penalty,
|
||||||
int top_k,
|
int top_k,
|
||||||
double top_p,
|
double top_p,
|
||||||
@ -109,7 +97,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
|
|||||||
std::mt19937 & rng);
|
std::mt19937 & rng);
|
||||||
|
|
||||||
// filer to top K tokens from list of logits
|
// filer to top K tokens from list of logits
|
||||||
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
|
void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Quantization
|
// Quantization
|
||||||
|
Loading…
Reference in New Issue
Block a user