mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-10 04:20:33 +01:00
132 lines
4.5 KiB
Python
132 lines
4.5 KiB
Python
|
from pathlib import Path
|
||
|
|
||
|
import tensorrt_llm
|
||
|
import torch
|
||
|
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
|
||
|
|
||
|
from modules import shared
|
||
|
from modules.logging_colors import logger
|
||
|
from modules.text_generation import (
|
||
|
get_max_prompt_length,
|
||
|
get_reply_from_output_ids
|
||
|
)
|
||
|
|
||
|
|
||
|
class TensorRTLLMModel:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
@classmethod
|
||
|
def from_pretrained(self, path_to_model):
|
||
|
|
||
|
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||
|
runtime_rank = tensorrt_llm.mpi_rank()
|
||
|
|
||
|
# Define model settings
|
||
|
runner_kwargs = dict(
|
||
|
engine_dir=str(path_to_model),
|
||
|
lora_dir=None,
|
||
|
rank=runtime_rank,
|
||
|
debug_mode=False,
|
||
|
lora_ckpt_source="hf",
|
||
|
)
|
||
|
|
||
|
if shared.args.cpp_runner:
|
||
|
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
|
||
|
runner_kwargs.update(
|
||
|
max_batch_size=1,
|
||
|
max_input_len=shared.args.max_seq_len - 512,
|
||
|
max_output_len=512,
|
||
|
max_beam_width=1,
|
||
|
max_attention_window_size=None,
|
||
|
sink_token_length=None,
|
||
|
)
|
||
|
else:
|
||
|
logger.info("TensorRT-LLM: Using \"ModelRunner\"")
|
||
|
|
||
|
# Load the model
|
||
|
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
|
||
|
runner = runner_cls.from_dir(**runner_kwargs)
|
||
|
|
||
|
result = self()
|
||
|
result.model = runner
|
||
|
result.runtime_rank = runtime_rank
|
||
|
|
||
|
return result
|
||
|
|
||
|
def generate_with_streaming(self, prompt, state):
|
||
|
batch_input_ids = []
|
||
|
input_ids = shared.tokenizer.encode(
|
||
|
prompt,
|
||
|
add_special_tokens=True,
|
||
|
truncation=False,
|
||
|
)
|
||
|
input_ids = torch.tensor(input_ids, dtype=torch.int32)
|
||
|
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length
|
||
|
batch_input_ids.append(input_ids)
|
||
|
|
||
|
if shared.args.cpp_runner:
|
||
|
max_new_tokens = min(512, state['max_new_tokens'])
|
||
|
elif state['auto_max_new_tokens']:
|
||
|
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
|
||
|
else:
|
||
|
max_new_tokens = state['max_new_tokens']
|
||
|
|
||
|
with torch.no_grad():
|
||
|
generator = self.model.generate(
|
||
|
batch_input_ids,
|
||
|
max_new_tokens=max_new_tokens,
|
||
|
max_attention_window_size=None,
|
||
|
sink_token_length=None,
|
||
|
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
|
||
|
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
|
||
|
temperature=state['temperature'],
|
||
|
top_k=state['top_k'],
|
||
|
top_p=state['top_p'],
|
||
|
num_beams=1,
|
||
|
length_penalty=1.0,
|
||
|
repetition_penalty=state['repetition_penalty'],
|
||
|
presence_penalty=state['presence_penalty'],
|
||
|
frequency_penalty=state['frequency_penalty'],
|
||
|
stop_words_list=None,
|
||
|
bad_words_list=None,
|
||
|
lora_uids=None,
|
||
|
prompt_table_path=None,
|
||
|
prompt_tasks=None,
|
||
|
streaming=not shared.args.cpp_runner,
|
||
|
output_sequence_lengths=True,
|
||
|
return_dict=True,
|
||
|
medusa_choices=None
|
||
|
)
|
||
|
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
cumulative_reply = ''
|
||
|
starting_from = batch_input_ids[0].shape[-1]
|
||
|
|
||
|
if shared.args.cpp_runner:
|
||
|
sequence_length = generator['sequence_lengths'][0].item()
|
||
|
output_ids = generator['output_ids'][0][0][:sequence_length].tolist()
|
||
|
|
||
|
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
|
||
|
starting_from = sequence_length
|
||
|
yield cumulative_reply
|
||
|
else:
|
||
|
for curr_outputs in generator:
|
||
|
if shared.stop_everything:
|
||
|
break
|
||
|
|
||
|
sequence_length = curr_outputs['sequence_lengths'][0].item()
|
||
|
output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()
|
||
|
|
||
|
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
|
||
|
starting_from = sequence_length
|
||
|
yield cumulative_reply
|
||
|
|
||
|
def generate(self, prompt, state):
|
||
|
output = ''
|
||
|
for output in self.generate_with_streaming(prompt, state):
|
||
|
pass
|
||
|
|
||
|
return output
|