diff --git a/modules/LLaMA_8bit.py b/modules/LLaMA_8bit.py new file mode 100644 index 00000000..a339277c --- /dev/null +++ b/modules/LLaMA_8bit.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Tuple +import os +import sys +import torch +import fire +import time +import json + +from pathlib import Path + +from fairscale.nn.model_parallel.initialize import initialize_model_parallel + +from repositories.llama_int8.llama import ModelArgs, Transformer, Tokenizer, LLaMA + + +def setup_model_parallel() -> Tuple[int, int]: + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + world_size = int(os.environ.get("WORLD_SIZE", -1)) + + torch.distributed.init_process_group("nccl") + initialize_model_parallel(world_size) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(1) + return local_rank, world_size + + +def load( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, +) -> LLaMA: + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + # torch.set_default_tensor_type(torch.cuda.HalfTensor) + torch.set_default_tensor_type(torch.HalfTensor) + print("Creating transformer") + model = Transformer(model_args) + print("Transformer created") + + key_to_dim = { + "w1": 0, + "w2": -1, + "w3": 0, + "wo": -1, + "wq": 0, + "wk": 0, + "wv": 0, + "output": 0, + "tok_embeddings": -1, + "ffn_norm": None, + "attention_norm": None, + "norm": None, + "rope": None, + } + + # ? + torch.set_default_tensor_type(torch.FloatTensor) + + # load the state dict incrementally, to avoid memory problems + for i, ckpt in enumerate(checkpoints): + print(f"Loading checkpoint {i}") + checkpoint = torch.load(ckpt, map_location="cpu") + for parameter_name, parameter in model.named_parameters(): + short_name = parameter_name.split(".")[-2] + if key_to_dim[short_name] is None and i == 0: + parameter.data = checkpoint[parameter_name] + elif key_to_dim[short_name] == 0: + size = checkpoint[parameter_name].size(0) + parameter.data[size * i : size * (i + 1), :] = checkpoint[ + parameter_name + ] + elif key_to_dim[short_name] == -1: + size = checkpoint[parameter_name].size(-1) + parameter.data[:, size * i : size * (i + 1)] = checkpoint[ + parameter_name + ] + del checkpoint + + # model.load_state_dict(checkpoint, strict=False) + model.quantize() + + generator = LLaMA(model, tokenizer) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + return generator + + +class LLaMAModel_8bit: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path, max_seq_len=2048, max_batch_size=1): + tokenizer_path = path / "tokenizer.model" + path = os.path.abspath(path) + tokenizer_path = os.path.abspath(tokenizer_path) + + generator = load(path, tokenizer_path, max_seq_len, max_batch_size) + + result = self() + result.pipeline = generator + return result + + def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95): + + results = self.pipeline.generate( + [prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p + ) + + return results[0] + diff --git a/modules/models.py b/modules/models.py index 904d8ae2..c7b75bb9 100644 --- a/modules/models.py +++ b/modules/models.py @@ -88,12 +88,20 @@ def load_model(model_name): # LLaMA model (not on HuggingFace) elif shared.is_LLaMA: - import modules.LLaMA - from modules.LLaMA import LLaMAModel + if shared.args.load_in_8bit: + import modules.LLaMA_8bit + from modules.LLaMA_8bit import LLaMAModel_8bit - model = LLaMAModel.from_pretrained(Path(f'models/{model_name}')) + model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}')) - return model, None + return model, None + else: + import modules.LLaMA + from modules.LLaMA import LLaMAModel + + model = LLaMAModel.from_pretrained(Path(f'models/{model_name}')) + + return model, None # Custom else: