mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add types to the encode/decode/token-count endpoints
This commit is contained in:
parent
f6ca9cfcdc
commit
1b69694fe9
@ -27,7 +27,12 @@ from .typing import (
|
|||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
|
DecodeRequest,
|
||||||
|
DecodeResponse,
|
||||||
|
EncodeRequest,
|
||||||
|
EncodeResponse,
|
||||||
ModelInfoResponse,
|
ModelInfoResponse,
|
||||||
|
TokenCountResponse,
|
||||||
to_dict
|
to_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -206,26 +211,21 @@ async def handle_moderations(request: Request):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/encode")
|
@app.post("/v1/internal/encode", response_model=EncodeResponse)
|
||||||
async def handle_token_encode(request: Request):
|
async def handle_token_encode(request_data: EncodeRequest):
|
||||||
body = await request.json()
|
response = token_encode(request_data.text)
|
||||||
encoding_format = body.get("encoding_format", "")
|
|
||||||
response = token_encode(body["input"], encoding_format)
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/decode")
|
@app.post("/v1/internal/decode", response_model=DecodeResponse)
|
||||||
async def handle_token_decode(request: Request):
|
async def handle_token_decode(request_data: DecodeRequest):
|
||||||
body = await request.json()
|
response = token_decode(request_data.tokens)
|
||||||
encoding_format = body.get("encoding_format", "")
|
return JSONResponse(response)
|
||||||
response = token_decode(body["input"], encoding_format)
|
|
||||||
return JSONResponse(response, no_debug=True)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/token-count")
|
@app.post("/v1/internal/token-count", response_model=TokenCountResponse)
|
||||||
async def handle_token_count(request: Request):
|
async def handle_token_count(request_data: EncodeRequest):
|
||||||
body = await request.json()
|
response = token_count(request_data.text)
|
||||||
response = token_count(body['prompt'])
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,34 +3,24 @@ from modules.text_generation import decode, encode
|
|||||||
|
|
||||||
def token_count(prompt):
|
def token_count(prompt):
|
||||||
tokens = encode(prompt)[0]
|
tokens = encode(prompt)[0]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'results': [{
|
'length': len(tokens)
|
||||||
'tokens': len(tokens)
|
|
||||||
}]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def token_encode(input, encoding_format):
|
def token_encode(input):
|
||||||
# if isinstance(input, list):
|
|
||||||
tokens = encode(input)[0]
|
tokens = encode(input)[0]
|
||||||
|
if tokens.__class__.__name__ in ['Tensor', 'ndarray']:
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'results': [{
|
'tokens': tokens,
|
||||||
'tokens': tokens,
|
'length': len(tokens),
|
||||||
'length': len(tokens),
|
|
||||||
}]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def token_decode(tokens, encoding_format):
|
def token_decode(tokens):
|
||||||
# if isinstance(input, list):
|
output = decode(tokens)
|
||||||
# if encoding_format == "base64":
|
|
||||||
# tokens = base64_to_float_list(tokens)
|
|
||||||
output = decode(tokens)[0]
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'results': [{
|
'text': output
|
||||||
'text': output
|
|
||||||
}]
|
|
||||||
}
|
}
|
||||||
|
@ -121,6 +121,27 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
usage: dict
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
|
class EncodeRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeRequest(BaseModel):
|
||||||
|
tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class EncodeResponse(BaseModel):
|
||||||
|
tokens: List[int]
|
||||||
|
length: int
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeResponse(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountResponse(BaseModel):
|
||||||
|
length: int
|
||||||
|
|
||||||
|
|
||||||
class ModelInfoResponse(BaseModel):
|
class ModelInfoResponse(BaseModel):
|
||||||
model_name: str
|
model_name: str
|
||||||
lora_names: List[str]
|
lora_names: List[str]
|
||||||
|
@ -101,7 +101,7 @@ class LlamaCppModel:
|
|||||||
|
|
||||||
return self.model.tokenize(string)
|
return self.model.tokenize(string)
|
||||||
|
|
||||||
def decode(self, ids):
|
def decode(self, ids, **kwargs):
|
||||||
return self.model.detokenize(ids).decode('utf-8')
|
return self.model.detokenize(ids).decode('utf-8')
|
||||||
|
|
||||||
def get_logits(self, tokens):
|
def get_logits(self, tokens):
|
||||||
|
@ -145,7 +145,7 @@ def decode(output_ids, skip_special_tokens=True):
|
|||||||
if shared.tokenizer is None:
|
if shared.tokenizer is None:
|
||||||
raise ValueError('No tokenizer is loaded')
|
raise ValueError('No tokenizer is loaded')
|
||||||
|
|
||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
return shared.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
def get_encoded_length(prompt):
|
def get_encoded_length(prompt):
|
||||||
|
Loading…
Reference in New Issue
Block a user