mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-26 12:22:08 +01:00
Use gen_begin_reuse in exllama
This commit is contained in:
parent
239b11c94b
commit
766c760cd7
@ -41,45 +41,49 @@ class ExllamaModel:
|
|||||||
model = ExLlama(config)
|
model = ExLlama(config)
|
||||||
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
|
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
|
||||||
cache = ExLlamaCache(model)
|
cache = ExLlamaCache(model)
|
||||||
|
generator = ExLlamaGenerator(model, tokenizer, cache)
|
||||||
|
|
||||||
result = self()
|
result = self()
|
||||||
result.config = config
|
result.config = config
|
||||||
result.model = model
|
result.model = model
|
||||||
result.cache = cache
|
result.cache = cache
|
||||||
result.tokenizer = tokenizer
|
result.tokenizer = tokenizer
|
||||||
|
self.generator = generator
|
||||||
return result, result
|
return result, result
|
||||||
|
|
||||||
def generate(self, prompt, state, callback=None):
|
def generate(self, prompt, state, callback=None):
|
||||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
self.generator.settings.temperature = state['temperature']
|
||||||
generator.settings.temperature = state['temperature']
|
self.generator.settings.top_p = state['top_p']
|
||||||
generator.settings.top_p = state['top_p']
|
self.generator.settings.top_k = state['top_k']
|
||||||
generator.settings.top_k = state['top_k']
|
self.generator.settings.typical = state['typical_p']
|
||||||
generator.settings.typical = state['typical_p']
|
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
|
||||||
if state['ban_eos_token']:
|
if state['ban_eos_token']:
|
||||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||||
|
else:
|
||||||
|
self.generator.disallow_tokens(None)
|
||||||
|
|
||||||
text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
|
text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state, callback=None):
|
def generate_with_streaming(self, prompt, state, callback=None):
|
||||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
self.generator.settings.temperature = state['temperature']
|
||||||
generator.settings.temperature = state['temperature']
|
self.generator.settings.top_p = state['top_p']
|
||||||
generator.settings.top_p = state['top_p']
|
self.generator.settings.top_k = state['top_k']
|
||||||
generator.settings.top_k = state['top_k']
|
self.generator.settings.typical = state['typical_p']
|
||||||
generator.settings.typical = state['typical_p']
|
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
|
||||||
if state['ban_eos_token']:
|
if state['ban_eos_token']:
|
||||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||||
|
else:
|
||||||
|
self.generator.disallow_tokens(None)
|
||||||
|
|
||||||
generator.end_beam_search()
|
self.generator.end_beam_search()
|
||||||
ids = generator.tokenizer.encode(prompt)
|
ids = self.generator.tokenizer.encode(prompt)
|
||||||
generator.gen_begin(ids)
|
self.generator.gen_begin_reuse(ids)
|
||||||
initial_len = generator.sequence[0].shape[0]
|
initial_len = self.generator.sequence[0].shape[0]
|
||||||
for i in range(state['max_new_tokens']):
|
for _ in range(state['max_new_tokens']):
|
||||||
token = generator.gen_single_token()
|
token = self.generator.gen_single_token()
|
||||||
yield (generator.tokenizer.decode(generator.sequence[0][initial_len:]))
|
yield (self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]))
|
||||||
if token.item() == generator.tokenizer.eos_token_id or shared.stop_everything:
|
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||||
break
|
break
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user