From a76ce02a6cbbc139215ae3843fca4199ee439adf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20L=C3=BCtke?= Date: Wed, 5 Jul 2023 15:03:01 -0400 Subject: [PATCH] use javascript generators as much cleaner API Also add ways to access completion as promise and EventSource --- examples/server/public/completion.js | 116 +++++++++++++++++++++++---- examples/server/public/index.html | 12 +-- 2 files changed, 106 insertions(+), 22 deletions(-) diff --git a/examples/server/public/completion.js b/examples/server/public/completion.js index 4f5005cfb..11ccc482c 100644 --- a/examples/server/public/completion.js +++ b/examples/server/public/completion.js @@ -5,20 +5,29 @@ const paramDefaults = { stop: [""] }; -/** - * This function completes the input text using a llama dictionary. - * @param {object} params - The parameters for the completion request. - * @param {object} controller - an instance of AbortController if you need one, or null. - * @param {function} callback - The callback function to call when the completion is done. - * @returns {string} the completed text as a string. Ideally ignored, and you get at it via the callback. - */ -export const llamaComplete = async (params, controller, callback) => { +let generation_settings = null; + + +// Completes the prompt as a generator. Recommended for most use cases. +// +// Example: +// +// import { llama } from '/completion.js' +// +// const request = llama("Tell me a joke", {n_predict: 800}) +// for await (const chunk of request) { +// document.write(chunk.data.content) +// } +// +export async function* llama(prompt, params = {}, config = {}) { + let controller = config.controller; + if (!controller) { controller = new AbortController(); } - const completionParams = { ...paramDefaults, ...params }; - // we use fetch directly here becasue the built in fetchEventSource does not support POST + const completionParams = { ...paramDefaults, ...params, prompt }; + const response = await fetch("/completion", { method: 'POST', body: JSON.stringify(completionParams), @@ -36,7 +45,6 @@ export const llamaComplete = async (params, controller, callback) => { let content = ""; try { - let cont = true; while (cont) { @@ -59,10 +67,8 @@ export const llamaComplete = async (params, controller, callback) => { result.data = JSON.parse(result.data); content += result.data.content; - // callack - if (callback) { - cont = callback(result) != false; - } + // yield + yield result; // if we got a stop token from server, we will break here if (result.data.stop) { @@ -70,7 +76,9 @@ export const llamaComplete = async (params, controller, callback) => { } } } catch (e) { - console.error("llama error: ", e); + if (e.name !== 'AbortError') { + console.error("llama error: ", e); + } throw e; } finally { @@ -79,3 +87,79 @@ export const llamaComplete = async (params, controller, callback) => { return content; } + +// Call llama, return an event target that you can subcribe to +// +// Example: +// +// import { llamaEventTarget } from '/completion.js' +// +// const conn = llamaEventTarget(prompt) +// conn.addEventListener("message", (chunk) => { +// document.write(chunk.detail.content) +// }) +// +export const llamaEventTarget = (prompt, params = {}, config = {}) => { + const eventTarget = new EventTarget(); + (async () => { + let content = ""; + for await (const chunk of llama(prompt, params, config)) { + if (chunk.data) { + content += chunk.data.content; + eventTarget.dispatchEvent(new CustomEvent("message", { detail: chunk.data })); + } + if (chunk.data.generation_settings) { + eventTarget.dispatchEvent(new CustomEvent("generation_settings", { detail: chunk.data.generation_settings })); + } + if (chunk.data.timings) { + eventTarget.dispatchEvent(new CustomEvent("timings", { detail: chunk.data.timings })); + } + } + eventTarget.dispatchEvent(new CustomEvent("done", { detail: { content } })); + })(); + return eventTarget; +} + +// Call llama, return a promise that resolves to the completed text. This does not support streaming +// +// Example: +// +// llamaPromise(prompt).then((content) => { +// document.write(content) +// }) +// +// or +// +// const content = await llamaPromise(prompt) +// document.write(content) +// +export const llamaPromise = (prompt, params = {}, config = {}) => { + return new Promise(async (resolve, reject) => { + let content = ""; + try { + for await (const chunk of llama(prompt, params, config)) { + content += chunk.data.content; + } + resolve(content); + } catch (error) { + reject(error); + } + }); +}; + +/** + * (deprecated) + */ +export const llamaComplete = async (params, controller, callback) => { + for await (const chunk of llama(params.prompt, params, { controller })) { + callback(chunk); + } +} + +// Get the model info from the server. This is useful for getting the context window and so on. +export const llamaModelInfo = async () => { + if (!generation_settings) { + generation_settings = await fetch("/model.json").then(r => r.json()); + } + return generation_settings; +} diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 6393e2e75..0b7b5f49b 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -106,7 +106,7 @@ html, h, signal, effect, computed, render, useSignal, useEffect, useRef } from '/index.js'; - import { llamaComplete } from '/completion.js'; + import { llama } from '/completion.js'; const session = signal({ prompt: "This is a conversation between user and llama, a friendly chatbot. respond in markdown.", @@ -158,7 +158,7 @@ transcriptUpdate([...session.value.transcript, ["{{user}}", msg]]) - const payload = template(session.value.template, { + const prompt = template(session.value.template, { message: msg, history: session.value.transcript.flatMap(([name, message]) => template(session.value.historyTemplate, {name, message})).join("\n"), }); @@ -168,13 +168,13 @@ const llamaParams = { ...params.value, - prompt: payload, stop: ["", template("{{char}}:"), template("{{user}}:")], } - await llamaComplete(llamaParams, controller.value, (message) => { - const data = message.data; + for await (const chunk of llama(prompt, llamaParams, { controller: controller.value })) { + const data = chunk.data; currentMessage += data.content; + // remove leading whitespace currentMessage = currentMessage.replace(/^\s+/, "") @@ -183,7 +183,7 @@ if (data.stop) { console.log("-->", data, ' response was:', currentMessage, 'transcript state:', session.value.transcript); } - }) + } controller.value = null; }