mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Make it actually work
This commit is contained in:
parent
804486214b
commit
001e638b47
@ -42,7 +42,7 @@ def load_model(model_name):
|
||||
shared.is_RWKV = model_name.lower().startswith('rwkv-')
|
||||
|
||||
# Default settings
|
||||
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.llama_bits>0 or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
|
||||
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.llama_bits>0 or shared.args.load_in_4bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
|
||||
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user