mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
convert : support loading vocab from fast tokenizer config (#3633)
* Add HFVocab into convert.py * Update convert.py * Update convert.py * add bytes_to_unicode function * change add_meta_vocab fucntion * remove debug code * remove byte_encoder * Add newline between classes * Check tokenizer.json when tokenizer.model is not exist. * Move transformers dependency to local code * Add error context with 'raise from' * Add fast tokenizer option to BpeVocab * Update convert.py * Add VocabLoader and remove *Vocab class * Add transformers dependency * remove added tokens and check newline token to decide spm or bpe * Update convert.py * Add special token type * Update convert.py * Update convert.py * Update convert.py * Fix typo in convert.py * Fix when params.n_vocab < tokenizer vocab size * update vocab class * change funtion name * Remove unused variable/functions, add types to class variable and methods, delete blank liens * fix flake8 warnings * code style cleanup * make mypy happy * change exception --------- Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
0353a18401
commit
873637afc7
315
convert.py
315
convert.py
@ -10,6 +10,7 @@ import itertools
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import mmap
|
import mmap
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import signal
|
import signal
|
||||||
@ -18,15 +19,15 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import zipfile
|
import zipfile
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from collections import OrderedDict
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
|
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
import os
|
|
||||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
import gguf
|
import gguf
|
||||||
@ -327,127 +328,138 @@ class Params:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
#
|
class VocabLoader:
|
||||||
# vocab
|
def __init__(self, params: Params, fname_tokenizer: Path) -> None:
|
||||||
#
|
try:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"To use VocabLoader, please install the `transformers` package. "
|
||||||
|
"You can install it with `pip install transformers`."
|
||||||
|
) from e
|
||||||
|
|
||||||
class BpeVocab:
|
try:
|
||||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True)
|
||||||
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
|
except ValueError:
|
||||||
added_tokens: dict[str, int]
|
self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True)
|
||||||
if fname_added_tokens is not None:
|
|
||||||
# FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
|
self.added_tokens_dict: OrderedDict[str, int] = OrderedDict()
|
||||||
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
|
||||||
|
for tok, tokidx in sorted(self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]):
|
||||||
|
if tokidx >= params.n_vocab or tokidx < self.tokenizer.vocab_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.added_tokens_dict[tok] = tokidx
|
||||||
|
|
||||||
|
self.unk_token_id: int = self.tokenizer.unk_token_id
|
||||||
|
self.specials: dict[str, int] = {
|
||||||
|
tok: self.tokenizer.get_vocab()[tok]
|
||||||
|
for tok in self.tokenizer.all_special_tokens
|
||||||
|
}
|
||||||
|
self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
|
||||||
|
self.vocab_size_base: int = self.tokenizer.vocab_size
|
||||||
|
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
|
||||||
|
self.fname_tokenizer: Path = fname_tokenizer
|
||||||
|
|
||||||
|
vocab_file = "tokenizer.model"
|
||||||
|
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
|
||||||
|
if path_candidate is not None:
|
||||||
|
self.spm = SentencePieceProcessor(str(path_candidate))
|
||||||
|
print(self.spm.vocab_size(), self.vocab_size_base)
|
||||||
else:
|
else:
|
||||||
# Fall back to trying to find the added tokens in tokenizer.json
|
self.spm = None
|
||||||
tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
|
|
||||||
if not tokenizer_json_file.is_file():
|
|
||||||
added_tokens = {}
|
|
||||||
else:
|
|
||||||
tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
|
|
||||||
added_tokens = dict(
|
|
||||||
(item['content'], item['id'])
|
|
||||||
for item in tokenizer_json.get('added_tokens', [])
|
|
||||||
# Added tokens here can be duplicates of the main vocabulary.
|
|
||||||
if item['content'] not in self.bpe_tokenizer)
|
|
||||||
|
|
||||||
vocab_size: int = len(self.bpe_tokenizer)
|
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
tokenizer = self.tokenizer
|
||||||
actual_ids = sorted(added_tokens.values())
|
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.get_vocab().items()}
|
||||||
if expected_ids != actual_ids:
|
added_tokens_ids = set(self.added_tokens_dict.values())
|
||||||
expected_end_id = vocab_size + len(actual_ids) - 1
|
|
||||||
raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
|
|
||||||
|
|
||||||
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
|
for i in range(self.vocab_size_base):
|
||||||
self.added_tokens_list = [text for (text, idx) in items]
|
if i in added_tokens_ids:
|
||||||
self.vocab_size_base: int = vocab_size
|
continue
|
||||||
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
|
|
||||||
self.fname_tokenizer = fname_tokenizer
|
|
||||||
self.fname_added_tokens = fname_added_tokens
|
|
||||||
|
|
||||||
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
text = reverse_vocab[i].encode("utf-8")
|
||||||
tokenizer = self.bpe_tokenizer
|
yield text, self.get_token_score(i), self.get_token_type(i)
|
||||||
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()}
|
|
||||||
|
|
||||||
for i, _ in enumerate(tokenizer):
|
|
||||||
yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
|
|
||||||
|
|
||||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
||||||
for text in self.added_tokens_list:
|
|
||||||
score = -1000.0
|
|
||||||
yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
|
|
||||||
|
|
||||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
||||||
yield from self.bpe_tokens()
|
|
||||||
yield from self.added_tokens()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
|
||||||
|
|
||||||
|
|
||||||
class SentencePieceVocab:
|
|
||||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
|
|
||||||
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
|
||||||
added_tokens: dict[str, int]
|
|
||||||
if fname_added_tokens is not None:
|
|
||||||
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
|
|
||||||
else:
|
|
||||||
added_tokens = {}
|
|
||||||
|
|
||||||
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
|
|
||||||
|
|
||||||
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
|
|
||||||
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
|
|
||||||
actual_new_ids = sorted(new_tokens.keys())
|
|
||||||
|
|
||||||
if expected_new_ids != actual_new_ids:
|
|
||||||
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
|
|
||||||
|
|
||||||
# Token pieces that were added to the base vocabulary.
|
|
||||||
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
|
|
||||||
self.vocab_size_base = vocab_size
|
|
||||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
|
||||||
self.fname_tokenizer = fname_tokenizer
|
|
||||||
self.fname_added_tokens = fname_added_tokens
|
|
||||||
|
|
||||||
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
|
||||||
tokenizer = self.sentencepiece_tokenizer
|
|
||||||
for i in range(tokenizer.vocab_size()):
|
|
||||||
piece = tokenizer.id_to_piece(i)
|
|
||||||
text: bytes = piece.encode("utf-8")
|
|
||||||
score: float = tokenizer.get_score(i)
|
|
||||||
|
|
||||||
|
def get_token_type(self, token_id: int) -> gguf.TokenType:
|
||||||
toktype = gguf.TokenType.NORMAL
|
toktype = gguf.TokenType.NORMAL
|
||||||
if tokenizer.is_unknown(i):
|
|
||||||
|
if self.spm is not None and token_id < self.spm.vocab_size():
|
||||||
|
if self.spm.is_unknown(token_id):
|
||||||
toktype = gguf.TokenType.UNKNOWN
|
toktype = gguf.TokenType.UNKNOWN
|
||||||
if tokenizer.is_control(i):
|
if self.spm.is_control(token_id):
|
||||||
|
toktype = gguf.TokenType.CONTROL
|
||||||
|
if self.spm.is_unused(token_id):
|
||||||
|
toktype = gguf.TokenType.UNUSED
|
||||||
|
if self.spm.is_byte(token_id):
|
||||||
|
toktype = gguf.TokenType.BYTE
|
||||||
|
else:
|
||||||
|
if token_id == self.unk_token_id:
|
||||||
|
toktype = gguf.TokenType.UNKNOWN
|
||||||
|
if token_id in self.special_ids:
|
||||||
toktype = gguf.TokenType.CONTROL
|
toktype = gguf.TokenType.CONTROL
|
||||||
|
|
||||||
# NOTE: I think added_tokens are user defined.
|
return toktype
|
||||||
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
|
||||||
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
|
|
||||||
|
|
||||||
if tokenizer.is_unused(i):
|
def get_token_score(self, token_id: int) -> float:
|
||||||
toktype = gguf.TokenType.UNUSED
|
if self.spm is not None and token_id < self.spm.vocab_size():
|
||||||
if tokenizer.is_byte(i):
|
return cast(float, self.spm.get_score(token_id))
|
||||||
toktype = gguf.TokenType.BYTE
|
return 0.0
|
||||||
|
|
||||||
yield text, score, toktype
|
|
||||||
|
|
||||||
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
for text in self.added_tokens_list:
|
|
||||||
|
for text in self.added_tokens_dict:
|
||||||
|
if text in self.specials:
|
||||||
|
|
||||||
|
toktype = self.get_token_type(self.specials[text])
|
||||||
|
score = self.get_token_score(self.specials[text])
|
||||||
|
|
||||||
|
else:
|
||||||
|
toktype = gguf.TokenType.USER_DEFINED
|
||||||
score = -1000.0
|
score = -1000.0
|
||||||
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
|
|
||||||
|
yield text.encode("utf-8"), score, toktype
|
||||||
|
|
||||||
|
def has_newline_token(self) -> bool:
|
||||||
|
return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab
|
||||||
|
|
||||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||||
yield from self.sentencepiece_tokens()
|
yield from self.hf_tokens()
|
||||||
yield from self.added_tokens()
|
yield from self.added_tokens()
|
||||||
|
|
||||||
|
def get_vocab_type(self) -> str:
|
||||||
|
path_candidates = []
|
||||||
|
vocab_file = "tokenizer.model"
|
||||||
|
path_candidates.append(vocab_file)
|
||||||
|
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
|
||||||
|
if path_candidate is not None:
|
||||||
|
return "llama"
|
||||||
|
|
||||||
|
vocab_file = "vocab.json"
|
||||||
|
path_candidates.append(vocab_file)
|
||||||
|
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
|
||||||
|
if path_candidate is not None:
|
||||||
|
return "gpt2"
|
||||||
|
|
||||||
|
vocab_file = "tokenizer.json"
|
||||||
|
path_candidates.append(vocab_file)
|
||||||
|
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
|
||||||
|
if path_candidate:
|
||||||
|
if not self.has_newline_token():
|
||||||
|
return "gpt2"
|
||||||
|
return "llama"
|
||||||
|
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; "
|
||||||
|
"if it's in another directory, pass the directory as --vocab-dir"
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_dict)} added tokens>"
|
||||||
|
|
||||||
|
|
||||||
Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab'
|
Vocab: TypeAlias = 'VocabLoader'
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# data loading
|
# data loading
|
||||||
@ -824,20 +836,27 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
|||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
def check_vocab_size(params: Params, vocab: Vocab) -> None:
|
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
|
||||||
if params.n_vocab != vocab.vocab_size:
|
if params.n_vocab != vocab.vocab_size:
|
||||||
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
|
if params.n_vocab == vocab.vocab_size:
|
||||||
if params.n_vocab == vocab.vocab_size_base:
|
|
||||||
print("Ignoring added_tokens.json since model matches vocab size without it.")
|
print("Ignoring added_tokens.json since model matches vocab size without it.")
|
||||||
vocab.added_tokens_list = []
|
vocab.added_tokens_dict = OrderedDict()
|
||||||
vocab.vocab_size = vocab.vocab_size_base
|
vocab.vocab_size = vocab.vocab_size
|
||||||
|
return
|
||||||
|
|
||||||
|
if pad_vocab and params.n_vocab > vocab.vocab_size:
|
||||||
|
pad_count = params.n_vocab - vocab.vocab_size
|
||||||
|
print(f'Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>')
|
||||||
|
for i in range(1, (params.n_vocab - vocab.vocab_size) + 1):
|
||||||
|
vocab.added_tokens_dict[f'<dummy{i:05}>'] = -1
|
||||||
|
vocab.vocab_size = params.n_vocab
|
||||||
return
|
return
|
||||||
msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
|
msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
|
||||||
if vocab.fname_added_tokens is not None:
|
|
||||||
msg += f" combined with {vocab.fname_added_tokens}"
|
|
||||||
msg += f" has {vocab.vocab_size})."
|
msg += f" has {vocab.vocab_size})."
|
||||||
if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None:
|
if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
|
||||||
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
|
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
|
||||||
|
if vocab.vocab_size < params.n_vocab:
|
||||||
|
msg += " Possibly try using the --padvocab option."
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
|
|
||||||
@ -901,12 +920,8 @@ class OutputFile:
|
|||||||
scores.append(score)
|
scores.append(score)
|
||||||
toktypes.append(toktype)
|
toktypes.append(toktype)
|
||||||
|
|
||||||
if isinstance(vocab, SentencePieceVocab):
|
vocab_type = vocab.get_vocab_type()
|
||||||
self.gguf.add_tokenizer_model("llama")
|
self.gguf.add_tokenizer_model(vocab_type)
|
||||||
elif isinstance(vocab, BpeVocab):
|
|
||||||
self.gguf.add_tokenizer_model("gpt2")
|
|
||||||
else:
|
|
||||||
raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab')
|
|
||||||
self.gguf.add_token_list(tokens)
|
self.gguf.add_token_list(tokens)
|
||||||
self.gguf.add_token_scores(scores)
|
self.gguf.add_token_scores(scores)
|
||||||
self.gguf.add_token_types(toktypes)
|
self.gguf.add_token_types(toktypes)
|
||||||
@ -932,8 +947,12 @@ class OutputFile:
|
|||||||
self.gguf.close()
|
self.gguf.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
|
def write_vocab_only(
|
||||||
check_vocab_size(params, vocab)
|
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||||
|
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||||
|
pad_vocab: bool = False,
|
||||||
|
) -> None:
|
||||||
|
check_vocab_size(params, vocab, pad_vocab = pad_vocab)
|
||||||
|
|
||||||
of = OutputFile(fname_out, endianess=endianess)
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
@ -960,8 +979,13 @@ class OutputFile:
|
|||||||
return dt.quantize(arr)
|
return dt.quantize(arr)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
|
def write_all(
|
||||||
check_vocab_size(params, vocab)
|
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||||
|
concurrency: int = DEFAULT_CONCURRENCY,
|
||||||
|
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||||
|
pad_vocab: bool = False,
|
||||||
|
) -> None:
|
||||||
|
check_vocab_size(params, vocab, pad_vocab = pad_vocab)
|
||||||
|
|
||||||
of = OutputFile(fname_out, endianess=endianess)
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
@ -1119,35 +1143,17 @@ def load_some_model(path: Path) -> ModelPlus:
|
|||||||
return model_plus
|
return model_plus
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(path: Path, vocabtype: str | None) -> Vocab:
|
def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]:
|
||||||
# Be extra-friendly and accept either a file or a directory. Also, if it's
|
|
||||||
# a directory, it might be the model directory, and tokenizer.model might
|
|
||||||
# be in the parent of that.
|
|
||||||
if path.is_dir():
|
|
||||||
vocab_file = "tokenizer.model"
|
|
||||||
if vocabtype == 'bpe':
|
|
||||||
vocab_file = "vocab.json"
|
|
||||||
path2 = path / vocab_file
|
path2 = path / vocab_file
|
||||||
# Use `.parent` instead of /.. to handle the symlink case better.
|
# Use `.parent` instead of /.. to handle the symlink case better.
|
||||||
path3 = path.parent / vocab_file
|
path3 = path.parent / vocab_file
|
||||||
|
|
||||||
if path2.exists():
|
if path2.exists():
|
||||||
path = path2
|
return path2
|
||||||
elif path3.exists():
|
if path3.exists():
|
||||||
path = path3
|
return path3
|
||||||
else:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Could not find {vocab_file} in {path} or its parent; "
|
|
||||||
"if it's in another directory, pass the directory as --vocab-dir")
|
|
||||||
|
|
||||||
print(f"Loading vocab file '{path}', type '{vocabtype}'")
|
return None
|
||||||
|
|
||||||
added_tokens_path = path.parent / "added_tokens.json"
|
|
||||||
if vocabtype == "bpe":
|
|
||||||
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
|
||||||
elif vocabtype == "spm":
|
|
||||||
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
|
|
||||||
|
|
||||||
|
|
||||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
|
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
|
||||||
@ -1185,11 +1191,11 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
|
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
|
||||||
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
||||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||||
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)")
|
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
|
||||||
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
|
|
||||||
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
||||||
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
|
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
|
||||||
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
|
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
|
||||||
|
parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
||||||
|
|
||||||
args = parser.parse_args(args_in)
|
args = parser.parse_args(args_in)
|
||||||
if args.dump_single:
|
if args.dump_single:
|
||||||
@ -1232,12 +1238,13 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
if not args.outfile:
|
if not args.outfile:
|
||||||
raise ValueError("need --outfile if using --vocab-only")
|
raise ValueError("need --outfile if using --vocab-only")
|
||||||
# FIXME: Try to respect vocab_dir somehow?
|
# FIXME: Try to respect vocab_dir somehow?
|
||||||
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
|
vocab = VocabLoader(params, args.vocab_dir or args.model)
|
||||||
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
|
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
|
||||||
load_merges = args.vocabtype == 'bpe',
|
load_merges = True,
|
||||||
n_vocab = vocab.vocab_size)
|
n_vocab = vocab.vocab_size)
|
||||||
outfile = args.outfile
|
outfile = args.outfile
|
||||||
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
|
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
|
||||||
|
endianess = endianess, pad_vocab = args.padvocab)
|
||||||
print(f"Wrote {outfile}")
|
print(f"Wrote {outfile}")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1245,12 +1252,15 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
vocab = model_plus.vocab
|
vocab = model_plus.vocab
|
||||||
else:
|
else:
|
||||||
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
|
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
|
||||||
vocab = load_vocab(vocab_dir, args.vocabtype)
|
vocab = VocabLoader(params, vocab_dir)
|
||||||
|
|
||||||
# FIXME: Try to respect vocab_dir somehow?
|
# FIXME: Try to respect vocab_dir somehow?
|
||||||
|
print(f"Vocab info: {vocab}")
|
||||||
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
|
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
|
||||||
load_merges = args.vocabtype == 'bpe',
|
load_merges = True,
|
||||||
n_vocab = vocab.vocab_size)
|
n_vocab = vocab.vocab_size)
|
||||||
|
|
||||||
|
print(f"Special vocab info: {special_vocab}")
|
||||||
model = model_plus.model
|
model = model_plus.model
|
||||||
model = convert_model_names(model, params)
|
model = convert_model_names(model, params)
|
||||||
ftype = pick_output_type(model, args.outtype)
|
ftype = pick_output_type(model, args.outtype)
|
||||||
@ -1260,7 +1270,8 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
params.ftype = ftype
|
params.ftype = ftype
|
||||||
print(f"Writing {outfile}, format {ftype}")
|
print(f"Writing {outfile}, format {ftype}")
|
||||||
|
|
||||||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess)
|
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
|
||||||
|
concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab)
|
||||||
print(f"Wrote {outfile}")
|
print(f"Wrote {outfile}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
numpy==1.24.4
|
numpy==1.24.4
|
||||||
sentencepiece==0.1.98
|
sentencepiece==0.1.98
|
||||||
|
transformers>=4.34.0
|
||||||
gguf>=0.1.0
|
gguf>=0.1.0
|
||||||
|
Loading…
Reference in New Issue
Block a user