mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 22:38:58 +01:00
fix ftype
This commit is contained in:
parent
84288ff9f7
commit
7a83f200d3
@ -5,19 +5,12 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Iterable, Iterator
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -32,22 +25,17 @@ from convert_hf_to_gguf import Model
|
||||
|
||||
logger = logging.getLogger("lora-to-gguf")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
all_models = ", ".join([arch for arch in Model._model_classes.keys()])
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a huggingface model to a GGML compatible file")
|
||||
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
|
||||
parser.add_argument(
|
||||
"--outfile", type=Path,
|
||||
help="path to write to; default: based on input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", type=str,
|
||||
help=f"Arch of the base model, must be one of: {all_models} (default: LlamaForCausalLM)",
|
||||
default="LlamaForCausalLM"
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
@ -73,14 +61,13 @@ if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
# FIXME: outtype is not working
|
||||
ftype_map: dict[str, gguf.LlamaFileType] = {
|
||||
"f32": gguf.LlamaFileType.ALL_F32,
|
||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
ftype = ftype_map[args.outtype]
|
||||
|
||||
dir_base_model = args.base
|
||||
dir_lora = args.lora_path
|
||||
@ -110,7 +97,7 @@ if __name__ == '__main__':
|
||||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
model_instance = model_class(dir_base_model, ftype_map[args.outtype], fname_out, args.bigendian, False, False, None)
|
||||
model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
|
||||
logger.info("Set model parameters")
|
||||
model_instance.set_gguf_parameters()
|
||||
|
||||
@ -140,16 +127,18 @@ if __name__ == '__main__':
|
||||
# overwrite method
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
# TODO: This will not take into account tensor transformations
|
||||
return [(name, data_torch)]
|
||||
|
||||
# overwrite method
|
||||
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
return True
|
||||
return ftype != gguf.LlamaFileType.ALL_F32
|
||||
|
||||
model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
|
||||
model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance)
|
||||
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)
|
||||
|
||||
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
||||
logger.info("Exporting model...")
|
||||
model_instance.write()
|
||||
|
Loading…
Reference in New Issue
Block a user