From d37a28730dc208ca5fbff50b32818a93a305b5ef Mon Sep 17 00:00:00 2001 From: flurb18 <33769947+flurb18@users.noreply.github.com> Date: Wed, 24 May 2023 08:38:20 -0400 Subject: [PATCH] Beginning of multi-user support (#2262) Adds a lock to generate_reply --- modules/shared.py | 1 + modules/text_generation.py | 12 +++++++++++- server.py | 2 ++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index b809bcc0..9059341f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -6,6 +6,7 @@ import yaml from modules.logging_colors import logger +generation_lock = None model = None tokenizer = None model_name = "None" diff --git a/modules/text_generation.py b/modules/text_generation.py index 904d0d48..d2a77ece 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,6 +1,7 @@ import ast import random import re +import threading import time import traceback @@ -17,6 +18,15 @@ from modules.logging_colors import logger from modules.models import clear_torch_cache, local_rank +def generate_reply(*args, **kwargs): + shared.generation_lock.acquire() + try: + for result in _generate_reply(*args, **kwargs): + yield result + finally: + shared.generation_lock.release() + + def get_max_prompt_length(state): max_length = state['truncation_length'] - state['max_new_tokens'] if shared.soft_prompt: @@ -154,7 +164,7 @@ def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=Non yield formatted_outputs(reply, shared.model_name) -def generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False): +def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False): state = apply_extensions('state', state) generate_func = apply_extensions('custom_generate_reply') if generate_func is None: diff --git a/server.py b/server.py index f408c6e6..1fbbde68 100644 --- a/server.py +++ b/server.py @@ -38,6 +38,7 @@ import zipfile from datetime import datetime from functools import partial from pathlib import Path +from threading import Lock import psutil import torch @@ -1075,6 +1076,7 @@ if __name__ == "__main__": 'instruction_template': shared.settings['instruction_template'] }) + shared.generation_lock = Lock() # Launch the web UI create_interface() while True: