2023-02-23 17:28:30 +01:00
import base64
import copy
2023-05-23 05:50:58 +02:00
import functools
2023-08-16 18:23:29 +02:00
import html
2023-02-23 16:05:25 +01:00
import json
import re
2023-09-21 22:19:32 +02:00
from datetime import datetime
2023-12-12 21:23:14 +01:00
from functools import partial
2023-02-23 16:05:25 +01:00
from pathlib import Path
2023-06-29 19:56:25 +02:00
import gradio as gr
2023-04-03 01:34:25 +02:00
import yaml
2023-12-12 21:23:14 +01:00
from jinja2 . sandbox import ImmutableSandboxedEnvironment
2023-04-05 04:03:58 +02:00
from PIL import Image
2023-02-23 18:41:42 +01:00
2023-02-25 13:23:02 +01:00
import modules . shared as shared
2024-01-09 15:24:27 +01:00
from modules import utils
2023-02-23 16:05:25 +01:00
from modules . extensions import apply_extensions
2023-04-16 23:25:44 +02:00
from modules . html_generator import chat_html_wrapper , make_thumbnail
2023-05-22 03:42:34 +02:00
from modules . logging_colors import logger
2023-06-25 06:44:36 +02:00
from modules . text_generation import (
generate_reply ,
get_encoded_length ,
get_max_prompt_length
)
2023-12-12 21:23:14 +01:00
from modules . utils import delete_file , get_available_characters , save_file
# Copied from the Transformers library
jinja_env = ImmutableSandboxedEnvironment ( trim_blocks = True , lstrip_blocks = True )
2023-04-26 08:21:53 +02:00
2023-07-30 20:25:38 +02:00
def str_presenter ( dumper , data ) :
"""
Copied from https : / / github . com / yaml / pyyaml / issues / 240
Makes pyyaml output prettier multiline strings .
"""
if data . count ( ' \n ' ) > 0 :
return dumper . represent_scalar ( ' tag:yaml.org,2002:str ' , data , style = ' | ' )
return dumper . represent_scalar ( ' tag:yaml.org,2002:str ' , data )
yaml . add_representer ( str , str_presenter )
yaml . representer . SafeRepresenter . add_representer ( str , str_presenter )
2023-12-12 21:23:14 +01:00
def get_generation_prompt ( renderer , impersonate = False , strip_trailing_spaces = True ) :
'''
Given a Jinja template , reverse - engineers the prefix and the suffix for
an assistant message ( if impersonate = False ) or an user message
( if impersonate = True )
'''
if impersonate :
messages = [
{ " role " : " user " , " content " : " <<|user-message-1|>> " } ,
{ " role " : " user " , " content " : " <<|user-message-2|>> " } ,
]
2023-05-14 15:43:55 +02:00
else :
2023-12-12 21:23:14 +01:00
messages = [
{ " role " : " assistant " , " content " : " <<|user-message-1|>> " } ,
{ " role " : " assistant " , " content " : " <<|user-message-2|>> " } ,
]
2023-05-14 15:43:55 +02:00
2023-12-12 21:23:14 +01:00
prompt = renderer ( messages = messages )
2023-05-14 15:43:55 +02:00
2023-12-12 21:23:14 +01:00
suffix_plus_prefix = prompt . split ( " <<|user-message-1|>> " ) [ 1 ] . split ( " <<|user-message-2|>> " ) [ 0 ]
suffix = prompt . split ( " <<|user-message-2|>> " ) [ 1 ]
prefix = suffix_plus_prefix [ len ( suffix ) : ]
2023-05-14 15:43:55 +02:00
2023-12-12 21:23:14 +01:00
if strip_trailing_spaces :
prefix = prefix . rstrip ( ' ' )
2023-05-14 15:43:55 +02:00
2023-12-12 21:23:14 +01:00
return prefix , suffix
2023-05-14 15:43:55 +02:00
2023-04-11 23:46:06 +02:00
def generate_chat_prompt ( user_input , state , * * kwargs ) :
2023-05-11 21:27:50 +02:00
impersonate = kwargs . get ( ' impersonate ' , False )
_continue = kwargs . get ( ' _continue ' , False )
also_return_rows = kwargs . get ( ' also_return_rows ' , False )
2023-07-04 05:03:30 +02:00
history = kwargs . get ( ' history ' , state [ ' history ' ] ) [ ' internal ' ]
2023-02-23 16:05:25 +01:00
2023-12-12 21:23:14 +01:00
# Templates
2024-02-19 04:57:38 +01:00
chat_template_str = state [ ' chat_template_str ' ]
if state [ ' mode ' ] != ' instruct ' :
chat_template_str = replace_character_names ( chat_template_str , state [ ' name1 ' ] , state [ ' name2 ' ] )
2023-12-12 21:23:14 +01:00
instruction_template = jinja_env . from_string ( state [ ' instruction_template_str ' ] )
instruct_renderer = partial ( instruction_template . render , add_generation_prompt = False )
2024-03-12 03:41:57 +01:00
chat_template = jinja_env . from_string ( chat_template_str )
chat_renderer = partial (
chat_template . render ,
add_generation_prompt = False ,
name1 = state [ ' name1 ' ] ,
name2 = state [ ' name2 ' ] ,
user_bio = replace_character_names ( state [ ' user_bio ' ] , state [ ' name1 ' ] , state [ ' name2 ' ] ) ,
)
2023-05-20 23:42:17 +02:00
2023-12-12 21:23:14 +01:00
messages = [ ]
2023-05-14 15:43:55 +02:00
2023-12-12 21:23:14 +01:00
if state [ ' mode ' ] == ' instruct ' :
renderer = instruct_renderer
2023-11-19 00:20:13 +01:00
if state [ ' custom_system_message ' ] . strip ( ) != ' ' :
2023-12-12 21:23:14 +01:00
messages . append ( { " role " : " system " , " content " : state [ ' custom_system_message ' ] } )
2023-04-05 16:49:59 +02:00
else :
2023-12-12 21:23:14 +01:00
renderer = chat_renderer
2024-03-12 03:41:57 +01:00
if state [ ' context ' ] . strip ( ) != ' ' or state [ ' user_bio ' ] . strip ( ) != ' ' :
2023-12-20 02:31:46 +01:00
context = replace_character_names ( state [ ' context ' ] , state [ ' name1 ' ] , state [ ' name2 ' ] )
messages . append ( { " role " : " system " , " content " : context } )
2023-04-05 16:49:59 +02:00
2023-12-12 21:23:14 +01:00
insert_pos = len ( messages )
for user_msg , assistant_msg in reversed ( history ) :
user_msg = user_msg . strip ( )
assistant_msg = assistant_msg . strip ( )
2023-07-30 20:42:30 +02:00
2023-12-12 21:23:14 +01:00
if assistant_msg :
messages . insert ( insert_pos , { " role " : " assistant " , " content " : assistant_msg } )
2023-04-15 04:02:08 +02:00
2023-12-12 21:23:14 +01:00
if user_msg not in [ ' ' , ' <|BEGIN-VISIBLE-CHAT|> ' ] :
messages . insert ( insert_pos , { " role " : " user " , " content " : user_msg } )
2023-04-15 04:02:08 +02:00
2023-12-12 21:23:14 +01:00
user_input = user_input . strip ( )
if user_input and not impersonate and not _continue :
messages . append ( { " role " : " user " , " content " : user_input } )
2023-12-18 06:04:03 +01:00
def remove_extra_bos ( prompt ) :
2024-04-06 22:12:16 +02:00
for bos_token in [ ' <s> ' , ' <|startoftext|> ' , ' <BOS_TOKEN> ' , ' <|endoftext|> ' ] :
2023-12-18 06:04:03 +01:00
while prompt . startswith ( bos_token ) :
prompt = prompt [ len ( bos_token ) : ]
return prompt
2023-12-12 21:23:14 +01:00
def make_prompt ( messages ) :
if state [ ' mode ' ] == ' chat-instruct ' and _continue :
prompt = renderer ( messages = messages [ : - 1 ] )
else :
prompt = renderer ( messages = messages )
2023-02-23 16:05:25 +01:00
2023-05-14 15:43:55 +02:00
if state [ ' mode ' ] == ' chat-instruct ' :
2023-12-12 21:23:14 +01:00
outer_messages = [ ]
if state [ ' custom_system_message ' ] . strip ( ) != ' ' :
outer_messages . append ( { " role " : " system " , " content " : state [ ' custom_system_message ' ] } )
2023-12-18 06:04:03 +01:00
prompt = remove_extra_bos ( prompt )
2023-12-12 21:23:14 +01:00
command = state [ ' chat-instruct_command ' ]
command = command . replace ( ' <|character|> ' , state [ ' name2 ' ] if not impersonate else state [ ' name1 ' ] )
command = command . replace ( ' <|prompt|> ' , prompt )
2024-03-12 03:41:57 +01:00
command = replace_character_names ( command , state [ ' name1 ' ] , state [ ' name2 ' ] )
2023-12-12 21:23:14 +01:00
if _continue :
prefix = get_generation_prompt ( renderer , impersonate = impersonate , strip_trailing_spaces = False ) [ 0 ]
prefix + = messages [ - 1 ] [ " content " ]
else :
prefix = get_generation_prompt ( renderer , impersonate = impersonate ) [ 0 ]
if not impersonate :
prefix = apply_extensions ( ' bot_prefix ' , prefix , state )
outer_messages . append ( { " role " : " user " , " content " : command } )
outer_messages . append ( { " role " : " assistant " , " content " : prefix } )
prompt = instruction_template . render ( messages = outer_messages )
suffix = get_generation_prompt ( instruct_renderer , impersonate = False ) [ 1 ]
2024-03-26 20:33:09 +01:00
if len ( suffix ) > 0 :
prompt = prompt [ : - len ( suffix ) ]
2023-12-12 21:23:14 +01:00
2023-05-14 15:43:55 +02:00
else :
2023-12-12 21:23:14 +01:00
if _continue :
suffix = get_generation_prompt ( renderer , impersonate = impersonate ) [ 1 ]
2024-03-26 20:33:09 +01:00
if len ( suffix ) > 0 :
prompt = prompt [ : - len ( suffix ) ]
2023-12-12 21:23:14 +01:00
else :
prefix = get_generation_prompt ( renderer , impersonate = impersonate ) [ 0 ]
if state [ ' mode ' ] == ' chat ' and not impersonate :
prefix = apply_extensions ( ' bot_prefix ' , prefix , state )
2023-04-05 16:49:59 +02:00
2023-12-12 21:23:14 +01:00
prompt + = prefix
2023-12-18 06:04:03 +01:00
prompt = remove_extra_bos ( prompt )
2023-12-12 21:23:14 +01:00
return prompt
2024-02-05 06:31:24 +01:00
prompt = make_prompt ( messages )
2024-02-16 16:59:09 +01:00
# Handle truncation
if shared . tokenizer is not None :
max_length = get_max_prompt_length ( state )
encoded_length = get_encoded_length ( prompt )
while len ( messages ) > 0 and encoded_length > max_length :
# Remove old message, save system message
if len ( messages ) > 2 and messages [ 0 ] [ ' role ' ] == ' system ' :
messages . pop ( 1 )
2024-02-05 06:31:24 +01:00
2024-02-16 16:59:09 +01:00
# Remove old message when no system message is present
elif len ( messages ) > 1 and messages [ 0 ] [ ' role ' ] != ' system ' :
messages . pop ( 0 )
2024-02-05 06:31:24 +01:00
2024-02-16 16:59:09 +01:00
# Resort to truncating the user input
else :
2023-02-23 16:05:25 +01:00
2024-02-16 16:59:09 +01:00
user_message = messages [ - 1 ] [ ' content ' ]
2024-02-05 06:31:24 +01:00
2024-02-16 16:59:09 +01:00
# Bisect the truncation point
left , right = 0 , len ( user_message ) - 1
2024-02-05 06:31:24 +01:00
2024-02-16 16:59:09 +01:00
while right - left > 1 :
mid = ( left + right ) / / 2
2024-02-05 06:31:24 +01:00
2024-03-07 17:50:24 +01:00
messages [ - 1 ] [ ' content ' ] = user_message [ : mid ]
2024-02-16 16:59:09 +01:00
prompt = make_prompt ( messages )
encoded_length = get_encoded_length ( prompt )
2024-02-05 06:31:24 +01:00
2024-02-16 16:59:09 +01:00
if encoded_length < = max_length :
left = mid
2024-03-07 17:50:24 +01:00
else :
right = mid
2024-02-16 16:59:09 +01:00
2024-03-07 17:50:24 +01:00
messages [ - 1 ] [ ' content ' ] = user_message [ : left ]
2024-02-05 06:31:24 +01:00
prompt = make_prompt ( messages )
encoded_length = get_encoded_length ( prompt )
2024-02-16 16:59:09 +01:00
if encoded_length > max_length :
logger . error ( f " Failed to build the chat prompt. The input is too long for the available context length. \n \n Truncation length: { state [ ' truncation_length ' ] } \n max_new_tokens: { state [ ' max_new_tokens ' ] } (is it too high?) \n Available context length: { max_length } \n " )
raise ValueError
2024-02-05 06:31:24 +01:00
else :
2024-02-16 16:59:09 +01:00
logger . warning ( f " The input has been truncated. Context length: { state [ ' truncation_length ' ] } , max_new_tokens: { state [ ' max_new_tokens ' ] } , available context length: { max_length } . " )
break
2024-02-05 06:31:24 +01:00
prompt = make_prompt ( messages )
encoded_length = get_encoded_length ( prompt )
2023-04-10 01:06:20 +02:00
2023-04-01 06:12:13 +02:00
if also_return_rows :
2023-12-12 21:23:14 +01:00
return prompt , [ message [ ' content ' ] for message in messages ]
2023-04-01 06:12:13 +02:00
else :
return prompt
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-04-11 17:30:06 +02:00
def get_stopping_strings ( state ) :
2023-05-14 15:43:55 +02:00
stopping_strings = [ ]
2023-12-12 21:23:14 +01:00
renderers = [ ]
2023-05-14 15:43:55 +02:00
if state [ ' mode ' ] in [ ' instruct ' , ' chat-instruct ' ] :
2023-12-12 21:23:14 +01:00
template = jinja_env . from_string ( state [ ' instruction_template_str ' ] )
renderer = partial ( template . render , add_generation_prompt = False )
renderers . append ( renderer )
2023-05-10 19:22:38 +02:00
2023-12-12 21:23:14 +01:00
if state [ ' mode ' ] in [ ' chat ' , ' chat-instruct ' ] :
template = jinja_env . from_string ( state [ ' chat_template_str ' ] )
renderer = partial ( template . render , add_generation_prompt = False , name1 = state [ ' name1 ' ] , name2 = state [ ' name2 ' ] )
renderers . append ( renderer )
2023-05-10 19:22:38 +02:00
2023-12-12 21:23:14 +01:00
for renderer in renderers :
prefix_bot , suffix_bot = get_generation_prompt ( renderer , impersonate = False )
prefix_user , suffix_user = get_generation_prompt ( renderer , impersonate = True )
2023-05-14 15:43:55 +02:00
stopping_strings + = [
2023-12-12 21:23:14 +01:00
suffix_user + prefix_bot ,
suffix_user + prefix_user ,
suffix_bot + prefix_bot ,
suffix_bot + prefix_user ,
2023-05-14 15:43:55 +02:00
]
2023-04-24 05:47:40 +02:00
2023-08-29 04:40:11 +02:00
if ' stopping_strings ' in state and isinstance ( state [ ' stopping_strings ' ] , list ) :
stopping_strings + = state . pop ( ' stopping_strings ' )
2023-12-12 21:23:14 +01:00
return list ( set ( stopping_strings ) )
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-12-12 22:00:38 +01:00
def chatbot_wrapper ( text , state , regenerate = False , _continue = False , loading_message = True , for_ui = False ) :
2023-07-04 05:03:30 +02:00
history = state [ ' history ' ]
2023-05-20 23:42:17 +02:00
output = copy . deepcopy ( history )
2023-05-21 18:24:54 +02:00
output = apply_extensions ( ' history ' , output )
2023-06-24 15:59:07 +02:00
state = apply_extensions ( ' state ' , state )
2023-04-13 15:35:08 +02:00
2023-04-24 01:32:22 +02:00
visible_text = None
2023-04-11 17:30:06 +02:00
stopping_strings = get_stopping_strings ( state )
2023-06-24 15:59:07 +02:00
is_stream = state [ ' stream ' ]
2023-02-25 04:23:51 +01:00
2023-08-16 18:23:29 +02:00
# Prepare the input
2023-12-15 05:01:45 +01:00
if not ( regenerate or _continue ) :
2023-08-16 18:23:29 +02:00
visible_text = html . escape ( text )
# Apply extensions
2023-07-25 23:49:56 +02:00
text , visible_text = apply_extensions ( ' chat_input ' , text , visible_text , state )
2023-08-13 06:12:15 +02:00
text = apply_extensions ( ' input ' , text , state , is_chat = True )
2023-07-04 05:03:30 +02:00
2023-12-15 05:01:45 +01:00
output [ ' internal ' ] . append ( [ text , ' ' ] )
output [ ' visible ' ] . append ( [ visible_text , ' ' ] )
2023-04-24 05:47:40 +02:00
# *Is typing...*
2023-05-20 23:42:17 +02:00
if loading_message :
2023-12-15 05:01:45 +01:00
yield {
' visible ' : output [ ' visible ' ] [ : - 1 ] + [ [ output [ ' visible ' ] [ - 1 ] [ 0 ] , shared . processing_message ] ] ,
' internal ' : output [ ' internal ' ]
}
2023-04-24 05:47:40 +02:00
else :
2023-05-20 23:42:17 +02:00
text , visible_text = output [ ' internal ' ] [ - 1 ] [ 0 ] , output [ ' visible ' ] [ - 1 ] [ 0 ]
2023-04-24 05:47:40 +02:00
if regenerate :
2023-05-20 23:42:17 +02:00
if loading_message :
2023-12-15 05:01:45 +01:00
yield {
' visible ' : output [ ' visible ' ] [ : - 1 ] + [ [ visible_text , shared . processing_message ] ] ,
' internal ' : output [ ' internal ' ] [ : - 1 ] + [ [ text , ' ' ] ]
}
2023-04-24 05:47:40 +02:00
elif _continue :
2023-05-20 23:42:17 +02:00
last_reply = [ output [ ' internal ' ] [ - 1 ] [ 1 ] , output [ ' visible ' ] [ - 1 ] [ 1 ] ]
if loading_message :
2023-12-15 05:01:45 +01:00
yield {
' visible ' : output [ ' visible ' ] [ : - 1 ] + [ [ visible_text , last_reply [ 1 ] + ' ... ' ] ] ,
' internal ' : output [ ' internal ' ]
}
2023-02-25 04:49:18 +01:00
2023-08-16 18:23:29 +02:00
# Generate the prompt
2023-05-20 23:42:17 +02:00
kwargs = {
' _continue ' : _continue ,
2023-12-15 05:01:45 +01:00
' history ' : output if _continue else { k : v [ : - 1 ] for k , v in output . items ( ) }
2023-05-20 23:42:17 +02:00
}
2023-04-24 01:32:22 +02:00
prompt = apply_extensions ( ' custom_generate_chat_prompt ' , text , state , * * kwargs )
if prompt is None :
2023-04-12 02:46:17 +02:00
prompt = generate_chat_prompt ( text , state , * * kwargs )
2023-02-24 12:31:30 +01:00
# Generate
2023-08-13 06:12:15 +02:00
reply = None
2023-12-12 22:00:38 +01:00
for j , reply in enumerate ( generate_reply ( prompt , state , stopping_strings = stopping_strings , is_chat = True , for_ui = for_ui ) ) :
2023-08-13 06:12:15 +02:00
# Extract the reply
2023-11-06 06:38:29 +01:00
visible_reply = reply
if state [ ' mode ' ] in [ ' chat ' , ' chat-instruct ' ] :
visible_reply = re . sub ( " (<USER>|<user>| {{ user}}) " , state [ ' name1 ' ] , reply )
2023-08-16 18:23:29 +02:00
visible_reply = html . escape ( visible_reply )
2023-08-13 06:12:15 +02:00
if shared . stop_everything :
output [ ' visible ' ] [ - 1 ] [ 1 ] = apply_extensions ( ' output ' , output [ ' visible ' ] [ - 1 ] [ 1 ] , state , is_chat = True )
yield output
return
if _continue :
output [ ' internal ' ] [ - 1 ] = [ text , last_reply [ 0 ] + reply ]
output [ ' visible ' ] [ - 1 ] = [ visible_text , last_reply [ 1 ] + visible_reply ]
if is_stream :
yield output
elif not ( j == 0 and visible_reply . strip ( ) == ' ' ) :
output [ ' internal ' ] [ - 1 ] = [ text , reply . lstrip ( ' ' ) ]
output [ ' visible ' ] [ - 1 ] = [ visible_text , visible_reply . lstrip ( ' ' ) ]
if is_stream :
2023-05-20 23:42:17 +02:00
yield output
2023-03-24 06:03:30 +01:00
2023-08-13 06:12:15 +02:00
output [ ' visible ' ] [ - 1 ] [ 1 ] = apply_extensions ( ' output ' , output [ ' visible ' ] [ - 1 ] [ 1 ] , state , is_chat = True )
2023-05-20 23:42:17 +02:00
yield output
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-08-14 16:46:07 +02:00
def impersonate_wrapper ( text , state ) :
2023-09-13 20:22:53 +02:00
2023-12-04 02:45:50 +01:00
static_output = chat_html_wrapper ( state [ ' history ' ] , state [ ' name1 ' ] , state [ ' name2 ' ] , state [ ' mode ' ] , state [ ' chat_style ' ] , state [ ' character_menu ' ] )
2023-09-13 20:22:53 +02:00
2023-05-09 07:37:42 +02:00
prompt = generate_chat_prompt ( ' ' , state , impersonate = True )
2023-04-11 17:30:06 +02:00
stopping_strings = get_stopping_strings ( state )
2023-02-24 12:44:54 +01:00
2023-09-13 20:22:53 +02:00
yield text + ' ... ' , static_output
2023-08-13 06:12:15 +02:00
reply = None
2023-08-14 06:10:47 +02:00
for reply in generate_reply ( prompt + text , state , stopping_strings = stopping_strings , is_chat = True ) :
2023-09-13 20:22:53 +02:00
yield ( text + reply ) . lstrip ( ' ' ) , static_output
2023-08-13 06:12:15 +02:00
if shared . stop_everything :
return
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-12-12 22:00:38 +01:00
def generate_chat_reply ( text , state , regenerate = False , _continue = False , loading_message = True , for_ui = False ) :
2023-07-04 05:03:30 +02:00
history = state [ ' history ' ]
2023-05-11 20:37:04 +02:00
if regenerate or _continue :
text = ' '
2023-05-20 23:42:17 +02:00
if ( len ( history [ ' visible ' ] ) == 1 and not history [ ' visible ' ] [ 0 ] [ 0 ] ) or len ( history [ ' internal ' ] ) == 0 :
yield history
2023-05-11 20:37:04 +02:00
return
2023-04-07 05:15:45 +02:00
2023-12-12 22:00:38 +01:00
for history in chatbot_wrapper ( text , state , regenerate = regenerate , _continue = _continue , loading_message = loading_message , for_ui = for_ui ) :
2023-05-11 20:37:04 +02:00
yield history
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-09-21 22:19:32 +02:00
def character_is_loaded ( state , raise_exception = False ) :
if state [ ' mode ' ] in [ ' chat ' , ' chat-instruct ' ] and state [ ' name2 ' ] == ' ' :
logger . error ( ' It looks like no character is loaded. Please load one under Parameters > Character. ' )
if raise_exception :
raise ValueError
return False
else :
return True
2023-08-14 16:46:07 +02:00
def generate_chat_reply_wrapper ( text , state , regenerate = False , _continue = False ) :
2023-09-21 22:19:32 +02:00
'''
Same as above but returns HTML for the UI
'''
if not character_is_loaded ( state ) :
return
2023-08-14 16:46:07 +02:00
if state [ ' start_with ' ] != ' ' and not _continue :
2023-06-11 17:19:18 +02:00
if regenerate :
2023-07-04 08:32:02 +02:00
text , state [ ' history ' ] = remove_last_message ( state [ ' history ' ] )
2023-06-05 16:56:03 +02:00
regenerate = False
2023-06-02 18:58:08 +02:00
_continue = True
2023-07-04 08:32:02 +02:00
send_dummy_message ( text , state )
2023-08-14 16:46:07 +02:00
send_dummy_reply ( state [ ' start_with ' ] , state )
2023-06-02 18:58:08 +02:00
2023-12-12 22:00:38 +01:00
for i , history in enumerate ( generate_chat_reply ( text , state , regenerate , _continue , loading_message = True , for_ui = True ) ) :
2023-12-04 02:45:50 +01:00
yield chat_html_wrapper ( history , state [ ' name1 ' ] , state [ ' name2 ' ] , state [ ' mode ' ] , state [ ' chat_style ' ] , state [ ' character_menu ' ] ) , history
2023-04-10 01:04:16 +02:00
2023-07-04 05:03:30 +02:00
def remove_last_message ( history ) :
if len ( history [ ' visible ' ] ) > 0 and history [ ' internal ' ] [ - 1 ] [ 0 ] != ' <|BEGIN-VISIBLE-CHAT|> ' :
2023-08-18 18:25:51 +02:00
last = history [ ' visible ' ] . pop ( )
history [ ' internal ' ] . pop ( )
2023-02-23 16:05:25 +01:00
else :
last = [ ' ' , ' ' ]
2023-03-12 19:23:33 +01:00
2023-08-19 14:29:08 +02:00
return html . unescape ( last [ 0 ] ) , history
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-07-04 05:03:30 +02:00
def send_last_reply_to_input ( history ) :
2023-08-18 22:04:45 +02:00
if len ( history [ ' visible ' ] ) > 0 :
2023-08-19 14:29:08 +02:00
return html . unescape ( history [ ' visible ' ] [ - 1 ] [ 1 ] )
2023-02-23 16:05:25 +01:00
else :
return ' '
2023-04-07 05:15:45 +02:00
2023-07-04 05:03:30 +02:00
def replace_last_reply ( text , state ) :
history = state [ ' history ' ]
2023-08-10 22:14:48 +02:00
if len ( text . strip ( ) ) == 0 :
return history
elif len ( history [ ' visible ' ] ) > 0 :
2023-08-16 18:23:29 +02:00
history [ ' visible ' ] [ - 1 ] [ 1 ] = html . escape ( text )
2023-08-13 06:12:15 +02:00
history [ ' internal ' ] [ - 1 ] [ 1 ] = apply_extensions ( ' input ' , text , state , is_chat = True )
2023-07-04 05:03:30 +02:00
return history
2023-02-23 16:05:25 +01:00
2023-07-04 05:03:30 +02:00
def send_dummy_message ( text , state ) :
history = state [ ' history ' ]
2023-08-16 18:23:29 +02:00
history [ ' visible ' ] . append ( [ html . escape ( text ) , ' ' ] )
2023-08-13 06:12:15 +02:00
history [ ' internal ' ] . append ( [ apply_extensions ( ' input ' , text , state , is_chat = True ) , ' ' ] )
2023-07-04 05:03:30 +02:00
return history
2023-04-12 03:21:41 +02:00
2023-07-04 05:03:30 +02:00
def send_dummy_reply ( text , state ) :
history = state [ ' history ' ]
if len ( history [ ' visible ' ] ) > 0 and not history [ ' visible ' ] [ - 1 ] [ 1 ] == ' ' :
history [ ' visible ' ] . append ( [ ' ' , ' ' ] )
history [ ' internal ' ] . append ( [ ' ' , ' ' ] )
2023-04-24 05:47:40 +02:00
2023-08-16 18:23:29 +02:00
history [ ' visible ' ] [ - 1 ] [ 1 ] = html . escape ( text )
2023-08-13 06:12:15 +02:00
history [ ' internal ' ] [ - 1 ] [ 1 ] = apply_extensions ( ' input ' , text , state , is_chat = True )
2023-07-04 05:03:30 +02:00
return history
2023-02-23 19:11:18 +01:00
2023-04-07 05:15:45 +02:00
2023-12-04 02:45:50 +01:00
def redraw_html ( history , name1 , name2 , mode , style , character , reset_cache = False ) :
return chat_html_wrapper ( history , name1 , name2 , mode , style , character , reset_cache = reset_cache )
2023-09-21 22:19:32 +02:00
def start_new_chat ( state ) :
2023-07-04 05:03:30 +02:00
mode = state [ ' mode ' ]
2023-09-21 22:19:32 +02:00
history = { ' internal ' : [ ] , ' visible ' : [ ] }
2023-04-03 17:16:15 +02:00
2023-05-08 21:41:21 +02:00
if mode != ' instruct ' :
2023-09-21 22:19:32 +02:00
greeting = replace_character_names ( state [ ' greeting ' ] , state [ ' name1 ' ] , state [ ' name2 ' ] )
2023-05-08 21:41:21 +02:00
if greeting != ' ' :
2023-07-04 05:03:30 +02:00
history [ ' internal ' ] + = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , greeting ] ]
2023-08-13 06:12:15 +02:00
history [ ' visible ' ] + = [ [ ' ' , apply_extensions ( ' output ' , greeting , state , is_chat = True ) ] ]
2023-04-11 16:46:30 +02:00
2023-09-21 22:19:32 +02:00
unique_id = datetime . now ( ) . strftime ( ' % Y % m %d - % H- % M- % S ' )
save_history ( history , unique_id , state [ ' character_menu ' ] , state [ ' mode ' ] )
2023-07-04 05:03:30 +02:00
return history
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-09-21 22:19:32 +02:00
def get_history_file_path ( unique_id , character , mode ) :
if mode == ' instruct ' :
p = Path ( f ' logs/instruct/ { unique_id } .json ' )
else :
p = Path ( f ' logs/chat/ { character } / { unique_id } .json ' )
2023-02-23 16:05:25 +01:00
2023-09-21 22:19:32 +02:00
return p
def save_history ( history , unique_id , character , mode ) :
if shared . args . multi_user :
return
2023-02-23 16:05:25 +01:00
2023-09-21 22:19:32 +02:00
p = get_history_file_path ( unique_id , character , mode )
2023-08-05 18:47:16 +02:00
if not p . parent . is_dir ( ) :
p . parent . mkdir ( parents = True )
2023-08-09 16:47:19 +02:00
2023-07-04 05:03:30 +02:00
with open ( p , ' w ' , encoding = ' utf-8 ' ) as f :
f . write ( json . dumps ( history , indent = 4 ) )
2023-02-23 16:05:25 +01:00
2023-04-24 05:47:40 +02:00
2023-09-22 03:53:03 +02:00
def rename_history ( old_id , new_id , character , mode ) :
if shared . args . multi_user :
return
old_p = get_history_file_path ( old_id , character , mode )
new_p = get_history_file_path ( new_id , character , mode )
if new_p . parent != old_p . parent :
2024-02-22 16:57:06 +01:00
logger . error ( f " The following path is not allowed: \" { new_p } \" . " )
2023-09-22 03:53:03 +02:00
elif new_p == old_p :
logger . info ( " The provided path is identical to the old one. " )
else :
2024-02-22 16:57:06 +01:00
logger . info ( f " Renaming \" { old_p } \" to \" { new_p } \" " )
2023-09-22 03:53:03 +02:00
old_p . rename ( new_p )
2023-09-21 22:19:32 +02:00
def find_all_histories ( state ) :
if shared . args . multi_user :
return [ ' ' ]
2023-02-23 16:05:25 +01:00
2023-09-21 22:19:32 +02:00
if state [ ' mode ' ] == ' instruct ' :
paths = Path ( ' logs/instruct ' ) . glob ( ' *.json ' )
else :
character = state [ ' character_menu ' ]
2023-04-24 05:47:40 +02:00
2023-09-21 22:19:32 +02:00
# Handle obsolete filenames and paths
old_p = Path ( f ' logs/ { character } _persistent.json ' )
new_p = Path ( f ' logs/persistent_ { character } .json ' )
if old_p . exists ( ) :
2024-02-22 16:57:06 +01:00
logger . warning ( f " Renaming \" { old_p } \" to \" { new_p } \" " )
2023-09-21 22:19:32 +02:00
old_p . rename ( new_p )
2024-02-22 16:57:06 +01:00
2023-09-21 22:19:32 +02:00
if new_p . exists ( ) :
unique_id = datetime . now ( ) . strftime ( ' % Y % m %d - % H- % M- % S ' )
p = get_history_file_path ( unique_id , character , state [ ' mode ' ] )
2024-02-22 16:57:06 +01:00
logger . warning ( f " Moving \" { new_p } \" to \" { p } \" " )
2023-09-21 22:19:32 +02:00
p . parent . mkdir ( exist_ok = True )
new_p . rename ( p )
2023-02-23 16:05:25 +01:00
2023-09-21 22:19:32 +02:00
paths = Path ( f ' logs/chat/ { character } ' ) . glob ( ' *.json ' )
2023-02-23 16:05:25 +01:00
2023-09-21 22:19:32 +02:00
histories = sorted ( paths , key = lambda x : x . stat ( ) . st_mtime , reverse = True )
histories = [ path . stem for path in histories ]
2023-04-07 05:15:45 +02:00
2023-09-21 22:19:32 +02:00
return histories
2023-08-03 06:13:16 +02:00
2023-04-24 05:47:40 +02:00
2023-09-21 22:19:32 +02:00
def load_latest_history ( state ) :
'''
Loads the latest history for the given character in chat or chat - instruct
mode , or the latest instruct history for instruct mode .
'''
if shared . args . multi_user :
return start_new_chat ( state )
histories = find_all_histories ( state )
2023-08-09 16:47:19 +02:00
2023-09-21 22:19:32 +02:00
if len ( histories ) > 0 :
2024-01-04 02:39:41 +01:00
history = load_history ( histories [ 0 ] , state [ ' character_menu ' ] , state [ ' mode ' ] )
2023-09-21 22:19:32 +02:00
else :
history = start_new_chat ( state )
return history
2024-01-09 06:22:37 +01:00
def load_history_after_deletion ( state , idx ) :
'''
Loads the latest history for the given character in chat or chat - instruct
mode , or the latest instruct history for instruct mode .
'''
if shared . args . multi_user :
return start_new_chat ( state )
histories = find_all_histories ( state )
2024-01-09 15:55:18 +01:00
idx = min ( int ( idx ) , len ( histories ) - 1 )
2024-01-10 01:27:50 +01:00
idx = max ( 0 , idx )
2024-01-09 06:22:37 +01:00
if len ( histories ) > 0 :
2024-01-09 15:24:27 +01:00
history = load_history ( histories [ idx ] , state [ ' character_menu ' ] , state [ ' mode ' ] )
2024-01-09 06:22:37 +01:00
else :
history = start_new_chat ( state )
histories = find_all_histories ( state )
2024-01-09 15:24:27 +01:00
return history , gr . update ( choices = histories , value = histories [ idx ] )
def update_character_menu_after_deletion ( idx ) :
characters = utils . get_available_characters ( )
idx = min ( int ( idx ) , len ( characters ) - 1 )
2024-01-10 01:27:50 +01:00
idx = max ( 0 , idx )
2024-01-09 15:24:27 +01:00
return gr . update ( choices = characters , value = characters [ idx ] )
2024-01-09 06:22:37 +01:00
2023-09-21 22:19:32 +02:00
def load_history ( unique_id , character , mode ) :
p = get_history_file_path ( unique_id , character , mode )
f = json . loads ( open ( p , ' rb ' ) . read ( ) )
if ' internal ' in f and ' visible ' in f :
history = f
else :
history = {
' internal ' : f [ ' data ' ] ,
' visible ' : f [ ' data_visible ' ]
}
2023-08-09 16:47:19 +02:00
2023-09-21 22:19:32 +02:00
return history
def load_history_json ( file , history ) :
try :
file = file . decode ( ' utf-8 ' )
f = json . loads ( file )
2023-07-04 05:03:30 +02:00
if ' internal ' in f and ' visible ' in f :
history = f
2023-07-04 07:19:28 +02:00
else :
2023-09-21 22:19:32 +02:00
history = {
' internal ' : f [ ' data ' ] ,
' visible ' : f [ ' data_visible ' ]
}
2023-06-06 17:57:13 +02:00
2023-09-21 22:19:32 +02:00
return history
except :
return history
def delete_history ( unique_id , character , mode ) :
p = get_history_file_path ( unique_id , character , mode )
delete_file ( p )
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-04-03 02:54:29 +02:00
def replace_character_names ( text , name1 , name2 ) :
text = text . replace ( ' {{ user}} ' , name1 ) . replace ( ' {{ char}} ' , name2 )
return text . replace ( ' <USER> ' , name1 ) . replace ( ' <BOT> ' , name2 )
2023-04-07 05:15:45 +02:00
2023-04-05 03:28:49 +02:00
def generate_pfp_cache ( character ) :
2024-01-04 04:27:26 +01:00
cache_folder = Path ( shared . args . disk_cache_dir )
2023-04-05 03:28:49 +02:00
if not cache_folder . exists ( ) :
cache_folder . mkdir ( )
for path in [ Path ( f " characters/ { character } . { extension } " ) for extension in [ ' png ' , ' jpg ' , ' jpeg ' ] ] :
if path . exists ( ) :
2023-11-19 06:05:17 +01:00
original_img = Image . open ( path )
2024-01-04 16:49:40 +01:00
original_img . save ( Path ( f ' { cache_folder } /pfp_character.png ' ) , format = ' PNG ' )
2023-11-19 06:05:17 +01:00
thumb = make_thumbnail ( original_img )
2024-01-04 16:49:40 +01:00
thumb . save ( Path ( f ' { cache_folder } /pfp_character_thumb.png ' ) , format = ' PNG ' )
2023-11-19 06:05:17 +01:00
return thumb
2023-04-24 05:47:40 +02:00
2023-04-05 03:28:49 +02:00
return None
2023-04-07 05:15:45 +02:00
2023-12-12 21:23:14 +01:00
def load_character ( character , name1 , name2 ) :
context = greeting = " "
2023-04-05 16:49:59 +02:00
greeting_field = ' greeting '
2023-04-05 03:28:49 +02:00
picture = None
2023-09-21 22:19:32 +02:00
filepath = None
for extension in [ " yml " , " yaml " , " json " ] :
2023-12-12 21:23:14 +01:00
filepath = Path ( f ' characters/ { character } . { extension } ' )
2023-09-21 22:19:32 +02:00
if filepath . exists ( ) :
break
2023-04-17 18:52:23 +02:00
2023-09-21 22:19:32 +02:00
if filepath is None or not filepath . exists ( ) :
2023-12-12 21:23:14 +01:00
logger . error ( f " Could not find the character \" { character } \" inside characters/. No character has been loaded. " )
2023-09-21 22:19:32 +02:00
raise ValueError
file_contents = open ( filepath , ' r ' , encoding = ' utf-8 ' ) . read ( )
data = json . loads ( file_contents ) if extension == " json " else yaml . safe_load ( file_contents )
2024-01-04 04:27:26 +01:00
cache_folder = Path ( shared . args . disk_cache_dir )
2023-09-21 22:19:32 +02:00
2024-01-04 16:49:40 +01:00
for path in [ Path ( f " { cache_folder } /pfp_character.png " ) , Path ( f " { cache_folder } /pfp_character_thumb.png " ) ] :
2023-12-12 21:23:14 +01:00
if path . exists ( ) :
2023-11-19 06:05:17 +01:00
path . unlink ( )
2023-09-21 22:19:32 +02:00
picture = generate_pfp_cache ( character )
# Finding the bot's name
for k in [ ' name ' , ' bot ' , ' <|bot|> ' , ' char_name ' ] :
if k in data and data [ k ] != ' ' :
name2 = data [ k ]
break
# Find the user name (if any)
for k in [ ' your_name ' , ' user ' , ' <|user|> ' ] :
if k in data and data [ k ] != ' ' :
name1 = data [ k ]
break
if ' context ' in data :
2023-12-12 21:23:14 +01:00
context = data [ ' context ' ] . strip ( )
2023-09-21 22:19:32 +02:00
elif " char_persona " in data :
context = build_pygmalion_style_context ( data )
greeting_field = ' char_greeting '
2023-11-08 00:02:58 +01:00
greeting = data . get ( greeting_field , greeting )
2023-12-12 21:23:14 +01:00
return name1 , name2 , picture , greeting , context
2023-02-23 16:05:25 +01:00
2023-12-12 21:23:14 +01:00
def load_instruction_template ( template ) :
2024-02-16 18:21:17 +01:00
if template == ' None ' :
return ' '
2023-12-12 21:23:14 +01:00
for filepath in [ Path ( f ' instruction-templates/ { template } .yaml ' ) , Path ( ' instruction-templates/Alpaca.yaml ' ) ] :
if filepath . exists ( ) :
break
else :
return ' '
file_contents = open ( filepath , ' r ' , encoding = ' utf-8 ' ) . read ( )
data = yaml . safe_load ( file_contents )
if ' instruction_template ' in data :
return data [ ' instruction_template ' ]
else :
return jinja_template_from_old_format ( data )
2023-04-03 17:16:15 +02:00
2023-04-07 05:15:45 +02:00
2023-05-23 05:50:58 +02:00
@functools.cache
2023-12-12 21:23:14 +01:00
def load_character_memoized ( character , name1 , name2 ) :
return load_character ( character , name1 , name2 )
@functools.cache
def load_instruction_template_memoized ( template ) :
return load_instruction_template ( template )
2023-05-23 05:50:58 +02:00
2023-07-30 20:25:38 +02:00
def upload_character ( file , img , tavern = False ) :
2023-09-19 22:13:13 +02:00
decoded_file = file if isinstance ( file , str ) else file . decode ( ' utf-8 ' )
2023-07-30 20:25:38 +02:00
try :
data = json . loads ( decoded_file )
except :
data = yaml . safe_load ( decoded_file )
if ' char_name ' in data :
name = data [ ' char_name ' ]
greeting = data [ ' char_greeting ' ]
context = build_pygmalion_style_context ( data )
yaml_data = generate_character_yaml ( name , greeting , context )
else :
2023-07-30 20:42:30 +02:00
name = data [ ' name ' ]
2023-07-30 20:25:38 +02:00
yaml_data = generate_character_yaml ( data [ ' name ' ] , data [ ' greeting ' ] , data [ ' context ' ] )
2023-07-30 20:42:30 +02:00
outfile_name = name
2023-02-23 16:05:25 +01:00
i = 1
2023-07-30 20:25:38 +02:00
while Path ( f ' characters/ { outfile_name } .yaml ' ) . exists ( ) :
2023-07-30 20:42:30 +02:00
outfile_name = f ' { name } _ { i : 03d } '
2023-02-23 16:05:25 +01:00
i + = 1
2023-04-24 05:47:40 +02:00
2023-07-30 20:25:38 +02:00
with open ( Path ( f ' characters/ { outfile_name } .yaml ' ) , ' w ' , encoding = ' utf-8 ' ) as f :
f . write ( yaml_data )
2023-04-24 05:47:40 +02:00
2023-02-23 16:05:25 +01:00
if img is not None :
img . save ( Path ( f ' characters/ { outfile_name } .png ' ) )
2023-04-24 05:47:40 +02:00
2023-07-30 20:25:38 +02:00
logger . info ( f ' New character saved to " characters/ { outfile_name } .yaml " . ' )
2023-07-04 05:03:30 +02:00
return gr . update ( value = outfile_name , choices = get_available_characters ( ) )
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-07-30 20:25:38 +02:00
def build_pygmalion_style_context ( data ) :
context = " "
if ' char_persona ' in data and data [ ' char_persona ' ] != ' ' :
context + = f " { data [ ' char_name ' ] } ' s Persona: { data [ ' char_persona ' ] } \n "
if ' world_scenario ' in data and data [ ' world_scenario ' ] != ' ' :
context + = f " Scenario: { data [ ' world_scenario ' ] } \n "
2023-08-08 04:44:59 +02:00
if ' example_dialogue ' in data and data [ ' example_dialogue ' ] != ' ' :
context + = f " { data [ ' example_dialogue ' ] . strip ( ) } \n "
2023-07-30 20:25:38 +02:00
context = f " { context . strip ( ) } \n "
return context
2023-06-29 19:56:25 +02:00
def upload_tavern_character ( img , _json ) :
2023-07-30 20:25:38 +02:00
_json = { ' char_name ' : _json [ ' name ' ] , ' char_persona ' : _json [ ' description ' ] , ' char_greeting ' : _json [ ' first_mes ' ] , ' example_dialogue ' : _json [ ' mes_example ' ] , ' world_scenario ' : _json [ ' scenario ' ] }
2023-07-30 20:42:30 +02:00
return upload_character ( json . dumps ( _json ) , img , tavern = True )
2023-06-29 19:56:25 +02:00
def check_tavern_character ( img ) :
if " chara " not in img . info :
return " Not a TavernAI card " , None , None , gr . update ( interactive = False )
2023-07-30 20:25:38 +02:00
2023-07-31 19:46:02 +02:00
decoded_string = base64 . b64decode ( img . info [ ' chara ' ] ) . replace ( b ' \\ r \\ n ' , b ' \\ n ' )
2023-06-29 19:56:25 +02:00
_json = json . loads ( decoded_string )
if " data " in _json :
_json = _json [ " data " ]
2023-07-30 20:25:38 +02:00
2023-06-29 19:56:25 +02:00
return _json [ ' name ' ] , _json [ ' description ' ] , _json , gr . update ( interactive = True )
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-05-11 20:37:04 +02:00
def upload_your_profile_picture ( img ) :
2024-01-04 04:27:26 +01:00
cache_folder = Path ( shared . args . disk_cache_dir )
2023-04-05 03:28:49 +02:00
if not cache_folder . exists ( ) :
cache_folder . mkdir ( )
2023-04-07 05:15:45 +02:00
if img is None :
2024-01-04 16:49:40 +01:00
if Path ( f " { cache_folder } /pfp_me.png " ) . exists ( ) :
Path ( f " { cache_folder } /pfp_me.png " ) . unlink ( )
2023-04-05 03:28:49 +02:00
else :
2023-04-05 04:03:58 +02:00
img = make_thumbnail ( img )
2024-01-04 16:49:40 +01:00
img . save ( Path ( f ' { cache_folder } /pfp_me.png ' ) )
logger . info ( f ' Profile picture saved to " { cache_folder } /pfp_me.png " ' )
2023-05-21 02:48:45 +02:00
2023-06-11 17:19:18 +02:00
def generate_character_yaml ( name , greeting , context ) :
data = {
' name ' : name ,
' greeting ' : greeting ,
' context ' : context ,
}
2023-05-21 02:48:45 +02:00
2023-06-11 17:19:18 +02:00
data = { k : v for k , v in data . items ( ) if v } # Strip falsy
2023-07-30 20:25:38 +02:00
return yaml . dump ( data , sort_keys = False , width = float ( " inf " ) )
2023-05-21 02:48:45 +02:00
2023-12-12 21:23:14 +01:00
def generate_instruction_template_yaml ( instruction_template ) :
2023-05-21 02:48:45 +02:00
data = {
2023-12-12 21:23:14 +01:00
' instruction_template ' : instruction_template
2023-05-21 02:48:45 +02:00
}
2023-12-12 21:23:14 +01:00
return my_yaml_output ( data )
2023-06-11 17:19:18 +02:00
def save_character ( name , greeting , context , picture , filename ) :
if filename == " " :
logger . error ( " The filename is empty, so the character will not be saved. " )
return
2023-05-21 02:48:45 +02:00
2023-06-11 17:19:18 +02:00
data = generate_character_yaml ( name , greeting , context )
filepath = Path ( f ' characters/ { filename } .yaml ' )
save_file ( filepath , data )
path_to_img = Path ( f ' characters/ { filename } .png ' )
if picture is not None :
2023-05-21 02:48:45 +02:00
picture . save ( path_to_img )
2023-06-11 17:19:18 +02:00
logger . info ( f ' Saved { path_to_img } . ' )
2023-05-21 02:48:45 +02:00
def delete_character ( name , instruct = False ) :
for extension in [ " yml " , " yaml " , " json " ] :
2023-06-11 17:19:18 +02:00
delete_file ( Path ( f ' characters/ { name } . { extension } ' ) )
2023-05-21 02:48:45 +02:00
2023-06-11 17:19:18 +02:00
delete_file ( Path ( f ' characters/ { name } .png ' ) )
2023-12-12 21:23:14 +01:00
def jinja_template_from_old_format ( params , verbose = False ) :
MASTER_TEMPLATE = """
2023-12-19 22:16:52 +01:00
{ % - set ns = namespace ( found = false ) - % }
2023-12-12 21:23:14 +01:00
{ % - for message in messages - % }
{ % - if message [ ' role ' ] == ' system ' - % }
2023-12-19 22:16:52 +01:00
{ % - set ns . found = true - % }
2023-12-12 21:23:14 +01:00
{ % - endif - % }
{ % - endfor - % }
2023-12-19 22:16:52 +01:00
{ % - if not ns . found - % }
2023-12-12 21:23:14 +01:00
{ { - ' <|PRE-SYSTEM|> ' + ' <|SYSTEM-MESSAGE|> ' + ' <|POST-SYSTEM|> ' - } }
{ % - endif % }
{ % - for message in messages % }
{ % - if message [ ' role ' ] == ' system ' - % }
{ { - ' <|PRE-SYSTEM|> ' + message [ ' content ' ] + ' <|POST-SYSTEM|> ' - } }
{ % - else - % }
{ % - if message [ ' role ' ] == ' user ' - % }
{ { - ' <|PRE-USER|> ' + message [ ' content ' ] + ' <|POST-USER|> ' - } }
{ % - else - % }
{ { - ' <|PRE-ASSISTANT|> ' + message [ ' content ' ] + ' <|POST-ASSISTANT|> ' - } }
{ % - endif - % }
{ % - endif - % }
{ % - endfor - % }
{ % - if add_generation_prompt - % }
{ { - ' <|PRE-ASSISTANT-GENERATE|> ' - } }
{ % - endif - % }
"""
if ' context ' in params and ' <|system-message|> ' in params [ ' context ' ] :
pre_system = params [ ' context ' ] . split ( ' <|system-message|> ' ) [ 0 ]
post_system = params [ ' context ' ] . split ( ' <|system-message|> ' ) [ 1 ]
else :
pre_system = ' '
post_system = ' '
pre_user = params [ ' turn_template ' ] . split ( ' <|user-message|> ' ) [ 0 ] . replace ( ' <|user|> ' , params [ ' user ' ] )
post_user = params [ ' turn_template ' ] . split ( ' <|user-message|> ' ) [ 1 ] . split ( ' <|bot|> ' ) [ 0 ]
pre_assistant = ' <|bot|> ' + params [ ' turn_template ' ] . split ( ' <|bot-message|> ' ) [ 0 ] . split ( ' <|bot|> ' ) [ 1 ]
pre_assistant = pre_assistant . replace ( ' <|bot|> ' , params [ ' bot ' ] )
post_assistant = params [ ' turn_template ' ] . split ( ' <|bot-message|> ' ) [ 1 ]
2023-12-13 19:46:23 +01:00
def preprocess ( string ) :
return string . replace ( ' \n ' , ' \\ n ' ) . replace ( ' \' ' , ' \\ \' ' )
pre_system = preprocess ( pre_system )
post_system = preprocess ( post_system )
pre_user = preprocess ( pre_user )
post_user = preprocess ( post_user )
pre_assistant = preprocess ( pre_assistant )
post_assistant = preprocess ( post_assistant )
2023-12-12 21:23:14 +01:00
if verbose :
print (
' \n ' ,
repr ( pre_system ) + ' \n ' ,
repr ( post_system ) + ' \n ' ,
repr ( pre_user ) + ' \n ' ,
repr ( post_user ) + ' \n ' ,
repr ( pre_assistant ) + ' \n ' ,
repr ( post_assistant ) + ' \n ' ,
)
result = MASTER_TEMPLATE
if ' system_message ' in params :
2023-12-13 19:46:23 +01:00
result = result . replace ( ' <|SYSTEM-MESSAGE|> ' , preprocess ( params [ ' system_message ' ] ) )
2023-12-12 21:23:14 +01:00
else :
result = result . replace ( ' <|SYSTEM-MESSAGE|> ' , ' ' )
result = result . replace ( ' <|PRE-SYSTEM|> ' , pre_system )
result = result . replace ( ' <|POST-SYSTEM|> ' , post_system )
result = result . replace ( ' <|PRE-USER|> ' , pre_user )
result = result . replace ( ' <|POST-USER|> ' , post_user )
result = result . replace ( ' <|PRE-ASSISTANT|> ' , pre_assistant )
2023-12-13 19:19:39 +01:00
result = result . replace ( ' <|PRE-ASSISTANT-GENERATE|> ' , pre_assistant . rstrip ( ' ' ) )
2023-12-12 21:23:14 +01:00
result = result . replace ( ' <|POST-ASSISTANT|> ' , post_assistant )
result = result . strip ( )
return result
def my_yaml_output ( data ) :
'''
pyyaml is very inconsistent with multiline strings .
for simple instruction template outputs , this is enough .
'''
result = " "
for k in data :
result + = k + " : |- \n "
for line in data [ k ] . splitlines ( ) :
result + = " " + line . rstrip ( ' ' ) + " \n "
return result