mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
Server: add tests for batch size, different seeds (#6950)
This commit is contained in:
parent
1613ef8d8e
commit
3ea0d36000
@ -7,44 +7,16 @@ Feature: Results
|
|||||||
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
|
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
|
||||||
And a model file test-model-00001-of-00003.gguf
|
And a model file test-model-00001-of-00003.gguf
|
||||||
And 128 as batch size
|
And 128 as batch size
|
||||||
And 256 KV cache size
|
And 1024 KV cache size
|
||||||
And 128 max tokens to predict
|
And 128 max tokens to predict
|
||||||
|
|
||||||
Scenario Outline: Multi users completion
|
|
||||||
Given <n_slots> slots
|
|
||||||
And continuous batching
|
And continuous batching
|
||||||
|
|
||||||
|
Scenario Outline: consistent results with same seed
|
||||||
|
Given <n_slots> slots
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
|
|
||||||
Given 42 as seed
|
Given 4 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long story about AI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Given 42 as seed
|
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long story about AI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Given 42 as seed
|
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long story about AI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Given 42 as seed
|
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long story about AI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Given 42 as seed
|
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long story about AI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Given concurrent completion requests
|
Given concurrent completion requests
|
||||||
Then the server is busy
|
Then the server is busy
|
||||||
@ -55,3 +27,55 @@ Feature: Results
|
|||||||
| n_slots |
|
| n_slots |
|
||||||
| 1 |
|
| 1 |
|
||||||
| 2 |
|
| 2 |
|
||||||
|
|
||||||
|
Scenario Outline: different results with different seed
|
||||||
|
Given <n_slots> slots
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42
|
||||||
|
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 43
|
||||||
|
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 44
|
||||||
|
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 45
|
||||||
|
|
||||||
|
Given concurrent completion requests
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
And all slots are idle
|
||||||
|
Then all predictions are different
|
||||||
|
Examples:
|
||||||
|
| n_slots |
|
||||||
|
| 1 |
|
||||||
|
| 2 |
|
||||||
|
|
||||||
|
Scenario Outline: consistent results with same seed and varying batch size
|
||||||
|
Given 4 slots
|
||||||
|
And <temp> temperature
|
||||||
|
# And 0 as draft
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Given 1 prompts "Write a very long story about AI." with seed 42
|
||||||
|
And concurrent completion requests
|
||||||
|
# Then the server is busy # Not all slots will be utilized.
|
||||||
|
Then the server is idle
|
||||||
|
And all slots are idle
|
||||||
|
|
||||||
|
Given <n_parallel> prompts "Write a very long story about AI." with seed 42
|
||||||
|
And concurrent completion requests
|
||||||
|
# Then the server is busy # Not all slots will be utilized.
|
||||||
|
Then the server is idle
|
||||||
|
And all slots are idle
|
||||||
|
|
||||||
|
Then all predictions are equal
|
||||||
|
Examples:
|
||||||
|
| n_parallel | temp |
|
||||||
|
| 1 | 0.0 |
|
||||||
|
| 2 | 0.0 |
|
||||||
|
| 4 | 0.0 |
|
||||||
|
| 1 | 1.0 |
|
||||||
|
# FIXME: These tests fail on master. The problem seems to be the unified KV cache.
|
||||||
|
# See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227
|
||||||
|
# and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 .
|
||||||
|
# | 2 | 1.0 |
|
||||||
|
# | 4 | 1.0 |
|
||||||
|
@ -65,6 +65,7 @@ def step_server_config(context, server_fqdn, server_port):
|
|||||||
context.server_seed = None
|
context.server_seed = None
|
||||||
context.user_api_key = None
|
context.user_api_key = None
|
||||||
context.response_format = None
|
context.response_format = None
|
||||||
|
context.temperature = None
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
@ -232,15 +233,17 @@ async def step_all_slots_status(context, expected_slot_status_string):
|
|||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_request_completion(context, api_error):
|
async def step_request_completion(context, api_error):
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised'
|
||||||
|
seeds = await completions_seed(context, num_seeds=1)
|
||||||
completion = await request_completion(context.prompts.pop(),
|
completion = await request_completion(context.prompts.pop(),
|
||||||
|
seeds[0] if seeds is not None else seeds,
|
||||||
context.base_url,
|
context.base_url,
|
||||||
debug=context.debug,
|
debug=context.debug,
|
||||||
n_predict=context.n_predict,
|
n_predict=context.n_predict,
|
||||||
cache_prompt=context.cache_prompt,
|
cache_prompt=context.cache_prompt,
|
||||||
id_slot=context.id_slot,
|
id_slot=context.id_slot,
|
||||||
seed=await completions_seed(context),
|
|
||||||
expect_api_error=expect_api_error,
|
expect_api_error=expect_api_error,
|
||||||
user_api_key=context.user_api_key)
|
user_api_key=context.user_api_key,
|
||||||
|
temperature=context.temperature)
|
||||||
context.tasks_result.append(completion)
|
context.tasks_result.append(completion)
|
||||||
if context.debug:
|
if context.debug:
|
||||||
print(f"Completion response: {completion}")
|
print(f"Completion response: {completion}")
|
||||||
@ -269,6 +272,15 @@ async def step_predictions_equal(context):
|
|||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
|
|
||||||
|
|
||||||
|
@step('all predictions are different')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_predictions_equal(context):
|
||||||
|
n_completions = await gather_tasks_results(context)
|
||||||
|
assert n_completions >= 2, "need at least 2 completions"
|
||||||
|
assert_all_predictions_different(context.tasks_result)
|
||||||
|
context.tasks_result = []
|
||||||
|
|
||||||
|
|
||||||
@step('the completion is truncated')
|
@step('the completion is truncated')
|
||||||
def step_assert_completion_truncated(context):
|
def step_assert_completion_truncated(context):
|
||||||
step_assert_completion_truncated(context, '')
|
step_assert_completion_truncated(context, '')
|
||||||
@ -311,6 +323,11 @@ def step_response_format(context, response_format):
|
|||||||
context.response_format = json.loads(response_format)
|
context.response_format = json.loads(response_format)
|
||||||
|
|
||||||
|
|
||||||
|
@step('{temperature:f} temperature')
|
||||||
|
def step_temperature(context, temperature):
|
||||||
|
context.temperature = temperature
|
||||||
|
|
||||||
|
|
||||||
@step('streaming is {enable_streaming}')
|
@step('streaming is {enable_streaming}')
|
||||||
def step_streaming(context, enable_streaming):
|
def step_streaming(context, enable_streaming):
|
||||||
context.enable_streaming = enable_streaming == 'enabled'
|
context.enable_streaming = enable_streaming == 'enabled'
|
||||||
@ -353,7 +370,10 @@ def step_n_ubatch(context, n_ubatch):
|
|||||||
|
|
||||||
@step('{seed:d} as seed')
|
@step('{seed:d} as seed')
|
||||||
def step_seed(context, seed):
|
def step_seed(context, seed):
|
||||||
context.seed = seed
|
if context.seed is None:
|
||||||
|
context.seed = [seed]
|
||||||
|
else:
|
||||||
|
context.seed.append(seed)
|
||||||
|
|
||||||
|
|
||||||
@step('a prefix prompt')
|
@step('a prefix prompt')
|
||||||
@ -413,7 +433,9 @@ async def step_oai_chat_completions(context, api_error):
|
|||||||
if context.debug:
|
if context.debug:
|
||||||
print(f"Submitting OAI compatible completions request...")
|
print(f"Submitting OAI compatible completions request...")
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised'
|
||||||
|
seeds = await completions_seed(context, num_seeds=1),
|
||||||
completion = await oai_chat_completions(context.prompts.pop(),
|
completion = await oai_chat_completions(context.prompts.pop(),
|
||||||
|
seeds[0] if seeds is not None else seeds,
|
||||||
context.system_prompt,
|
context.system_prompt,
|
||||||
context.base_url,
|
context.base_url,
|
||||||
'/v1/chat',
|
'/v1/chat',
|
||||||
@ -429,8 +451,6 @@ async def step_oai_chat_completions(context, api_error):
|
|||||||
response_format=context.response_format
|
response_format=context.response_format
|
||||||
if hasattr(context, 'response_format') else None,
|
if hasattr(context, 'response_format') else None,
|
||||||
|
|
||||||
seed=await completions_seed(context),
|
|
||||||
|
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None,
|
if hasattr(context, 'user_api_key') else None,
|
||||||
|
|
||||||
@ -457,10 +477,21 @@ def step_a_prompt_prompt(context, prompt):
|
|||||||
context.n_prompts = len(context.prompts)
|
context.n_prompts = len(context.prompts)
|
||||||
|
|
||||||
|
|
||||||
|
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
||||||
|
def step_many_prompts(context, num_prompts, prompt, seed):
|
||||||
|
if context.seed is None:
|
||||||
|
context.seed = []
|
||||||
|
for _ in range(num_prompts):
|
||||||
|
context.seed.append(seed)
|
||||||
|
context.prompts.append(prompt)
|
||||||
|
context.n_prompts = len(context.prompts)
|
||||||
|
|
||||||
|
|
||||||
@step('concurrent completion requests')
|
@step('concurrent completion requests')
|
||||||
@async_run_until_complete()
|
@async_run_until_complete()
|
||||||
async def step_concurrent_completion_requests(context):
|
async def step_concurrent_completion_requests(context):
|
||||||
await concurrent_requests(context,
|
await concurrent_requests(
|
||||||
|
context,
|
||||||
request_completion,
|
request_completion,
|
||||||
# prompt is inserted automatically
|
# prompt is inserted automatically
|
||||||
context.base_url,
|
context.base_url,
|
||||||
@ -468,9 +499,9 @@ async def step_concurrent_completion_requests(context):
|
|||||||
prompt_prefix=context.prompt_prefix,
|
prompt_prefix=context.prompt_prefix,
|
||||||
prompt_suffix=context.prompt_suffix,
|
prompt_suffix=context.prompt_suffix,
|
||||||
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||||
seed=await completions_seed(context),
|
user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None,
|
||||||
user_api_key=context.user_api_key if hasattr(context,
|
temperature=context.temperature,
|
||||||
'user_api_key') else None)
|
)
|
||||||
|
|
||||||
|
|
||||||
@step('concurrent OAI completions requests')
|
@step('concurrent OAI completions requests')
|
||||||
@ -490,7 +521,6 @@ async def step_oai_chat_completions(context):
|
|||||||
if hasattr(context, 'enable_streaming') else None,
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
response_format=context.response_format
|
response_format=context.response_format
|
||||||
if hasattr(context, 'response_format') else None,
|
if hasattr(context, 'response_format') else None,
|
||||||
seed=await completions_seed(context),
|
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None)
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
@ -512,10 +542,6 @@ async def step_oai_chat_completions(context):
|
|||||||
if hasattr(context, 'enable_streaming') else None,
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
response_format=context.response_format
|
response_format=context.response_format
|
||||||
if hasattr(context, 'response_format') else None,
|
if hasattr(context, 'response_format') else None,
|
||||||
seed=context.seed
|
|
||||||
if hasattr(context, 'seed') else
|
|
||||||
context.server_seed
|
|
||||||
if hasattr(context, 'server_seed') else None,
|
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None)
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
@ -544,7 +570,7 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
|||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_compute_embedding(context):
|
async def step_compute_embedding(context):
|
||||||
context.n_prompts = 1
|
context.n_prompts = 1
|
||||||
context.embeddings = await request_embedding(context_text(context), base_url=context.base_url)
|
context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
|
||||||
|
|
||||||
|
|
||||||
@step('all embeddings are the same')
|
@step('all embeddings are the same')
|
||||||
@ -585,7 +611,7 @@ def step_assert_embeddings(context):
|
|||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_oai_compute_embeddings(context):
|
async def step_oai_compute_embeddings(context):
|
||||||
context.n_prompts = 1
|
context.n_prompts = 1
|
||||||
context.embeddings = await request_oai_embeddings(context_text(context),
|
context.embeddings = await request_oai_embeddings(context_text(context), None,
|
||||||
base_url=context.base_url,
|
base_url=context.base_url,
|
||||||
user_api_key=context.user_api_key,
|
user_api_key=context.user_api_key,
|
||||||
model=context.model)
|
model=context.model)
|
||||||
@ -594,7 +620,7 @@ async def step_oai_compute_embeddings(context):
|
|||||||
@step('an OAI compatible embeddings computation request for multiple inputs')
|
@step('an OAI compatible embeddings computation request for multiple inputs')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_oai_compute_embeddings_multiple_inputs(context):
|
async def step_oai_compute_embeddings_multiple_inputs(context):
|
||||||
context.embeddings = await request_oai_embeddings(context.prompts,
|
context.embeddings = await request_oai_embeddings(context.prompts, None,
|
||||||
base_url=context.base_url,
|
base_url=context.base_url,
|
||||||
user_api_key=context.user_api_key,
|
user_api_key=context.user_api_key,
|
||||||
model=context.model)
|
model=context.model)
|
||||||
@ -740,8 +766,9 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|||||||
if context.debug:
|
if context.debug:
|
||||||
print(f"starting {context.n_prompts} concurrent completion requests...")
|
print(f"starting {context.n_prompts} concurrent completion requests...")
|
||||||
assert context.n_prompts > 0
|
assert context.n_prompts > 0
|
||||||
|
seeds = await completions_seed(context)
|
||||||
for prompt_no in range(context.n_prompts):
|
for prompt_no in range(context.n_prompts):
|
||||||
shifted_args = [context.prompts.pop(), *args]
|
shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
|
||||||
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
@ -781,6 +808,7 @@ def step_server_responds_with_status_code(context, status_code):
|
|||||||
|
|
||||||
|
|
||||||
async def request_completion(prompt,
|
async def request_completion(prompt,
|
||||||
|
seed,
|
||||||
base_url,
|
base_url,
|
||||||
debug=False,
|
debug=False,
|
||||||
prompt_prefix=None,
|
prompt_prefix=None,
|
||||||
@ -788,9 +816,9 @@ async def request_completion(prompt,
|
|||||||
n_predict=None,
|
n_predict=None,
|
||||||
cache_prompt=False,
|
cache_prompt=False,
|
||||||
id_slot=None,
|
id_slot=None,
|
||||||
seed=None,
|
|
||||||
expect_api_error=None,
|
expect_api_error=None,
|
||||||
user_api_key=None):
|
user_api_key=None,
|
||||||
|
temperature=None):
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Sending completion request: {prompt}")
|
print(f"Sending completion request: {prompt}")
|
||||||
origin = "my.super.domain"
|
origin = "my.super.domain"
|
||||||
@ -811,7 +839,8 @@ async def request_completion(prompt,
|
|||||||
"n_predict": n_predict if n_predict is not None else -1,
|
"n_predict": n_predict if n_predict is not None else -1,
|
||||||
"cache_prompt": cache_prompt,
|
"cache_prompt": cache_prompt,
|
||||||
"id_slot": id_slot,
|
"id_slot": id_slot,
|
||||||
"seed": seed if seed is not None else 42
|
"seed": seed if seed is not None else 42,
|
||||||
|
"temperature": temperature if temperature is not None else "0.8f",
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=3600) as response:
|
timeout=3600) as response:
|
||||||
@ -824,6 +853,7 @@ async def request_completion(prompt,
|
|||||||
|
|
||||||
|
|
||||||
async def oai_chat_completions(user_prompt,
|
async def oai_chat_completions(user_prompt,
|
||||||
|
seed,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
base_url,
|
base_url,
|
||||||
base_path,
|
base_path,
|
||||||
@ -833,7 +863,6 @@ async def oai_chat_completions(user_prompt,
|
|||||||
n_predict=None,
|
n_predict=None,
|
||||||
enable_streaming=None,
|
enable_streaming=None,
|
||||||
response_format=None,
|
response_format=None,
|
||||||
seed=None,
|
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
expect_api_error=None):
|
expect_api_error=None):
|
||||||
if debug:
|
if debug:
|
||||||
@ -952,7 +981,7 @@ async def oai_chat_completions(user_prompt,
|
|||||||
return completion_response
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
async def request_embedding(content, base_url=None):
|
async def request_embedding(content, seed, base_url=None):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f'{base_url}/embedding',
|
async with session.post(f'{base_url}/embedding',
|
||||||
json={
|
json={
|
||||||
@ -963,7 +992,7 @@ async def request_embedding(content, base_url=None):
|
|||||||
return [response_json['embedding']]
|
return [response_json['embedding']]
|
||||||
|
|
||||||
|
|
||||||
async def request_oai_embeddings(input,
|
async def request_oai_embeddings(input, seed,
|
||||||
base_url=None, user_api_key=None,
|
base_url=None, user_api_key=None,
|
||||||
model=None, async_client=False):
|
model=None, async_client=False):
|
||||||
# openai client always expects an api_key
|
# openai client always expects an api_key
|
||||||
@ -1036,21 +1065,31 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
|
|||||||
f' {n_predicted} <> {expected_predicted_n}')
|
f' {n_predicted} <> {expected_predicted_n}')
|
||||||
|
|
||||||
def assert_all_predictions_equal(completion_responses):
|
def assert_all_predictions_equal(completion_responses):
|
||||||
content_0 = completion_responses[0]['content']
|
|
||||||
|
|
||||||
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
||||||
print(f"content 0: {content_0}")
|
for i, response_i in enumerate(completion_responses):
|
||||||
|
content_i = response_i['content']
|
||||||
|
print(f"content {i}: {content_i}")
|
||||||
|
for i, response_i in enumerate(completion_responses):
|
||||||
|
content_i = response_i['content']
|
||||||
|
for j, response_j in enumerate(completion_responses):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
content_j = response_j['content']
|
||||||
|
assert content_i == content_j, "contents not equal"
|
||||||
|
|
||||||
i = 1
|
|
||||||
for response in completion_responses[1:]:
|
|
||||||
content = response['content']
|
|
||||||
|
|
||||||
|
def assert_all_predictions_different(completion_responses):
|
||||||
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
||||||
print(f"content {i}: {content}")
|
for i, response_i in enumerate(completion_responses):
|
||||||
|
content_i = response_i['content']
|
||||||
assert content == content_0, "contents not equal"
|
print(f"content {i}: {content_i}")
|
||||||
|
for i, response_i in enumerate(completion_responses):
|
||||||
i += 1
|
content_i = response_i['content']
|
||||||
|
for j, response_j in enumerate(completion_responses):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
content_j = response_j['content']
|
||||||
|
assert content_i != content_j, "contents not different"
|
||||||
|
|
||||||
|
|
||||||
async def gather_tasks_results(context):
|
async def gather_tasks_results(context):
|
||||||
@ -1145,9 +1184,22 @@ def assert_slots_status(slots, expected_slots):
|
|||||||
f" = {expected[key]} != {slot[key]}")
|
f" = {expected[key]} != {slot[key]}")
|
||||||
|
|
||||||
|
|
||||||
async def completions_seed(context):
|
async def completions_seed(context, num_seeds=None):
|
||||||
return context.seed if hasattr(context, 'seed') and context.seed is not None \
|
if hasattr(context, "seed") and context.seed is not None:
|
||||||
else context.server_seed if hasattr(context, 'server_seed') else None
|
assert len(context.seed) == context.n_prompts
|
||||||
|
if num_seeds is None:
|
||||||
|
num_seeds = context.n_prompts
|
||||||
|
assert num_seeds <= context.n_prompts
|
||||||
|
seeds = context.seed[:num_seeds]
|
||||||
|
context.seed = context.seed[num_seeds:] if num_seeds < context.n_prompts else None
|
||||||
|
return seeds
|
||||||
|
|
||||||
|
if hasattr(context, "server_seed") and context.server_seed is not None:
|
||||||
|
if num_seeds is None:
|
||||||
|
return [context.server_seed] * context.n_prompts
|
||||||
|
else:
|
||||||
|
return [context.server_seed] * num_seeds
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def context_text(context):
|
def context_text(context):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user