mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 01:09:22 +01:00
Add /v1/internal/lora endpoints (#4652)
This commit is contained in:
parent
ef6feedeb2
commit
771e62e476
@ -1,8 +1,9 @@
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.utils import get_available_models
|
||||
from modules.utils import get_available_loras, get_available_models
|
||||
|
||||
|
||||
def get_current_model_info():
|
||||
@ -13,12 +14,17 @@ def get_current_model_info():
|
||||
|
||||
|
||||
def list_models():
|
||||
return {'model_names': get_available_models()[1:]}
|
||||
|
||||
|
||||
def list_dummy_models():
|
||||
result = {
|
||||
"object": "list",
|
||||
"data": []
|
||||
}
|
||||
|
||||
for model in get_dummy_models() + get_available_models()[1:]:
|
||||
# these are expected by so much, so include some here as a dummy
|
||||
for model in ['gpt-3.5-turbo', 'text-embedding-ada-002']:
|
||||
result["data"].append(model_info_dict(model))
|
||||
|
||||
return result
|
||||
@ -33,13 +39,6 @@ def model_info_dict(model_name: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def get_dummy_models() -> list:
|
||||
return [ # these are expected by so much, so include some here as a dummy
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
|
||||
def _load_model(data):
|
||||
model_name = data["model_name"]
|
||||
args = data["args"]
|
||||
@ -67,3 +66,15 @@ def _load_model(data):
|
||||
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
|
||||
elif k == 'instruction_template':
|
||||
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
|
||||
|
||||
|
||||
def list_loras():
|
||||
return {'lora_names': get_available_loras()[1:]}
|
||||
|
||||
|
||||
def load_loras(lora_names):
|
||||
add_lora_to_model(lora_names)
|
||||
|
||||
|
||||
def unload_all_loras():
|
||||
add_lora_to_model([])
|
||||
|
@ -38,10 +38,13 @@ from .typing import (
|
||||
EmbeddingsResponse,
|
||||
EncodeRequest,
|
||||
EncodeResponse,
|
||||
LoadLorasRequest,
|
||||
LoadModelRequest,
|
||||
LogitsRequest,
|
||||
LogitsResponse,
|
||||
LoraListResponse,
|
||||
ModelInfoResponse,
|
||||
ModelListResponse,
|
||||
TokenCountResponse,
|
||||
to_dict
|
||||
)
|
||||
@ -141,7 +144,7 @@ async def handle_models(request: Request):
|
||||
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
||||
|
||||
if is_list:
|
||||
response = OAImodels.list_models()
|
||||
response = OAImodels.list_dummy_models()
|
||||
else:
|
||||
model_name = path[len('/v1/models/'):]
|
||||
response = OAImodels.model_info_dict(model_name)
|
||||
@ -267,6 +270,12 @@ async def handle_model_info():
|
||||
return JSONResponse(content=payload)
|
||||
|
||||
|
||||
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
|
||||
async def handle_list_models():
|
||||
payload = OAImodels.list_models()
|
||||
return JSONResponse(content=payload)
|
||||
|
||||
|
||||
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
||||
async def handle_load_model(request_data: LoadModelRequest):
|
||||
'''
|
||||
@ -307,6 +316,27 @@ async def handle_load_model(request_data: LoadModelRequest):
|
||||
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||
async def handle_unload_model():
|
||||
unload_model()
|
||||
|
||||
|
||||
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
||||
async def handle_list_loras():
|
||||
response = OAImodels.list_loras()
|
||||
return JSONResponse(content=response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
|
||||
async def handle_load_loras(request_data: LoadLorasRequest):
|
||||
try:
|
||||
OAImodels.load_loras(request_data.lora_names)
|
||||
return JSONResponse(content="OK")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
|
||||
|
||||
|
||||
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
||||
async def handle_unload_loras():
|
||||
OAImodels.unload_all_loras()
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
|
||||
|
@ -122,6 +122,19 @@ class ChatCompletionResponse(BaseModel):
|
||||
usage: dict
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: str | List[str]
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||
encoding_format: str = Field(default="float", description="Can be float or base64.")
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
index: int
|
||||
embedding: List[float]
|
||||
object: str = "embedding"
|
||||
|
||||
|
||||
class EncodeRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
@ -166,23 +179,22 @@ class ModelInfoResponse(BaseModel):
|
||||
lora_names: List[str]
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
model_names: List[str]
|
||||
|
||||
|
||||
class LoadModelRequest(BaseModel):
|
||||
model_name: str
|
||||
args: dict | None = None
|
||||
settings: dict | None = None
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: str | List[str]
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||
encoding_format: str = Field(default="float", description="Can be float or base64.")
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
class LoraListResponse(BaseModel):
|
||||
lora_names: List[str]
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
index: int
|
||||
embedding: List[float]
|
||||
object: str = "embedding"
|
||||
class LoadLorasRequest(BaseModel):
|
||||
lora_names: List[str]
|
||||
|
||||
|
||||
def to_json(obj):
|
||||
|
Loading…
Reference in New Issue
Block a user