llama.swiftui : fix infinite loop, ouput timings, buff UI (#4674)

* fix infinite loop

* slight UI simplification, clearer UX

* clearer UI text, add timings to completion log
This commit is contained in:
Peter Sugihara 2023-12-29 05:58:56 -08:00 committed by GitHub
parent c8255f8a6b
commit afd997ab60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 37 deletions

View File

@ -1,5 +1,7 @@
import Foundation import Foundation
// To use this in your own project, add llama.cpp as a swift package dependency
// and uncomment this import line.
// import llama // import llama
enum LlamaError: Error { enum LlamaError: Error {

View File

@ -4,6 +4,7 @@ import Foundation
class LlamaState: ObservableObject { class LlamaState: ObservableObject {
@Published var messageLog = "" @Published var messageLog = ""
@Published var cacheCleared = false @Published var cacheCleared = false
let NS_PER_S = 1_000_000_000.0
private var llamaContext: LlamaContext? private var llamaContext: LlamaContext?
private var defaultModelUrl: URL? { private var defaultModelUrl: URL? {
@ -20,12 +21,12 @@ class LlamaState: ObservableObject {
} }
func loadModel(modelUrl: URL?) throws { func loadModel(modelUrl: URL?) throws {
messageLog += "Loading model...\n"
if let modelUrl { if let modelUrl {
messageLog += "Loading model...\n"
llamaContext = try LlamaContext.create_context(path: modelUrl.path()) llamaContext = try LlamaContext.create_context(path: modelUrl.path())
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n" messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
} else { } else {
messageLog += "Could not locate model\n" messageLog += "Load a model from the list below\n"
} }
} }
@ -34,15 +35,29 @@ class LlamaState: ObservableObject {
return return
} }
let t_start = DispatchTime.now().uptimeNanoseconds
await llamaContext.completion_init(text: text) await llamaContext.completion_init(text: text)
let t_heat_end = DispatchTime.now().uptimeNanoseconds
let t_heat = Double(t_heat_end - t_start) / NS_PER_S
messageLog += "\(text)" messageLog += "\(text)"
while await llamaContext.n_cur <= llamaContext.n_len { while await llamaContext.n_cur < llamaContext.n_len {
let result = await llamaContext.completion_loop() let result = await llamaContext.completion_loop()
messageLog += "\(result)" messageLog += "\(result)"
} }
let t_end = DispatchTime.now().uptimeNanoseconds
let t_generation = Double(t_end - t_heat_end) / NS_PER_S
let tokens_per_second = Double(await llamaContext.n_len) / t_generation
await llamaContext.clear() await llamaContext.clear()
messageLog += "\n\ndone\n" messageLog += """
\n
Done
Heat up took \(t_heat)s
Generated \(tokens_per_second) t/s\n
"""
} }
func bench() async { func bench() async {
@ -56,10 +71,10 @@ class LlamaState: ObservableObject {
messageLog += await llamaContext.model_info() + "\n" messageLog += await llamaContext.model_info() + "\n"
let t_start = DispatchTime.now().uptimeNanoseconds let t_start = DispatchTime.now().uptimeNanoseconds
await llamaContext.bench(pp: 8, tg: 4, pl: 1) // heat up let _ = await llamaContext.bench(pp: 8, tg: 4, pl: 1) // heat up
let t_end = DispatchTime.now().uptimeNanoseconds let t_end = DispatchTime.now().uptimeNanoseconds
let t_heat = Double(t_end - t_start) / 1_000_000_000.0 let t_heat = Double(t_end - t_start) / NS_PER_S
messageLog += "Heat up time: \(t_heat) seconds, please wait...\n" messageLog += "Heat up time: \(t_heat) seconds, please wait...\n"
// if more than 5 seconds, then we're probably running on a slow device // if more than 5 seconds, then we're probably running on a slow device

View File

@ -42,46 +42,27 @@ struct ContentView: View {
Button("Send") { Button("Send") {
sendText() sendText()
} }
.padding(8)
.background(Color.blue)
.foregroundColor(.white)
.cornerRadius(8)
Button("Bench") { Button("Bench") {
bench() bench()
} }
.padding(8)
.background(Color.blue)
.foregroundColor(.white)
.cornerRadius(8)
Button("Clear") { Button("Clear") {
clear() clear()
} }
.padding(8)
.background(Color.blue)
.foregroundColor(.white)
.cornerRadius(8)
Button("Copy") { Button("Copy") {
UIPasteboard.general.string = llamaState.messageLog UIPasteboard.general.string = llamaState.messageLog
} }
.padding(8) }.buttonStyle(.bordered)
.background(Color.blue)
.foregroundColor(.white)
.cornerRadius(8)
}
VStack { VStack(alignment: .leading) {
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)", modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true", modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf" filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
) )
.font(.system(size: 12))
.padding(.top, 4)
.frame(maxWidth: .infinity, alignment: .leading)
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
@ -89,7 +70,6 @@ struct ContentView: View {
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true", modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf" filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
) )
.font(.system(size: 12))
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
@ -97,8 +77,6 @@ struct ContentView: View {
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true", modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
filename: "tinyllama-1.1b-f16.gguf" filename: "tinyllama-1.1b-f16.gguf"
) )
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
@ -106,7 +84,6 @@ struct ContentView: View {
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true", modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
filename: "phi-2-q4_0.gguf" filename: "phi-2-q4_0.gguf"
) )
.font(.system(size: 12))
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
@ -114,8 +91,6 @@ struct ContentView: View {
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true", modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
filename: "phi-2-q8_0.gguf" filename: "phi-2-q8_0.gguf"
) )
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
@ -123,15 +98,15 @@ struct ContentView: View {
modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true", modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
filename: "mistral-7b-v0.1.Q4_0.gguf" filename: "mistral-7b-v0.1.Q4_0.gguf"
) )
.font(.system(size: 12))
Button("Clear downloaded models") { Button("Clear downloaded models") {
ContentView.cleanupModelCaches() ContentView.cleanupModelCaches()
llamaState.cacheCleared = true llamaState.cacheCleared = true
} }
.padding(8)
.font(.system(size: 12))
} }
.padding(.top, 4)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
} }
.padding() .padding()
} }

View File

@ -93,7 +93,7 @@ struct DownloadButton: View {
print("Error: \(err.localizedDescription)") print("Error: \(err.localizedDescription)")
} }
}) { }) {
Text("\(modelName) (Downloaded)") Text("Load \(modelName)")
} }
} else { } else {
Text("Unknown status") Text("Unknown status")