Use gen_begin_reuse in exllama

This commit is contained in:
oobabooga 2023-06-17 18:00:10 -03:00
parent 239b11c94b
commit 766c760cd7

View File

@ -41,45 +41,49 @@ class ExllamaModel:
model = ExLlama(config) model = ExLlama(config)
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path)) tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
cache = ExLlamaCache(model) cache = ExLlamaCache(model)
generator = ExLlamaGenerator(model, tokenizer, cache)
result = self() result = self()
result.config = config result.config = config
result.model = model result.model = model
result.cache = cache result.cache = cache
result.tokenizer = tokenizer result.tokenizer = tokenizer
self.generator = generator
return result, result return result, result
def generate(self, prompt, state, callback=None): def generate(self, prompt, state, callback=None):
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) self.generator.settings.temperature = state['temperature']
generator.settings.temperature = state['temperature'] self.generator.settings.top_p = state['top_p']
generator.settings.top_p = state['top_p'] self.generator.settings.top_k = state['top_k']
generator.settings.top_k = state['top_k'] self.generator.settings.typical = state['typical_p']
generator.settings.typical = state['typical_p'] self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
if state['ban_eos_token']: if state['ban_eos_token']:
generator.disallow_tokens([self.tokenizer.eos_token_id]) self.generator.disallow_tokens([self.tokenizer.eos_token_id])
else:
self.generator.disallow_tokens(None)
text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens']) text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
return text return text
def generate_with_streaming(self, prompt, state, callback=None): def generate_with_streaming(self, prompt, state, callback=None):
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) self.generator.settings.temperature = state['temperature']
generator.settings.temperature = state['temperature'] self.generator.settings.top_p = state['top_p']
generator.settings.top_p = state['top_p'] self.generator.settings.top_k = state['top_k']
generator.settings.top_k = state['top_k'] self.generator.settings.typical = state['typical_p']
generator.settings.typical = state['typical_p'] self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
if state['ban_eos_token']: if state['ban_eos_token']:
generator.disallow_tokens([self.tokenizer.eos_token_id]) self.generator.disallow_tokens([self.tokenizer.eos_token_id])
else:
self.generator.disallow_tokens(None)
generator.end_beam_search() self.generator.end_beam_search()
ids = generator.tokenizer.encode(prompt) ids = self.generator.tokenizer.encode(prompt)
generator.gen_begin(ids) self.generator.gen_begin_reuse(ids)
initial_len = generator.sequence[0].shape[0] initial_len = self.generator.sequence[0].shape[0]
for i in range(state['max_new_tokens']): for _ in range(state['max_new_tokens']):
token = generator.gen_single_token() token = self.generator.gen_single_token()
yield (generator.tokenizer.decode(generator.sequence[0][initial_len:])) yield (self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]))
if token.item() == generator.tokenizer.eos_token_id or shared.stop_everything: if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break break
def encode(self, string, **kwargs): def encode(self, string, **kwargs):