mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Fix softprompts when deepspeed is active (#112)
This commit is contained in:
parent
dac6fe0ff4
commit
9ae063e42b
@ -37,7 +37,7 @@ def generate_softprompt_input_tensors(input_ids):
|
|||||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||||
filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
#filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
return inputs_embeds, filler_input_ids
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
|
Loading…
Reference in New Issue
Block a user