mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-12 05:17:40 +01:00
Simplify deepspeed implementation (#40)
This commit is contained in:
parent
d6b2d68527
commit
2583bc5840
36
server.py
36
server.py
@ -83,10 +83,7 @@ if args.deepspeed:
|
|||||||
from modules.deepspeed_parameters import generate_ds_config
|
from modules.deepspeed_parameters import generate_ds_config
|
||||||
|
|
||||||
# Distributed setup
|
# Distributed setup
|
||||||
if args.local_rank is not None:
|
local_rank = args.local_rank if args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
|
||||||
local_rank = args.local_rank
|
|
||||||
else:
|
|
||||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
deepspeed.init_distributed()
|
deepspeed.init_distributed()
|
||||||
@ -109,15 +106,8 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
# DeepSpeed ZeRO-3
|
||||||
elif args.deepspeed:
|
elif args.deepspeed:
|
||||||
if args.bf16:
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16)
|
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16)
|
|
||||||
model = deepspeed.initialize(model=model,
|
|
||||||
config_params=ds_config,
|
|
||||||
model_parameters=None,
|
|
||||||
optimizer=None,
|
|
||||||
lr_scheduler=None)[0]
|
|
||||||
model.module.eval() # Inference
|
model.module.eval() # Inference
|
||||||
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||||
|
|
||||||
@ -183,7 +173,11 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||||||
else:
|
else:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
|
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
|
||||||
|
|
||||||
|
if not args.deepspeed:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
else:
|
||||||
|
return input_ids.to(device=local_rank)
|
||||||
|
|
||||||
def decode(output_ids):
|
def decode(output_ids):
|
||||||
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
@ -226,10 +220,8 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
|
|
||||||
cuda = "" if args.cpu else ".cuda()"
|
cuda = "" if args.cpu else ".cuda()"
|
||||||
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||||
if args.deepspeed:
|
|
||||||
input_ids = encode(question, tokens).to(device=local_rank)
|
|
||||||
else:
|
|
||||||
input_ids = encode(question, tokens)
|
input_ids = encode(question, tokens)
|
||||||
|
|
||||||
if stopping_string is not None:
|
if stopping_string is not None:
|
||||||
# The stopping_criteria code below was copied from
|
# The stopping_criteria code below was copied from
|
||||||
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||||
@ -246,11 +238,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
# Generate the entire reply at once
|
# Generate the entire reply at once
|
||||||
if args.no_stream:
|
if args.no_stream:
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
if args.deepspeed:
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
if not args.deepspeed:
|
||||||
else:
|
|
||||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||||
|
else:
|
||||||
|
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
||||||
reply = decode(output[0])
|
reply = decode(output[0])
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
||||||
@ -263,11 +255,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
yield formatted_outputs(original_question, model_name)
|
yield formatted_outputs(original_question, model_name)
|
||||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
|
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
|
||||||
for i in tqdm(range(tokens//8+1)):
|
for i in tqdm(range(tokens//8+1)):
|
||||||
if args.deepspeed:
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
if not args.deepspeed:
|
||||||
else:
|
|
||||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||||
|
else:
|
||||||
|
output = eval(f"model.generate(input_ids, synced_gpus=True, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset})")
|
||||||
reply = decode(output[0])
|
reply = decode(output[0])
|
||||||
if not (args.chat or args.cai_chat):
|
if not (args.chat or args.cai_chat):
|
||||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user