refactor: Further refine functionality, improve user interaction, and streamline vocabulary handling

- Renamed command-line arguments for clarity and consistency.
- Improved path resolution and import adjustments for robustness.
- Thoughtfully handled 'awq-path' and conditional logic for the weighted model.
- Enhanced model and vocabulary loading with the 'VocabFactory' class for structured and adaptable loading.
- Strengthened error handling and user feedback for a more user-friendly experience.
- Structured output file handling with clear conditions and defaults.
- Streamlined and organized the 'main' function for better logic flow.
- Passed 'sys.argv[1:]' to 'main' for adaptability and testability.

These changes solidify the script's functionality, making it more robust, user-friendly, and adaptable. The use of the 'VocabFactory' class is a notable enhancement in efficient vocabulary handling, reflecting a thoughtful and iterative approach to script development.
This commit is contained in:
teleprint-me 2024-01-07 21:54:42 -05:00
parent 226cea270e
commit 0614c338f8
No known key found for this signature in database
GPG Key ID: B0D11345E65C4D48

View File

@ -1555,8 +1555,9 @@ def main(argv: Optional[list[str]] = None) -> None:
args = parser.parse_args(argv) args = parser.parse_args(argv)
if args.awq_path: if args.awq_path:
sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) sys.path.insert(1, str(Path(__file__).resolve().parent / "awq-py"))
from awq.apply_awq import add_scale_weights from awq.apply_awq import add_scale_weights
tmp_model_path = args.model / "weighted_model" tmp_model_path = args.model / "weighted_model"
if tmp_model_path.is_dir(): if tmp_model_path.is_dir():
print(f"{tmp_model_path} exists as a weighted model.") print(f"{tmp_model_path} exists as a weighted model.")
@ -1575,74 +1576,83 @@ def main(argv: Optional[list[str]] = None) -> None:
if not args.vocab_only: if not args.vocab_only:
model_plus = load_some_model(args.model) model_plus = load_some_model(args.model)
else: else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) model_plus = ModelPlus(
model={}, paths=[args.model / "dummy"], format="none", vocab=None
)
if args.dump: if args.dump:
do_dump_model(model_plus) do_dump_model(model_plus)
return return
endianess = gguf.GGUFEndian.LITTLE endianess = gguf.GGUFEndian.LITTLE
if args.bigendian: if args.big_endian:
endianess = gguf.GGUFEndian.BIG endianess = gguf.GGUFEndian.BIG
params = Params.load(model_plus) params = Params.load(model_plus)
if params.n_ctx == -1: if params.n_ctx == -1:
if args.ctx is None: if args.ctx is None:
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" raise Exception(
"The model doesn't have a context size, and you didn't specify one with --ctx\n"
"Please specify one with --ctx:\n" "Please specify one with --ctx:\n"
" - LLaMA v1: --ctx 2048\n" " - LLaMA v1: --ctx 2048\n"
" - LLaMA v2: --ctx 4096\n") " - LLaMA v2: --ctx 4096\n"
)
params.n_ctx = args.ctx params.n_ctx = args.ctx
if args.outtype: if args.out_type:
params.ftype = { params.ftype = {
"f32": GGMLFileType.AllF32, "f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16, "f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0, "q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype] }[args.out_type]
print(f"params = {params}") print(f"params = {params}")
vocab: Vocab model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
vocab_factory = VocabFactory(vocab_path)
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path)
if args.vocab_only: if args.vocab_only:
if not args.outfile: if not args.out_file:
raise ValueError("need --outfile if using --vocab-only") raise ValueError("need --out-file if using --vocab-only")
# FIXME: Try to respect vocab_dir somehow? out_file = args.out_file
vocab = VocabLoader(params, args.vocab_dir or args.model) OutputFile.write_vocab_only(
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, out_file,
load_merges = True, params,
n_vocab = vocab.vocab_size) vocab,
outfile = args.outfile special_vocab,
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, endianess=endianess,
endianess = endianess, pad_vocab = args.padvocab) pad_vocab=args.pad_vocab,
print(f"Wrote {outfile}") )
print(f"Wrote {out_file}")
return return
if model_plus.vocab is not None and args.vocab_dir is None: if model_plus.vocab is not None and args.vocab_dir is None:
vocab = model_plus.vocab vocab = model_plus.vocab
else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = VocabLoader(params, vocab_dir)
# FIXME: Try to respect vocab_dir somehow?
print(f"Vocab info: {vocab}")
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = True,
n_vocab = vocab.vocab_size)
print(f"Special vocab info: {special_vocab}")
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) model = convert_model_names(model, params)
ftype = pick_output_type(model, args.outtype) ftype = pick_output_type(model, args.out_type)
model = convert_to_output_type(model, ftype) model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype) out_file = args.out_file or default_output_file(model_plus.paths, ftype)
params.ftype = ftype params.ftype = ftype
print(f"Writing {outfile}, format {ftype}") print(f"Writing {out_file}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, OutputFile.write_all(
concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab) out_file,
print(f"Wrote {outfile}") ftype,
params,
model,
vocab,
special_vocab,
concurrency=args.concurrency,
endianess=endianess,
pad_vocab=args.pad_vocab,
)
print(f"Wrote {out_file}")
if __name__ == '__main__': if __name__ == "__main__":
main() main(sys.argv[1:]) # Exclude the first element (script name) from sys.argv