Add /v1/internal/logits endpoint (#4650)

This commit is contained in:
oobabooga 2023-11-18 23:19:31 -03:00 committed by GitHub
parent 8f4f4daf8b
commit 0fa1af296c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 9 deletions

View File

@ -97,6 +97,29 @@ curl http://127.0.0.1:5000/v1/chat/completions \
}' }'
``` ```
#### Logits
```
curl -k http://127.0.0.1:5000/v1/internal/logits \
-H "Content-Type: application/json" \
-d '{
"prompt": "Who is best, Asuka or Rei? Answer:",
"use_samplers": false
}'
```
#### Logits after sampling parameters
```
curl -k http://127.0.0.1:5000/v1/internal/logits \
-H "Content-Type: application/json" \
-d '{
"prompt": "Who is best, Asuka or Rei? Answer:",
"use_samplers": true,
"top_k": 3
}'
```
#### Python chat example #### Python chat example
```python ```python

View File

@ -16,6 +16,7 @@ from sse_starlette import EventSourceResponse
import extensions.openai.completions as OAIcompletions import extensions.openai.completions as OAIcompletions
import extensions.openai.embeddings as OAIembeddings import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations import extensions.openai.moderations as OAImoderations
from extensions.openai.errors import ServiceUnavailableError from extensions.openai.errors import ServiceUnavailableError
@ -38,6 +39,8 @@ from .typing import (
EncodeRequest, EncodeRequest,
EncodeResponse, EncodeResponse,
LoadModelRequest, LoadModelRequest,
LogitsRequest,
LogitsResponse,
ModelInfoResponse, ModelInfoResponse,
TokenCountResponse, TokenCountResponse,
to_dict to_dict
@ -242,6 +245,16 @@ async def handle_token_count(request_data: EncodeRequest):
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key)
async def handle_logits(request_data: LogitsRequest):
'''
Given a prompt, returns the top 50 most likely logits as a dict.
The keys are the tokens, and the values are the probabilities.
'''
response = OAIlogits._get_next_logits(to_dict(request_data))
return JSONResponse(response)
@app.post("/v1/internal/stop-generation", dependencies=check_key) @app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request): async def handle_stop_generation(request: Request):
stop_everything_event() stop_everything_event()

View File

@ -126,15 +126,15 @@ class EncodeRequest(BaseModel):
text: str text: str
class DecodeRequest(BaseModel):
tokens: List[int]
class EncodeResponse(BaseModel): class EncodeResponse(BaseModel):
tokens: List[int] tokens: List[int]
length: int length: int
class DecodeRequest(BaseModel):
tokens: List[int]
class DecodeResponse(BaseModel): class DecodeResponse(BaseModel):
text: str text: str
@ -143,6 +143,24 @@ class TokenCountResponse(BaseModel):
length: int length: int
class LogitsRequestParams(BaseModel):
prompt: str
use_samplers: bool = False
frequency_penalty: float | None = 0
max_tokens: int | None = 16
presence_penalty: float | None = 0
temperature: float | None = 1
top_p: float | None = 1
class LogitsRequest(GenerationOptions, LogitsRequestParams):
pass
class LogitsResponse(BaseModel):
logits: dict
class ModelInfoResponse(BaseModel): class ModelInfoResponse(BaseModel):
model_name: str model_name: str
lora_names: List[str] lora_names: List[str]

View File

@ -8,7 +8,7 @@ from modules.text_generation import generate_reply
global_scores = None global_scores = None
def get_next_logits(prompt, state, use_samplers, previous): def get_next_logits(prompt, state, use_samplers, previous, return_dict=False):
if shared.model is None: if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous return 'Error: No model is loaded1 Select one in the Model tab.', previous
@ -56,6 +56,14 @@ def get_next_logits(prompt, state, use_samplers, previous):
topk_indices = [i.expand((1, 1)) for i in topk_indices] topk_indices = [i.expand((1, 1)) for i in topk_indices]
tokens = [shared.tokenizer.decode(i) for i in topk_indices] tokens = [shared.tokenizer.decode(i) for i in topk_indices]
if return_dict:
output = {}
for row in list(zip(topk_values, tokens)):
output[row[1]] = row[0]
return output
else:
output = '' output = ''
for row in list(zip(topk_values, tokens)): for row in list(zip(topk_values, tokens)):
output += f"{row[0]} - {repr(row[1])}\n" output += f"{row[0]} - {repr(row[1])}\n"