Merge pull request #549 from catalpaaa/lora-and-model-dir

lora-dir, model-dir and login auth
This commit is contained in:
oobabooga 2023-03-27 23:46:47 -03:00 committed by GitHub
commit c188975a01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 23 deletions

View File

@ -198,12 +198,15 @@ Optionally, you can use the following command-line flags:
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
| `--no-stream` | Don't stream the text output in real time. | | `--no-stream` | Don't stream the text output in real time. |
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--listen` | Make the web UI reachable from your local network.| | `--model-dir MODEL_DIR` | Path to directory with all the models |
| `--listen-port LISTEN_PORT` | The listening port that the server will use. | | `--lora-dir LORA_DIR` | Path to directory with all the loras |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | | `--verbose` | Print the prompts to the terminal. |
| `--auto-launch` | Open the web UI in the default browser upon launch. | | `--listen` | Make the web UI reachable from your local network. |
| `--verbose` | Print the prompts to the terminal. | | `--listen-port LISTEN_PORT` | The listening port that the server will use. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).

View File

@ -32,7 +32,7 @@ def add_lora_to_model(lora_name):
elif shared.args.load_in_8bit: elif shared.args.load_in_8bit:
params['device_map'] = {'': 0} params['device_map'] = {'': 0}
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
if not shared.args.load_in_8bit and not shared.args.cpu: if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half() shared.model.half()
if not hasattr(shared.model, "hf_device_map"): if not hasattr(shared.model, "hf_device_map"):

View File

@ -46,9 +46,9 @@ def load_model(model_name):
# Default settings # Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else: else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
if torch.has_mps: if torch.has_mps:
device = torch.device('mps') device = torch.device('mps')
model = model.to(device) model = model.to(device)
@ -76,11 +76,11 @@ def load_model(model_name):
num_bits=4, group_size=64, num_bits=4, group_size=64,
group_dim=2, symmetric=False)) group_dim=2, symmetric=False))
model = OptLM(f"facebook/{shared.model_name}", env, "models", policy) model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy)
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif shared.args.deepspeed: elif shared.args.deepspeed:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
@ -89,8 +89,8 @@ def load_model(model_name):
elif shared.is_RWKV: elif shared.is_RWKV:
from modules.RWKV import RWKVModel, RWKVTokenizer from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path('models')) tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer return model, tokenizer
@ -142,7 +142,7 @@ def load_model(model_name):
if shared.args.disk: if shared.args.disk:
params["offload_folder"] = shared.args.disk_cache_dir params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'models/{shared.model_name}') checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint) config = AutoConfig.from_pretrained(checkpoint)
@ -159,10 +159,10 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Loading the tokenizer # Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left' tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")

View File

@ -107,11 +107,14 @@ parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile t
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
args = parser.parse_args() args = parser.parse_args()
# Provisional, this will be deleted later # Provisional, this will be deleted later

View File

@ -33,9 +33,9 @@ if settings_file is not None:
def get_available_models(): def get_available_models():
if shared.args.flexgen: if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else: else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def get_available_presets(): def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
@ -57,7 +57,7 @@ def get_available_softprompts():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras(): def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return ['None'] + sorted([item.name for item in list(Path('shared.args.lora_dir').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model(): def unload_model():
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
@ -498,12 +498,21 @@ def create_interface():
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
# Authentication
auth = None
if shared.args.gradio_auth_path is not None:
gradio_auth_creds = []
with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
# Launch the interface # Launch the interface
shared.gradio['interface'].queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
create_interface() create_interface()