use javascript generators as much cleaner API

Also add ways to access completion as promise and EventSource
This commit is contained in:
Tobias Lütke 2023-07-05 15:03:01 -04:00
parent 983b555e9d
commit a76ce02a6c
No known key found for this signature in database
GPG Key ID: 1FC0DBB14164709A
2 changed files with 106 additions and 22 deletions

View File

@ -5,20 +5,29 @@ const paramDefaults = {
stop: ["</s>"] stop: ["</s>"]
}; };
/** let generation_settings = null;
* This function completes the input text using a llama dictionary.
* @param {object} params - The parameters for the completion request.
* @param {object} controller - an instance of AbortController if you need one, or null. // Completes the prompt as a generator. Recommended for most use cases.
* @param {function} callback - The callback function to call when the completion is done. //
* @returns {string} the completed text as a string. Ideally ignored, and you get at it via the callback. // Example:
*/ //
export const llamaComplete = async (params, controller, callback) => { // import { llama } from '/completion.js'
//
// const request = llama("Tell me a joke", {n_predict: 800})
// for await (const chunk of request) {
// document.write(chunk.data.content)
// }
//
export async function* llama(prompt, params = {}, config = {}) {
let controller = config.controller;
if (!controller) { if (!controller) {
controller = new AbortController(); controller = new AbortController();
} }
const completionParams = { ...paramDefaults, ...params };
// we use fetch directly here becasue the built in fetchEventSource does not support POST const completionParams = { ...paramDefaults, ...params, prompt };
const response = await fetch("/completion", { const response = await fetch("/completion", {
method: 'POST', method: 'POST',
body: JSON.stringify(completionParams), body: JSON.stringify(completionParams),
@ -36,7 +45,6 @@ export const llamaComplete = async (params, controller, callback) => {
let content = ""; let content = "";
try { try {
let cont = true; let cont = true;
while (cont) { while (cont) {
@ -59,10 +67,8 @@ export const llamaComplete = async (params, controller, callback) => {
result.data = JSON.parse(result.data); result.data = JSON.parse(result.data);
content += result.data.content; content += result.data.content;
// callack // yield
if (callback) { yield result;
cont = callback(result) != false;
}
// if we got a stop token from server, we will break here // if we got a stop token from server, we will break here
if (result.data.stop) { if (result.data.stop) {
@ -70,7 +76,9 @@ export const llamaComplete = async (params, controller, callback) => {
} }
} }
} catch (e) { } catch (e) {
if (e.name !== 'AbortError') {
console.error("llama error: ", e); console.error("llama error: ", e);
}
throw e; throw e;
} }
finally { finally {
@ -79,3 +87,79 @@ export const llamaComplete = async (params, controller, callback) => {
return content; return content;
} }
// Call llama, return an event target that you can subcribe to
//
// Example:
//
// import { llamaEventTarget } from '/completion.js'
//
// const conn = llamaEventTarget(prompt)
// conn.addEventListener("message", (chunk) => {
// document.write(chunk.detail.content)
// })
//
export const llamaEventTarget = (prompt, params = {}, config = {}) => {
const eventTarget = new EventTarget();
(async () => {
let content = "";
for await (const chunk of llama(prompt, params, config)) {
if (chunk.data) {
content += chunk.data.content;
eventTarget.dispatchEvent(new CustomEvent("message", { detail: chunk.data }));
}
if (chunk.data.generation_settings) {
eventTarget.dispatchEvent(new CustomEvent("generation_settings", { detail: chunk.data.generation_settings }));
}
if (chunk.data.timings) {
eventTarget.dispatchEvent(new CustomEvent("timings", { detail: chunk.data.timings }));
}
}
eventTarget.dispatchEvent(new CustomEvent("done", { detail: { content } }));
})();
return eventTarget;
}
// Call llama, return a promise that resolves to the completed text. This does not support streaming
//
// Example:
//
// llamaPromise(prompt).then((content) => {
// document.write(content)
// })
//
// or
//
// const content = await llamaPromise(prompt)
// document.write(content)
//
export const llamaPromise = (prompt, params = {}, config = {}) => {
return new Promise(async (resolve, reject) => {
let content = "";
try {
for await (const chunk of llama(prompt, params, config)) {
content += chunk.data.content;
}
resolve(content);
} catch (error) {
reject(error);
}
});
};
/**
* (deprecated)
*/
export const llamaComplete = async (params, controller, callback) => {
for await (const chunk of llama(params.prompt, params, { controller })) {
callback(chunk);
}
}
// Get the model info from the server. This is useful for getting the context window and so on.
export const llamaModelInfo = async () => {
if (!generation_settings) {
generation_settings = await fetch("/model.json").then(r => r.json());
}
return generation_settings;
}

View File

@ -106,7 +106,7 @@
html, h, signal, effect, computed, render, useSignal, useEffect, useRef html, h, signal, effect, computed, render, useSignal, useEffect, useRef
} from '/index.js'; } from '/index.js';
import { llamaComplete } from '/completion.js'; import { llama } from '/completion.js';
const session = signal({ const session = signal({
prompt: "This is a conversation between user and llama, a friendly chatbot. respond in markdown.", prompt: "This is a conversation between user and llama, a friendly chatbot. respond in markdown.",
@ -158,7 +158,7 @@
transcriptUpdate([...session.value.transcript, ["{{user}}", msg]]) transcriptUpdate([...session.value.transcript, ["{{user}}", msg]])
const payload = template(session.value.template, { const prompt = template(session.value.template, {
message: msg, message: msg,
history: session.value.transcript.flatMap(([name, message]) => template(session.value.historyTemplate, {name, message})).join("\n"), history: session.value.transcript.flatMap(([name, message]) => template(session.value.historyTemplate, {name, message})).join("\n"),
}); });
@ -168,13 +168,13 @@
const llamaParams = { const llamaParams = {
...params.value, ...params.value,
prompt: payload,
stop: ["</s>", template("{{char}}:"), template("{{user}}:")], stop: ["</s>", template("{{char}}:"), template("{{user}}:")],
} }
await llamaComplete(llamaParams, controller.value, (message) => { for await (const chunk of llama(prompt, llamaParams, { controller: controller.value })) {
const data = message.data; const data = chunk.data;
currentMessage += data.content; currentMessage += data.content;
// remove leading whitespace // remove leading whitespace
currentMessage = currentMessage.replace(/^\s+/, "") currentMessage = currentMessage.replace(/^\s+/, "")
@ -183,7 +183,7 @@
if (data.stop) { if (data.stop) {
console.log("-->", data, ' response was:', currentMessage, 'transcript state:', session.value.transcript); console.log("-->", data, ' response was:', currentMessage, 'transcript state:', session.value.transcript);
} }
}) }
controller.value = null; controller.value = null;
} }