mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Add /v1/internal/lora endpoints (#4652)
This commit is contained in:
parent
ef6feedeb2
commit
771e62e476
@ -1,8 +1,9 @@
|
|||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||||
from modules.utils import get_available_models
|
from modules.utils import get_available_loras, get_available_models
|
||||||
|
|
||||||
|
|
||||||
def get_current_model_info():
|
def get_current_model_info():
|
||||||
@ -13,12 +14,17 @@ def get_current_model_info():
|
|||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
|
return {'model_names': get_available_models()[1:]}
|
||||||
|
|
||||||
|
|
||||||
|
def list_dummy_models():
|
||||||
result = {
|
result = {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": []
|
"data": []
|
||||||
}
|
}
|
||||||
|
|
||||||
for model in get_dummy_models() + get_available_models()[1:]:
|
# these are expected by so much, so include some here as a dummy
|
||||||
|
for model in ['gpt-3.5-turbo', 'text-embedding-ada-002']:
|
||||||
result["data"].append(model_info_dict(model))
|
result["data"].append(model_info_dict(model))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -33,13 +39,6 @@ def model_info_dict(model_name: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_dummy_models() -> list:
|
|
||||||
return [ # these are expected by so much, so include some here as a dummy
|
|
||||||
'gpt-3.5-turbo',
|
|
||||||
'text-embedding-ada-002',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _load_model(data):
|
def _load_model(data):
|
||||||
model_name = data["model_name"]
|
model_name = data["model_name"]
|
||||||
args = data["args"]
|
args = data["args"]
|
||||||
@ -67,3 +66,15 @@ def _load_model(data):
|
|||||||
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
|
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
|
||||||
elif k == 'instruction_template':
|
elif k == 'instruction_template':
|
||||||
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
|
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
|
||||||
|
|
||||||
|
|
||||||
|
def list_loras():
|
||||||
|
return {'lora_names': get_available_loras()[1:]}
|
||||||
|
|
||||||
|
|
||||||
|
def load_loras(lora_names):
|
||||||
|
add_lora_to_model(lora_names)
|
||||||
|
|
||||||
|
|
||||||
|
def unload_all_loras():
|
||||||
|
add_lora_to_model([])
|
||||||
|
@ -38,10 +38,13 @@ from .typing import (
|
|||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EncodeRequest,
|
EncodeRequest,
|
||||||
EncodeResponse,
|
EncodeResponse,
|
||||||
|
LoadLorasRequest,
|
||||||
LoadModelRequest,
|
LoadModelRequest,
|
||||||
LogitsRequest,
|
LogitsRequest,
|
||||||
LogitsResponse,
|
LogitsResponse,
|
||||||
|
LoraListResponse,
|
||||||
ModelInfoResponse,
|
ModelInfoResponse,
|
||||||
|
ModelListResponse,
|
||||||
TokenCountResponse,
|
TokenCountResponse,
|
||||||
to_dict
|
to_dict
|
||||||
)
|
)
|
||||||
@ -141,7 +144,7 @@ async def handle_models(request: Request):
|
|||||||
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
||||||
|
|
||||||
if is_list:
|
if is_list:
|
||||||
response = OAImodels.list_models()
|
response = OAImodels.list_dummy_models()
|
||||||
else:
|
else:
|
||||||
model_name = path[len('/v1/models/'):]
|
model_name = path[len('/v1/models/'):]
|
||||||
response = OAImodels.model_info_dict(model_name)
|
response = OAImodels.model_info_dict(model_name)
|
||||||
@ -267,6 +270,12 @@ async def handle_model_info():
|
|||||||
return JSONResponse(content=payload)
|
return JSONResponse(content=payload)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
|
||||||
|
async def handle_list_models():
|
||||||
|
payload = OAImodels.list_models()
|
||||||
|
return JSONResponse(content=payload)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
||||||
async def handle_load_model(request_data: LoadModelRequest):
|
async def handle_load_model(request_data: LoadModelRequest):
|
||||||
'''
|
'''
|
||||||
@ -307,6 +316,27 @@ async def handle_load_model(request_data: LoadModelRequest):
|
|||||||
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||||
async def handle_unload_model():
|
async def handle_unload_model():
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
||||||
|
async def handle_list_loras():
|
||||||
|
response = OAImodels.list_loras()
|
||||||
|
return JSONResponse(content=response)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
|
||||||
|
async def handle_load_loras(request_data: LoadLorasRequest):
|
||||||
|
try:
|
||||||
|
OAImodels.load_loras(request_data.lora_names)
|
||||||
|
return JSONResponse(content="OK")
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
||||||
|
async def handle_unload_loras():
|
||||||
|
OAImodels.unload_all_loras()
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,6 +122,19 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
usage: dict
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
input: str | List[str]
|
||||||
|
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||||
|
encoding_format: str = Field(default="float", description="Can be float or base64.")
|
||||||
|
user: str | None = Field(default=None, description="Unused parameter.")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
index: int
|
||||||
|
embedding: List[float]
|
||||||
|
object: str = "embedding"
|
||||||
|
|
||||||
|
|
||||||
class EncodeRequest(BaseModel):
|
class EncodeRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
@ -166,23 +179,22 @@ class ModelInfoResponse(BaseModel):
|
|||||||
lora_names: List[str]
|
lora_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelListResponse(BaseModel):
|
||||||
|
model_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
class LoadModelRequest(BaseModel):
|
class LoadModelRequest(BaseModel):
|
||||||
model_name: str
|
model_name: str
|
||||||
args: dict | None = None
|
args: dict | None = None
|
||||||
settings: dict | None = None
|
settings: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class LoraListResponse(BaseModel):
|
||||||
input: str | List[str]
|
lora_names: List[str]
|
||||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
|
||||||
encoding_format: str = Field(default="float", description="Can be float or base64.")
|
|
||||||
user: str | None = Field(default=None, description="Unused parameter.")
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsResponse(BaseModel):
|
class LoadLorasRequest(BaseModel):
|
||||||
index: int
|
lora_names: List[str]
|
||||||
embedding: List[float]
|
|
||||||
object: str = "embedding"
|
|
||||||
|
|
||||||
|
|
||||||
def to_json(obj):
|
def to_json(obj):
|
||||||
|
Loading…
Reference in New Issue
Block a user