mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
add api endpoint for finetuning
This commit is contained in:
parent
acfd876f29
commit
0a17565e53
18
api-example-finetune.py
Normal file
18
api-example-finetune.py
Normal file
@ -0,0 +1,18 @@
|
||||
import json
|
||||
import requests
|
||||
|
||||
HOST = 'localhost:5000'
|
||||
URI = f'http://{HOST}/api/v1/finetune'
|
||||
|
||||
def finetune(lora_name, raw_text_file):
|
||||
request = {
|
||||
'lora_name': lora_name,
|
||||
'raw_text_file': raw_text_file,
|
||||
}
|
||||
|
||||
response = requests.post(URI, json=request)
|
||||
|
||||
if __name__ == '__main__':
|
||||
lora_name = 'lora'
|
||||
raw_text_file = 'input'
|
||||
finetune(lora_name, raw_text_file)
|
@ -2,11 +2,11 @@ import json
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from extensions.api.util import build_parameters, build_parameters_train, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
from modules.training import do_train
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
@ -78,6 +78,18 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/finetune':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
generate_params = build_parameters_train(body)
|
||||
while True:
|
||||
try:
|
||||
print(next(do_train(**generate_params)))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
|
@ -61,6 +61,35 @@ def build_parameters(body, chat=False):
|
||||
|
||||
return generate_params
|
||||
|
||||
def build_parameters_train(body):
|
||||
generate_params = {
|
||||
'lora_name': str(body.get('lora_name', 'lora')),
|
||||
'always_override': bool(body.get('always_override', False)),
|
||||
'save_steps': int(body.get('save_steps', 0)),
|
||||
'micro_batch_size': int(body.get('micro_batch_size', 4)),
|
||||
'batch_size': int(body.get('batch_size', 128)),
|
||||
'epochs': int(body.get('epochs', 3)),
|
||||
'learning_rate': float(body.get('learning_rate', 3e-4)),
|
||||
'lr_scheduler_type': str(body.get('lr_scheduler_type', 'linear')),
|
||||
'lora_rank': int(body.get('lora_rank', 32)),
|
||||
'lora_alpha': int(body.get('lora_alpha', 64)),
|
||||
'lora_dropout': float(body.get('lora_dropout', 0.05)),
|
||||
'cutoff_len': int(body.get('cutoff_len', 256)),
|
||||
'dataset': str(body.get('dataset', None)),
|
||||
'eval_dataset': str(body.get('eval_dataset', None)),
|
||||
'format': str(body.get('format', None)),
|
||||
'eval_steps': int(body.get('eval_steps', 100)),
|
||||
'raw_text_file': str(body.get('raw_text_file', None)),
|
||||
'overlap_len': int(body.get('overlap_len', 128)),
|
||||
'newline_favor_len': int(body.get('newline_favor_len', 128)),
|
||||
'higher_rank_limit': bool(body.get('higher_rank_limit', False)),
|
||||
'warmup_steps': int(body.get('warmup_steps', 100)),
|
||||
'optimizer': str(body.get('optimizer', 'adamw_torch')),
|
||||
'hard_cut_string': str(body.get('hard_cut_string', '\\n\\n\\n')),
|
||||
'train_only_after': str(body.get('train_only_after', ''))
|
||||
}
|
||||
|
||||
return generate_params
|
||||
|
||||
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||
Thread(target=_start_cloudflared, args=[
|
||||
|
Loading…
Reference in New Issue
Block a user