mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 16:10:42 +01:00
swift : fix llama-vocab api usage (#11645)
* swiftui : fix vocab api usage * batched.swift : fix vocab api usage
This commit is contained in:
parent
534c46b53c
commit
f117d84b48
@ -31,6 +31,11 @@ defer {
|
|||||||
llama_model_free(model)
|
llama_model_free(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
guard let vocab = llama_model_get_vocab(model) else {
|
||||||
|
print("Failed to get vocab")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
var tokens = tokenize(text: prompt, add_bos: true)
|
var tokens = tokenize(text: prompt, add_bos: true)
|
||||||
|
|
||||||
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
|
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
|
||||||
@ -41,7 +46,7 @@ context_params.n_batch = UInt32(max(n_len, n_parallel))
|
|||||||
context_params.n_threads = 8
|
context_params.n_threads = 8
|
||||||
context_params.n_threads_batch = 8
|
context_params.n_threads_batch = 8
|
||||||
|
|
||||||
let context = llama_new_context_with_model(model, context_params)
|
let context = llama_init_from_model(model, context_params)
|
||||||
guard context != nil else {
|
guard context != nil else {
|
||||||
print("Failed to initialize context")
|
print("Failed to initialize context")
|
||||||
exit(1)
|
exit(1)
|
||||||
@ -141,7 +146,7 @@ while n_cur <= n_len {
|
|||||||
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
|
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
|
||||||
|
|
||||||
// is it an end of stream? -> mark the stream as finished
|
// is it an end of stream? -> mark the stream as finished
|
||||||
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
|
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
|
||||||
i_batch[i] = -1
|
i_batch[i] = -1
|
||||||
// print("")
|
// print("")
|
||||||
if n_parallel > 1 {
|
if n_parallel > 1 {
|
||||||
@ -207,7 +212,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
|||||||
let utf8Count = text.utf8.count
|
let utf8Count = text.utf8.count
|
||||||
let n_tokens = utf8Count + (add_bos ? 1 : 0)
|
let n_tokens = utf8Count + (add_bos ? 1 : 0)
|
||||||
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
|
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
|
||||||
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
|
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
|
||||||
var swiftTokens: [llama_token] = []
|
var swiftTokens: [llama_token] = []
|
||||||
for i in 0 ..< tokenCount {
|
for i in 0 ..< tokenCount {
|
||||||
swiftTokens.append(tokens[Int(i)])
|
swiftTokens.append(tokens[Int(i)])
|
||||||
@ -218,12 +223,12 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
|||||||
|
|
||||||
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
|
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
|
||||||
var result = [CChar](repeating: 0, count: 8)
|
var result = [CChar](repeating: 0, count: 8)
|
||||||
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false)
|
let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
|
||||||
if nTokens < 0 {
|
if nTokens < 0 {
|
||||||
let actualTokensCount = -Int(nTokens)
|
let actualTokensCount = -Int(nTokens)
|
||||||
result = .init(repeating: 0, count: actualTokensCount)
|
result = .init(repeating: 0, count: actualTokensCount)
|
||||||
let check = llama_token_to_piece(
|
let check = llama_token_to_piece(
|
||||||
model,
|
vocab,
|
||||||
token,
|
token,
|
||||||
&result,
|
&result,
|
||||||
Int32(result.count),
|
Int32(result.count),
|
||||||
|
@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
|||||||
actor LlamaContext {
|
actor LlamaContext {
|
||||||
private var model: OpaquePointer
|
private var model: OpaquePointer
|
||||||
private var context: OpaquePointer
|
private var context: OpaquePointer
|
||||||
|
private var vocab: OpaquePointer
|
||||||
private var sampling: UnsafeMutablePointer<llama_sampler>
|
private var sampling: UnsafeMutablePointer<llama_sampler>
|
||||||
private var batch: llama_batch
|
private var batch: llama_batch
|
||||||
private var tokens_list: [llama_token]
|
private var tokens_list: [llama_token]
|
||||||
@ -47,6 +48,7 @@ actor LlamaContext {
|
|||||||
self.sampling = llama_sampler_chain_init(sparams)
|
self.sampling = llama_sampler_chain_init(sparams)
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
||||||
|
vocab = llama_model_get_vocab(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
deinit {
|
deinit {
|
||||||
@ -79,7 +81,7 @@ actor LlamaContext {
|
|||||||
ctx_params.n_threads = Int32(n_threads)
|
ctx_params.n_threads = Int32(n_threads)
|
||||||
ctx_params.n_threads_batch = Int32(n_threads)
|
ctx_params.n_threads_batch = Int32(n_threads)
|
||||||
|
|
||||||
let context = llama_new_context_with_model(model, ctx_params)
|
let context = llama_init_from_model(model, ctx_params)
|
||||||
guard let context else {
|
guard let context else {
|
||||||
print("Could not load context!")
|
print("Could not load context!")
|
||||||
throw LlamaError.couldNotInitializeContext
|
throw LlamaError.couldNotInitializeContext
|
||||||
@ -151,7 +153,7 @@ actor LlamaContext {
|
|||||||
|
|
||||||
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
|
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
|
||||||
|
|
||||||
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
|
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
|
||||||
print("\n")
|
print("\n")
|
||||||
is_done = true
|
is_done = true
|
||||||
let new_token_str = String(cString: temporary_invalid_cchars + [0])
|
let new_token_str = String(cString: temporary_invalid_cchars + [0])
|
||||||
@ -297,7 +299,7 @@ actor LlamaContext {
|
|||||||
let utf8Count = text.utf8.count
|
let utf8Count = text.utf8.count
|
||||||
let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
|
let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
|
||||||
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
|
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
|
||||||
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
|
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
|
||||||
|
|
||||||
var swiftTokens: [llama_token] = []
|
var swiftTokens: [llama_token] = []
|
||||||
for i in 0..<tokenCount {
|
for i in 0..<tokenCount {
|
||||||
@ -316,7 +318,7 @@ actor LlamaContext {
|
|||||||
defer {
|
defer {
|
||||||
result.deallocate()
|
result.deallocate()
|
||||||
}
|
}
|
||||||
let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
|
let nTokens = llama_token_to_piece(vocab, token, result, 8, 0, false)
|
||||||
|
|
||||||
if nTokens < 0 {
|
if nTokens < 0 {
|
||||||
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
|
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
|
||||||
@ -324,7 +326,7 @@ actor LlamaContext {
|
|||||||
defer {
|
defer {
|
||||||
newResult.deallocate()
|
newResult.deallocate()
|
||||||
}
|
}
|
||||||
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
|
let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false)
|
||||||
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
|
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
|
||||||
return Array(bufferPointer)
|
return Array(bufferPointer)
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user