mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
add api endpoint for finetuning
This commit is contained in:
parent
acfd876f29
commit
0a17565e53
18
api-example-finetune.py
Normal file
18
api-example-finetune.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
HOST = 'localhost:5000'
|
||||||
|
URI = f'http://{HOST}/api/v1/finetune'
|
||||||
|
|
||||||
|
def finetune(lora_name, raw_text_file):
|
||||||
|
request = {
|
||||||
|
'lora_name': lora_name,
|
||||||
|
'raw_text_file': raw_text_file,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(URI, json=request)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
lora_name = 'lora'
|
||||||
|
raw_text_file = 'input'
|
||||||
|
finetune(lora_name, raw_text_file)
|
@ -2,11 +2,11 @@ import json
|
|||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
from extensions.api.util import build_parameters, build_parameters_train, try_start_cloudflared
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.chat import generate_chat_reply
|
from modules.chat import generate_chat_reply
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
from modules.training import do_train
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
@ -78,6 +78,18 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
elif self.path == '/api/v1/finetune':
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
generate_params = build_parameters_train(body)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print(next(do_train(**generate_params)))
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
elif self.path == '/api/v1/token-count':
|
elif self.path == '/api/v1/token-count':
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
@ -61,6 +61,35 @@ def build_parameters(body, chat=False):
|
|||||||
|
|
||||||
return generate_params
|
return generate_params
|
||||||
|
|
||||||
|
def build_parameters_train(body):
|
||||||
|
generate_params = {
|
||||||
|
'lora_name': str(body.get('lora_name', 'lora')),
|
||||||
|
'always_override': bool(body.get('always_override', False)),
|
||||||
|
'save_steps': int(body.get('save_steps', 0)),
|
||||||
|
'micro_batch_size': int(body.get('micro_batch_size', 4)),
|
||||||
|
'batch_size': int(body.get('batch_size', 128)),
|
||||||
|
'epochs': int(body.get('epochs', 3)),
|
||||||
|
'learning_rate': float(body.get('learning_rate', 3e-4)),
|
||||||
|
'lr_scheduler_type': str(body.get('lr_scheduler_type', 'linear')),
|
||||||
|
'lora_rank': int(body.get('lora_rank', 32)),
|
||||||
|
'lora_alpha': int(body.get('lora_alpha', 64)),
|
||||||
|
'lora_dropout': float(body.get('lora_dropout', 0.05)),
|
||||||
|
'cutoff_len': int(body.get('cutoff_len', 256)),
|
||||||
|
'dataset': str(body.get('dataset', None)),
|
||||||
|
'eval_dataset': str(body.get('eval_dataset', None)),
|
||||||
|
'format': str(body.get('format', None)),
|
||||||
|
'eval_steps': int(body.get('eval_steps', 100)),
|
||||||
|
'raw_text_file': str(body.get('raw_text_file', None)),
|
||||||
|
'overlap_len': int(body.get('overlap_len', 128)),
|
||||||
|
'newline_favor_len': int(body.get('newline_favor_len', 128)),
|
||||||
|
'higher_rank_limit': bool(body.get('higher_rank_limit', False)),
|
||||||
|
'warmup_steps': int(body.get('warmup_steps', 100)),
|
||||||
|
'optimizer': str(body.get('optimizer', 'adamw_torch')),
|
||||||
|
'hard_cut_string': str(body.get('hard_cut_string', '\\n\\n\\n')),
|
||||||
|
'train_only_after': str(body.get('train_only_after', ''))
|
||||||
|
}
|
||||||
|
|
||||||
|
return generate_params
|
||||||
|
|
||||||
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||||
Thread(target=_start_cloudflared, args=[
|
Thread(target=_start_cloudflared, args=[
|
||||||
|
Loading…
Reference in New Issue
Block a user