Add LLaMA support

This commit is contained in:
oobabooga 2023-03-03 14:39:14 -03:00
parent 2bff646130
commit ea5c5eb3da
4 changed files with 110 additions and 2 deletions

96
modules/LLaMA.py Normal file
View File

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
import json
import os
import sys
import time
from pathlib import Path
from typing import Tuple
import fire
import torch
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from llama import LLaMA, ModelArgs, Tokenizer, Transformer
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MP'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '2223'
def setup_model_parallel() -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))
torch.distributed.init_process_group("gloo")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(1)
return local_rank, world_size
def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
class LLaMAModel:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path, max_seq_len=512, max_batch_size=32):
tokenizer_path = path / "tokenizer.model"
path = os.path.abspath(path)
tokenizer_path = os.path.abspath(tokenizer_path)
local_rank, world_size = setup_model_parallel()
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
generator = load(
path, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)
result = self()
result.pipeline = generator
return result
def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):
results = self.pipeline.generate(
[prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
)
return results[0]

View File

@ -39,9 +39,10 @@ def load_model(model_name):
t0 = time.time()
shared.is_RWKV = model_name.lower().startswith('rwkv-')
shared.is_LLaMA = model_name.lower().startswith('llama-')
# Default settings
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV or shared.is_LLaMA):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
@ -85,6 +86,15 @@ def load_model(model_name):
return model, None
# LLaMA model (not on HuggingFace)
elif shared.is_LLaMA:
import modules.LLaMA
from modules.LLaMA import LLaMAModel
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
return model, None
# Custom
else:
command = "AutoModelForCausalLM.from_pretrained"

View File

@ -6,6 +6,7 @@ model_name = ""
soft_prompt_tensor = None
soft_prompt = False
is_RWKV = False
is_LLaMA = False
# Chat variables
history = {'internal': [], 'visible': []}
@ -42,6 +43,7 @@ settings = {
'default': 'NovelAI-Sphinx Moth',
'pygmalion-*': 'Pygmalion',
'RWKV-*': 'Naive',
'llama-*': 'Naive',
'(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)'
},
'prompts': {

View File

@ -83,7 +83,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not shared.args.cpu:
torch.cuda.empty_cache()
if shared.is_RWKV:
if shared.is_RWKV or shared.is_LLaMA:
if shared.args.no_stream:
reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p)
yield formatted_outputs(reply, shared.model_name)