llama.swiftui: fix end of generation bug (#8268)

* fix continuing generating blank lines after getting EOT token or EOS token from LLM

* change variable name to is_done (variable name suggested by ggerganov)

* minor : fix trailing whitespace

* minor : add space

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Huifeng Ou 2024-07-20 09:09:37 -04:00 committed by GitHub
parent c3776cacab
commit 69b9945b44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -26,11 +26,12 @@ actor LlamaContext {
private var context: OpaquePointer private var context: OpaquePointer
private var batch: llama_batch private var batch: llama_batch
private var tokens_list: [llama_token] private var tokens_list: [llama_token]
var is_done: Bool = false
/// This variable is used to store temporarily invalid cchars /// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar] private var temporary_invalid_cchars: [CChar]
var n_len: Int32 = 64 var n_len: Int32 = 1024
var n_cur: Int32 = 0 var n_cur: Int32 = 0
var n_decode: Int32 = 0 var n_decode: Int32 = 0
@ -160,6 +161,7 @@ actor LlamaContext {
if llama_token_is_eog(model, new_token_id) || n_cur == n_len { if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n") print("\n")
is_done = true
let new_token_str = String(cString: temporary_invalid_cchars + [0]) let new_token_str = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll() temporary_invalid_cchars.removeAll()
return new_token_str return new_token_str

View File

@ -132,7 +132,7 @@ class LlamaState: ObservableObject {
messageLog += "\(text)" messageLog += "\(text)"
Task.detached { Task.detached {
while await llamaContext.n_cur < llamaContext.n_len { while await !llamaContext.is_done {
let result = await llamaContext.completion_loop() let result = await llamaContext.completion_loop()
await MainActor.run { await MainActor.run {
self.messageLog += "\(result)" self.messageLog += "\(result)"