mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-24 13:28:59 +01:00
Simplify deepspeed implementation (#40)
This commit is contained in:
parent
d6b2d68527
commit
2583bc5840
36
server.py
36
server.py
@ -83,10 +83,7 @@ if args.deepspeed:
|
||||
from modules.deepspeed_parameters import generate_ds_config
|
||||
|
||||
# Distributed setup
|
||||
if args.local_rank is not None:
|
||||
local_rank = args.local_rank
|
||||
else:
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
local_rank = args.local_rank if args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
torch.cuda.set_device(local_rank)
|
||||
deepspeed.init_distributed()
|
||||
@ -109,15 +106,8 @@ def load_model(model_name):
|
||||
|
||||
# DeepSpeed ZeRO-3
|
||||
elif args.deepspeed:
|
||||
if args.bf16:
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16)
|
||||
model = deepspeed.initialize(model=model,
|
||||
config_params=ds_config,
|
||||
model_parameters=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None)[0]
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
|
||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||
model.module.eval() # Inference
|
||||
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||
|
||||
@ -183,7 +173,11 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
|
||||
|
||||
if not args.deepspeed:
|
||||
return input_ids
|
||||
else:
|
||||
return input_ids.to(device=local_rank)
|
||||
|
||||
def decode(output_ids):
|
||||
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
@ -226,10 +220,8 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
|
||||
cuda = "" if args.cpu else ".cuda()"
|
||||
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||
if args.deepspeed:
|
||||
input_ids = encode(question, tokens).to(device=local_rank)
|
||||
else:
|
||||
input_ids = encode(question, tokens)
|
||||
|
||||
if stopping_string is not None:
|
||||
# The stopping_criteria code below was copied from
|
||||
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||
@ -246,11 +238,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
# Generate the entire reply at once
|
||||
if args.no_stream:
|
||||
t0 = time.time()
|
||||
if args.deepspeed:
|
||||
with torch.no_grad():
|
||||
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
||||
else:
|
||||
if not args.deepspeed:
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||
else:
|
||||
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
||||
reply = decode(output[0])
|
||||
t1 = time.time()
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
||||
@ -263,11 +255,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
yield formatted_outputs(original_question, model_name)
|
||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
|
||||
for i in tqdm(range(tokens//8+1)):
|
||||
if args.deepspeed:
|
||||
with torch.no_grad():
|
||||
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
||||
else:
|
||||
if not args.deepspeed:
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||
else:
|
||||
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
||||
reply = decode(output[0])
|
||||
if not (args.chat or args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
|
Loading…
Reference in New Issue
Block a user