mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
server-parallel : add "--reverse-prompt" + compiler warning fixes
This commit is contained in:
parent
afc09db51c
commit
5ab6c2132a
@ -1,13 +1,15 @@
|
|||||||
#include <chrono>
|
#include "frontend.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
#include "../server/httplib.h"
|
#include "../server/httplib.h"
|
||||||
#include "../server/json.hpp"
|
#include "../server/json.hpp"
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "frontend.h"
|
#include <chrono>
|
||||||
#include "common.h"
|
|
||||||
#include "llama.h"
|
|
||||||
|
|
||||||
using namespace httplib;
|
using namespace httplib;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -241,9 +243,7 @@ struct server_parallel_context {
|
|||||||
string prompt = data.value("prompt", "");
|
string prompt = data.value("prompt", "");
|
||||||
for (llama_client_slot & slot : slots)
|
for (llama_client_slot & slot : slots)
|
||||||
{
|
{
|
||||||
if (
|
if ((slot_id == -1 && slot.available()) || slot.id == slot_id)
|
||||||
slot_id == -1 && slot.available() ||
|
|
||||||
slot.id == slot_id)
|
|
||||||
{
|
{
|
||||||
slot.start(prompt, temperature);
|
slot.start(prompt, temperature);
|
||||||
LOG_TEE("slot %i is processing\n", slot.id);
|
LOG_TEE("slot %i is processing\n", slot.id);
|
||||||
@ -429,8 +429,6 @@ struct server_parallel_context {
|
|||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
slot.sampled = id;
|
slot.sampled = id;
|
||||||
|
|
||||||
size_t pos = 0;
|
|
||||||
|
|
||||||
size_t stop_pos =
|
size_t stop_pos =
|
||||||
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
|
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
|
||||||
|
|
||||||
@ -740,20 +738,34 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
else if (arg == "--numa")
|
else if (arg == "--numa")
|
||||||
{
|
{
|
||||||
params.numa = true;
|
params.numa = true;
|
||||||
} else if (arg == "-cb" || arg == "--cont-batching") {
|
} else if (arg == "-cb" || arg == "--cont-batching")
|
||||||
|
{
|
||||||
params.cont_batching = true;
|
params.cont_batching = true;
|
||||||
} else if (arg == "-np" || arg == "--parallel") {
|
}
|
||||||
if (++i >= argc) {
|
else if (arg == "-np" || arg == "--parallel")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_parallel = std::stoi(argv[i]);
|
params.n_parallel = std::stoi(argv[i]);
|
||||||
} else if (arg == "-n" || arg == "--n-predict") {
|
} else if (arg == "-n" || arg == "--n-predict")
|
||||||
if (++i >= argc) {
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_predict = std::stoi(argv[i]);
|
params.n_predict = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "-r" || arg == "--reverse-prompt")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.antiprompt.push_back(argv[i]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user