add api endpoint for finetuning

This commit is contained in:
tohrnii 2023-05-25 22:42:32 +01:00
parent acfd876f29
commit 0a17565e53
3 changed files with 61 additions and 2 deletions

18
api-example-finetune.py Normal file
View File

@ -0,0 +1,18 @@
import json
import requests
HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/finetune'
def finetune(lora_name, raw_text_file):
request = {
'lora_name': lora_name,
'raw_text_file': raw_text_file,
}
response = requests.post(URI, json=request)
if __name__ == '__main__':
lora_name = 'lora'
raw_text_file = 'input'
finetune(lora_name, raw_text_file)

View File

@ -2,11 +2,11 @@ import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared from extensions.api.util import build_parameters, build_parameters_train, try_start_cloudflared
from modules import shared from modules import shared
from modules.chat import generate_chat_reply from modules.chat import generate_chat_reply
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
from modules.training import do_train
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
@ -78,6 +78,18 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/finetune':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
generate_params = build_parameters_train(body)
while True:
try:
print(next(do_train(**generate_params)))
except StopIteration:
break
elif self.path == '/api/v1/token-count': elif self.path == '/api/v1/token-count':
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')

View File

@ -61,6 +61,35 @@ def build_parameters(body, chat=False):
return generate_params return generate_params
def build_parameters_train(body):
generate_params = {
'lora_name': str(body.get('lora_name', 'lora')),
'always_override': bool(body.get('always_override', False)),
'save_steps': int(body.get('save_steps', 0)),
'micro_batch_size': int(body.get('micro_batch_size', 4)),
'batch_size': int(body.get('batch_size', 128)),
'epochs': int(body.get('epochs', 3)),
'learning_rate': float(body.get('learning_rate', 3e-4)),
'lr_scheduler_type': str(body.get('lr_scheduler_type', 'linear')),
'lora_rank': int(body.get('lora_rank', 32)),
'lora_alpha': int(body.get('lora_alpha', 64)),
'lora_dropout': float(body.get('lora_dropout', 0.05)),
'cutoff_len': int(body.get('cutoff_len', 256)),
'dataset': str(body.get('dataset', None)),
'eval_dataset': str(body.get('eval_dataset', None)),
'format': str(body.get('format', None)),
'eval_steps': int(body.get('eval_steps', 100)),
'raw_text_file': str(body.get('raw_text_file', None)),
'overlap_len': int(body.get('overlap_len', 128)),
'newline_favor_len': int(body.get('newline_favor_len', 128)),
'higher_rank_limit': bool(body.get('higher_rank_limit', False)),
'warmup_steps': int(body.get('warmup_steps', 100)),
'optimizer': str(body.get('optimizer', 'adamw_torch')),
'hard_cut_string': str(body.get('hard_cut_string', '\\n\\n\\n')),
'train_only_after': str(body.get('train_only_after', ''))
}
return generate_params
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
Thread(target=_start_cloudflared, args=[ Thread(target=_start_cloudflared, args=[