Add types to the encode/decode/token-count endpoints

This commit is contained in:
oobabooga 2023-11-07 19:05:36 -08:00
parent f6ca9cfcdc
commit 1b69694fe9
5 changed files with 47 additions and 36 deletions

View File

@ -27,7 +27,12 @@ from .typing import (
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
DecodeRequest,
DecodeResponse,
EncodeRequest,
EncodeResponse,
ModelInfoResponse, ModelInfoResponse,
TokenCountResponse,
to_dict to_dict
) )
@ -206,26 +211,21 @@ async def handle_moderations(request: Request):
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/internal/encode") @app.post("/v1/internal/encode", response_model=EncodeResponse)
async def handle_token_encode(request: Request): async def handle_token_encode(request_data: EncodeRequest):
body = await request.json() response = token_encode(request_data.text)
encoding_format = body.get("encoding_format", "")
response = token_encode(body["input"], encoding_format)
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/internal/decode") @app.post("/v1/internal/decode", response_model=DecodeResponse)
async def handle_token_decode(request: Request): async def handle_token_decode(request_data: DecodeRequest):
body = await request.json() response = token_decode(request_data.tokens)
encoding_format = body.get("encoding_format", "") return JSONResponse(response)
response = token_decode(body["input"], encoding_format)
return JSONResponse(response, no_debug=True)
@app.post("/v1/internal/token-count") @app.post("/v1/internal/token-count", response_model=TokenCountResponse)
async def handle_token_count(request: Request): async def handle_token_count(request_data: EncodeRequest):
body = await request.json() response = token_count(request_data.text)
response = token_count(body['prompt'])
return JSONResponse(response) return JSONResponse(response)

View File

@ -3,34 +3,24 @@ from modules.text_generation import decode, encode
def token_count(prompt): def token_count(prompt):
tokens = encode(prompt)[0] tokens = encode(prompt)[0]
return { return {
'results': [{ 'length': len(tokens)
'tokens': len(tokens)
}]
} }
def token_encode(input, encoding_format): def token_encode(input):
# if isinstance(input, list):
tokens = encode(input)[0] tokens = encode(input)[0]
if tokens.__class__.__name__ in ['Tensor', 'ndarray']:
tokens = tokens.tolist()
return { return {
'results': [{
'tokens': tokens, 'tokens': tokens,
'length': len(tokens), 'length': len(tokens),
}]
} }
def token_decode(tokens, encoding_format): def token_decode(tokens):
# if isinstance(input, list): output = decode(tokens)
# if encoding_format == "base64":
# tokens = base64_to_float_list(tokens)
output = decode(tokens)[0]
return { return {
'results': [{
'text': output 'text': output
}]
} }

View File

@ -121,6 +121,27 @@ class ChatCompletionResponse(BaseModel):
usage: dict usage: dict
class EncodeRequest(BaseModel):
text: str
class DecodeRequest(BaseModel):
tokens: List[int]
class EncodeResponse(BaseModel):
tokens: List[int]
length: int
class DecodeResponse(BaseModel):
text: str
class TokenCountResponse(BaseModel):
length: int
class ModelInfoResponse(BaseModel): class ModelInfoResponse(BaseModel):
model_name: str model_name: str
lora_names: List[str] lora_names: List[str]

View File

@ -101,7 +101,7 @@ class LlamaCppModel:
return self.model.tokenize(string) return self.model.tokenize(string)
def decode(self, ids): def decode(self, ids, **kwargs):
return self.model.detokenize(ids).decode('utf-8') return self.model.detokenize(ids).decode('utf-8')
def get_logits(self, tokens): def get_logits(self, tokens):

View File

@ -145,7 +145,7 @@ def decode(output_ids, skip_special_tokens=True):
if shared.tokenizer is None: if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded') raise ValueError('No tokenizer is loaded')
return shared.tokenizer.decode(output_ids, skip_special_tokens) return shared.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
def get_encoded_length(prompt): def get_encoded_length(prompt):