mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-22 09:39:08 +01:00
a86ad841f1
Co-authored-by: eugenio.segala <esegala@deloitte.co.uk>
380 lines
12 KiB
Python
380 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# type: ignore[reportUnusedImport]
|
|
|
|
import subprocess
|
|
import os
|
|
import re
|
|
import json
|
|
import sys
|
|
import requests
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
ContextManager,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
Literal,
|
|
Tuple,
|
|
Set,
|
|
)
|
|
from re import RegexFlag
|
|
|
|
|
|
class ServerResponse:
|
|
headers: dict
|
|
status_code: int
|
|
body: dict | Any
|
|
|
|
|
|
class ServerProcess:
|
|
# default options
|
|
debug: bool = False
|
|
server_port: int = 8080
|
|
server_host: str = "127.0.0.1"
|
|
model_hf_repo: str = "ggml-org/models"
|
|
model_hf_file: str = "tinyllamas/stories260K.gguf"
|
|
model_alias: str = "tinyllama-2"
|
|
temperature: float = 0.8
|
|
seed: int = 42
|
|
|
|
# custom options
|
|
model_alias: str | None = None
|
|
model_url: str | None = None
|
|
model_file: str | None = None
|
|
model_draft: str | None = None
|
|
n_threads: int | None = None
|
|
n_gpu_layer: int | None = None
|
|
n_batch: int | None = None
|
|
n_ubatch: int | None = None
|
|
n_ctx: int | None = None
|
|
n_ga: int | None = None
|
|
n_ga_w: int | None = None
|
|
n_predict: int | None = None
|
|
n_prompts: int | None = 0
|
|
slot_save_path: str | None = None
|
|
id_slot: int | None = None
|
|
cache_prompt: bool | None = None
|
|
n_slots: int | None = None
|
|
server_continuous_batching: bool | None = False
|
|
server_embeddings: bool | None = False
|
|
server_reranking: bool | None = False
|
|
server_metrics: bool | None = False
|
|
server_slots: bool | None = False
|
|
draft: int | None = None
|
|
api_key: str | None = None
|
|
response_format: str | None = None
|
|
lora_files: List[str] | None = None
|
|
disable_ctx_shift: int | None = False
|
|
draft_min: int | None = None
|
|
draft_max: int | None = None
|
|
no_webui: bool | None = None
|
|
|
|
# session variables
|
|
process: subprocess.Popen | None = None
|
|
|
|
def __init__(self):
|
|
if "N_GPU_LAYERS" in os.environ:
|
|
self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
|
|
if "DEBUG" in os.environ:
|
|
self.debug = True
|
|
if "PORT" in os.environ:
|
|
self.server_port = int(os.environ["PORT"])
|
|
|
|
def start(self, timeout_seconds: int = 10) -> None:
|
|
if "LLAMA_SERVER_BIN_PATH" in os.environ:
|
|
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
|
|
elif os.name == "nt":
|
|
server_path = "../../../build/bin/Release/llama-server.exe"
|
|
else:
|
|
server_path = "../../../build/bin/llama-server"
|
|
server_args = [
|
|
"--host",
|
|
self.server_host,
|
|
"--port",
|
|
self.server_port,
|
|
"--temp",
|
|
self.temperature,
|
|
"--seed",
|
|
self.seed,
|
|
]
|
|
if self.model_file:
|
|
server_args.extend(["--model", self.model_file])
|
|
if self.model_url:
|
|
server_args.extend(["--model-url", self.model_url])
|
|
if self.model_draft:
|
|
server_args.extend(["--model-draft", self.model_draft])
|
|
if self.model_hf_repo:
|
|
server_args.extend(["--hf-repo", self.model_hf_repo])
|
|
if self.model_hf_file:
|
|
server_args.extend(["--hf-file", self.model_hf_file])
|
|
if self.n_batch:
|
|
server_args.extend(["--batch-size", self.n_batch])
|
|
if self.n_ubatch:
|
|
server_args.extend(["--ubatch-size", self.n_ubatch])
|
|
if self.n_threads:
|
|
server_args.extend(["--threads", self.n_threads])
|
|
if self.n_gpu_layer:
|
|
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
|
|
if self.draft is not None:
|
|
server_args.extend(["--draft", self.draft])
|
|
if self.server_continuous_batching:
|
|
server_args.append("--cont-batching")
|
|
if self.server_embeddings:
|
|
server_args.append("--embedding")
|
|
if self.server_reranking:
|
|
server_args.append("--reranking")
|
|
if self.server_metrics:
|
|
server_args.append("--metrics")
|
|
if self.server_slots:
|
|
server_args.append("--slots")
|
|
if self.model_alias:
|
|
server_args.extend(["--alias", self.model_alias])
|
|
if self.n_ctx:
|
|
server_args.extend(["--ctx-size", self.n_ctx])
|
|
if self.n_slots:
|
|
server_args.extend(["--parallel", self.n_slots])
|
|
if self.n_predict:
|
|
server_args.extend(["--n-predict", self.n_predict])
|
|
if self.slot_save_path:
|
|
server_args.extend(["--slot-save-path", self.slot_save_path])
|
|
if self.n_ga:
|
|
server_args.extend(["--grp-attn-n", self.n_ga])
|
|
if self.n_ga_w:
|
|
server_args.extend(["--grp-attn-w", self.n_ga_w])
|
|
if self.debug:
|
|
server_args.append("--verbose")
|
|
if self.lora_files:
|
|
for lora_file in self.lora_files:
|
|
server_args.extend(["--lora", lora_file])
|
|
if self.disable_ctx_shift:
|
|
server_args.extend(["--no-context-shift"])
|
|
if self.api_key:
|
|
server_args.extend(["--api-key", self.api_key])
|
|
if self.draft_max:
|
|
server_args.extend(["--draft-max", self.draft_max])
|
|
if self.draft_min:
|
|
server_args.extend(["--draft-min", self.draft_min])
|
|
if self.no_webui:
|
|
server_args.append("--no-webui")
|
|
|
|
args = [str(arg) for arg in [server_path, *server_args]]
|
|
print(f"bench: starting server with: {' '.join(args)}")
|
|
|
|
flags = 0
|
|
if "nt" == os.name:
|
|
flags |= subprocess.DETACHED_PROCESS
|
|
flags |= subprocess.CREATE_NEW_PROCESS_GROUP
|
|
flags |= subprocess.CREATE_NO_WINDOW
|
|
|
|
self.process = subprocess.Popen(
|
|
[str(arg) for arg in [server_path, *server_args]],
|
|
creationflags=flags,
|
|
stdout=sys.stdout,
|
|
stderr=sys.stdout,
|
|
env={**os.environ, "LLAMA_CACHE": "tmp"},
|
|
)
|
|
server_instances.add(self)
|
|
|
|
print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
|
|
|
|
# wait for server to start
|
|
start_time = time.time()
|
|
while time.time() - start_time < timeout_seconds:
|
|
try:
|
|
response = self.make_request("GET", "/health", headers={
|
|
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
|
|
})
|
|
if response.status_code == 200:
|
|
self.ready = True
|
|
return # server is ready
|
|
except Exception as e:
|
|
pass
|
|
print(f"Waiting for server to start...")
|
|
time.sleep(0.5)
|
|
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
|
|
|
|
def stop(self) -> None:
|
|
if self in server_instances:
|
|
server_instances.remove(self)
|
|
if self.process:
|
|
print(f"Stopping server with pid={self.process.pid}")
|
|
self.process.kill()
|
|
self.process = None
|
|
|
|
def make_request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
data: dict | Any | None = None,
|
|
headers: dict | None = None,
|
|
) -> ServerResponse:
|
|
url = f"http://{self.server_host}:{self.server_port}{path}"
|
|
parse_body = False
|
|
if method == "GET":
|
|
response = requests.get(url, headers=headers)
|
|
parse_body = True
|
|
elif method == "POST":
|
|
response = requests.post(url, headers=headers, json=data)
|
|
parse_body = True
|
|
elif method == "OPTIONS":
|
|
response = requests.options(url, headers=headers)
|
|
else:
|
|
raise ValueError(f"Unimplemented method: {method}")
|
|
result = ServerResponse()
|
|
result.headers = dict(response.headers)
|
|
result.status_code = response.status_code
|
|
result.body = response.json() if parse_body else None
|
|
print("Response from server", json.dumps(result.body, indent=2))
|
|
return result
|
|
|
|
def make_stream_request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
data: dict | None = None,
|
|
headers: dict | None = None,
|
|
) -> Iterator[dict]:
|
|
url = f"http://{self.server_host}:{self.server_port}{path}"
|
|
if method == "POST":
|
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
|
else:
|
|
raise ValueError(f"Unimplemented method: {method}")
|
|
for line_bytes in response.iter_lines():
|
|
line = line_bytes.decode("utf-8")
|
|
if '[DONE]' in line:
|
|
break
|
|
elif line.startswith('data: '):
|
|
data = json.loads(line[6:])
|
|
print("Partial response from server", json.dumps(data, indent=2))
|
|
yield data
|
|
|
|
|
|
server_instances: Set[ServerProcess] = set()
|
|
|
|
|
|
class ServerPreset:
|
|
@staticmethod
|
|
def tinyllama2() -> ServerProcess:
|
|
server = ServerProcess()
|
|
server.model_hf_repo = "ggml-org/models"
|
|
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
|
server.model_alias = "tinyllama-2"
|
|
server.n_ctx = 256
|
|
server.n_batch = 32
|
|
server.n_slots = 2
|
|
server.n_predict = 64
|
|
server.seed = 42
|
|
return server
|
|
|
|
@staticmethod
|
|
def bert_bge_small() -> ServerProcess:
|
|
server = ServerProcess()
|
|
server.model_hf_repo = "ggml-org/models"
|
|
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
|
server.model_alias = "bert-bge-small"
|
|
server.n_ctx = 512
|
|
server.n_batch = 128
|
|
server.n_ubatch = 128
|
|
server.n_slots = 2
|
|
server.seed = 42
|
|
server.server_embeddings = True
|
|
return server
|
|
|
|
@staticmethod
|
|
def tinyllama_infill() -> ServerProcess:
|
|
server = ServerProcess()
|
|
server.model_hf_repo = "ggml-org/models"
|
|
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
|
|
server.model_alias = "tinyllama-infill"
|
|
server.n_ctx = 2048
|
|
server.n_batch = 1024
|
|
server.n_slots = 1
|
|
server.n_predict = 64
|
|
server.temperature = 0.0
|
|
server.seed = 42
|
|
return server
|
|
|
|
@staticmethod
|
|
def stories15m_moe() -> ServerProcess:
|
|
server = ServerProcess()
|
|
server.model_hf_repo = "ggml-org/stories15M_MOE"
|
|
server.model_hf_file = "stories15M_MOE-F16.gguf"
|
|
server.model_alias = "stories15m-moe"
|
|
server.n_ctx = 2048
|
|
server.n_batch = 1024
|
|
server.n_slots = 1
|
|
server.n_predict = 64
|
|
server.temperature = 0.0
|
|
server.seed = 42
|
|
return server
|
|
|
|
@staticmethod
|
|
def jina_reranker_tiny() -> ServerProcess:
|
|
server = ServerProcess()
|
|
server.model_hf_repo = "ggml-org/models"
|
|
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
|
|
server.model_alias = "jina-reranker"
|
|
server.n_ctx = 512
|
|
server.n_batch = 512
|
|
server.n_slots = 1
|
|
server.seed = 42
|
|
server.server_reranking = True
|
|
return server
|
|
|
|
|
|
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
|
|
"""
|
|
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
|
|
|
|
Example usage:
|
|
|
|
results = parallel_function_calls([
|
|
(func1, (arg1, arg2)),
|
|
(func2, (arg3, arg4)),
|
|
])
|
|
"""
|
|
results = [None] * len(function_list)
|
|
exceptions = []
|
|
|
|
def worker(index, func, args):
|
|
try:
|
|
result = func(*args)
|
|
results[index] = result
|
|
except Exception as e:
|
|
exceptions.append((index, str(e)))
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
futures = []
|
|
for i, (func, args) in enumerate(function_list):
|
|
future = executor.submit(worker, i, func, args)
|
|
futures.append(future)
|
|
|
|
# Wait for all futures to complete
|
|
for future in as_completed(futures):
|
|
pass
|
|
|
|
# Check if there were any exceptions
|
|
if exceptions:
|
|
print("Exceptions occurred:")
|
|
for index, error in exceptions:
|
|
print(f"Function at index {index}: {error}")
|
|
|
|
return results
|
|
|
|
|
|
def match_regex(regex: str, text: str) -> bool:
|
|
return (
|
|
re.compile(
|
|
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
|
|
).search(text)
|
|
is not None
|
|
)
|
|
|
|
def is_slow_test_allowed():
|
|
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|