mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Implement CPU mode
This commit is contained in:
parent
f2a548c098
commit
0e67ccf607
@ -15,6 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* Chat mode for conversation and role playing.
|
* Chat mode for conversation and role playing.
|
||||||
* Load 13b/20b models in 8-bit mode.
|
* Load 13b/20b models in 8-bit mode.
|
||||||
* Load parameter presets from text files.
|
* Load parameter presets from text files.
|
||||||
|
* Option to use the CPU instead of the GPU for generation.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@ -89,6 +90,8 @@ Optionally, you can use the following command-line flags:
|
|||||||
|
|
||||||
`--chat`: Launch the webui in chat mode.
|
`--chat`: Launch the webui in chat mode.
|
||||||
|
|
||||||
|
`--cpu`: Use the CPU to generate text instead of the GPU.
|
||||||
|
|
||||||
## Presets
|
## Presets
|
||||||
|
|
||||||
Inference settings presets can be created under `presets/` as text files. These files are detected automatically at startup.
|
Inference settings presets can be created under `presets/` as text files. These files are detected automatically at startup.
|
||||||
|
36
server.py
36
server.py
@ -16,6 +16,7 @@ parser = argparse.ArgumentParser()
|
|||||||
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
|
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.')
|
parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.')
|
||||||
parser.add_argument('--chat', action='store_true', help='Launch the webui in chat mode.')
|
parser.add_argument('--chat', action='store_true', help='Launch the webui in chat mode.')
|
||||||
|
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
loaded_preset = None
|
loaded_preset = None
|
||||||
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
|
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
|
||||||
@ -26,30 +27,37 @@ def load_model(model_name):
|
|||||||
print(f"Loading {model_name}...")
|
print(f"Loading {model_name}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
# Loading the model
|
# Loading the model
|
||||||
if Path(f"torch-dumps/{model_name}.pt").exists():
|
if not args.cpu and Path(f"torch-dumps/{model_name}.pt").exists():
|
||||||
print("Loading in .pt format...")
|
print("Loading in .pt format...")
|
||||||
model = torch.load(Path(f"torch-dumps/{model_name}.pt")).cuda()
|
model = torch.load(Path(f"torch-dumps/{model_name}.pt"))
|
||||||
elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')):
|
elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')):
|
||||||
if any(size in model_name.lower() for size in ('13b', '20b', '30b')):
|
if any(size in model_name.lower() for size in ('13b', '20b', '30b')):
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||||
elif model_name in ['gpt-j-6B']:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
|
||||||
elif model_name in ['flan-t5', 't5-large']:
|
elif model_name in ['flan-t5', 't5-large']:
|
||||||
model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda()
|
model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}"))
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if model_name.lower().startswith('gpt4chan'):
|
if model_name.lower().startswith('gpt4chan'):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
|
||||||
elif model_name in ['flan-t5']:
|
elif model_name in ['flan-t5', 't5-large']:
|
||||||
tokenizer = T5Tokenizer.from_pretrained(Path(f"models/{model_name}/"))
|
tokenizer = T5Tokenizer.from_pretrained(Path(f"models/{model_name}/"))
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
|
||||||
|
|
||||||
|
# Sending to the GPU
|
||||||
|
if not (args.cpu or any(size in model_name.lower() for size in ('13b', '20b', '30b'))):
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -76,6 +84,7 @@ def generate_reply(question, temperature, max_length, inference_settings, select
|
|||||||
model_name = selected_model
|
model_name = selected_model
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
if not args.cpu:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model, tokenizer = load_model(model_name)
|
model, tokenizer = load_model(model_name)
|
||||||
if inference_settings != loaded_preset:
|
if inference_settings != loaded_preset:
|
||||||
@ -83,16 +92,21 @@ def generate_reply(question, temperature, max_length, inference_settings, select
|
|||||||
preset = infile.read()
|
preset = infile.read()
|
||||||
loaded_preset = inference_settings
|
loaded_preset = inference_settings
|
||||||
|
|
||||||
|
if not args.cpu:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda()
|
input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda()
|
||||||
|
cuda = ".cuda()"
|
||||||
|
else:
|
||||||
|
input_ids = tokenizer.encode(str(question), return_tensors='pt')
|
||||||
|
cuda = ""
|
||||||
|
|
||||||
if eos_token is None:
|
if eos_token is None:
|
||||||
output = eval(f"model.generate(input_ids, {preset}).cuda()")
|
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||||
else:
|
else:
|
||||||
n = tokenizer.encode(eos_token, return_tensors='pt')[0][1]
|
n = tokenizer.encode(eos_token, return_tensors='pt')[0][1]
|
||||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}).cuda()")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
||||||
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
|
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
if model_name.lower().startswith('galactica'):
|
if model_name.lower().startswith('galactica'):
|
||||||
reply = fix_galactica(reply)
|
reply = fix_galactica(reply)
|
||||||
return reply, reply, 'Only applicable for gpt4chan.'
|
return reply, reply, 'Only applicable for gpt4chan.'
|
||||||
|
Loading…
Reference in New Issue
Block a user