diff --git a/api-example-finetune.py b/api-example-finetune.py new file mode 100644 index 00000000..a39b6d59 --- /dev/null +++ b/api-example-finetune.py @@ -0,0 +1,18 @@ +import json +import requests + +HOST = 'localhost:5000' +URI = f'http://{HOST}/api/v1/finetune' + +def finetune(lora_name, raw_text_file): + request = { + 'lora_name': lora_name, + 'raw_text_file': raw_text_file, + } + + response = requests.post(URI, json=request) + +if __name__ == '__main__': + lora_name = 'lora' + raw_text_file = 'input' + finetune(lora_name, raw_text_file) \ No newline at end of file diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index 8c2326f4..9009c44a 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -2,11 +2,11 @@ import json from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread -from extensions.api.util import build_parameters, try_start_cloudflared +from extensions.api.util import build_parameters, build_parameters_train, try_start_cloudflared from modules import shared from modules.chat import generate_chat_reply from modules.text_generation import encode, generate_reply - +from modules.training import do_train class Handler(BaseHTTPRequestHandler): def do_GET(self): @@ -78,6 +78,18 @@ class Handler(BaseHTTPRequestHandler): self.wfile.write(response.encode('utf-8')) + elif self.path == '/api/v1/finetune': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + generate_params = build_parameters_train(body) + while True: + try: + print(next(do_train(**generate_params))) + except StopIteration: + break + elif self.path == '/api/v1/token-count': self.send_response(200) self.send_header('Content-Type', 'application/json') diff --git a/extensions/api/util.py b/extensions/api/util.py index 9c6dd30e..d8eb9f0c 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -61,6 +61,35 @@ def build_parameters(body, chat=False): return generate_params +def build_parameters_train(body): + generate_params = { + 'lora_name': str(body.get('lora_name', 'lora')), + 'always_override': bool(body.get('always_override', False)), + 'save_steps': int(body.get('save_steps', 0)), + 'micro_batch_size': int(body.get('micro_batch_size', 4)), + 'batch_size': int(body.get('batch_size', 128)), + 'epochs': int(body.get('epochs', 3)), + 'learning_rate': float(body.get('learning_rate', 3e-4)), + 'lr_scheduler_type': str(body.get('lr_scheduler_type', 'linear')), + 'lora_rank': int(body.get('lora_rank', 32)), + 'lora_alpha': int(body.get('lora_alpha', 64)), + 'lora_dropout': float(body.get('lora_dropout', 0.05)), + 'cutoff_len': int(body.get('cutoff_len', 256)), + 'dataset': str(body.get('dataset', None)), + 'eval_dataset': str(body.get('eval_dataset', None)), + 'format': str(body.get('format', None)), + 'eval_steps': int(body.get('eval_steps', 100)), + 'raw_text_file': str(body.get('raw_text_file', None)), + 'overlap_len': int(body.get('overlap_len', 128)), + 'newline_favor_len': int(body.get('newline_favor_len', 128)), + 'higher_rank_limit': bool(body.get('higher_rank_limit', False)), + 'warmup_steps': int(body.get('warmup_steps', 100)), + 'optimizer': str(body.get('optimizer', 'adamw_torch')), + 'hard_cut_string': str(body.get('hard_cut_string', '\\n\\n\\n')), + 'train_only_after': str(body.get('train_only_after', '')) + } + + return generate_params def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): Thread(target=_start_cloudflared, args=[