mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Fix repeated tokens with exllama
This commit is contained in:
parent
766c760cd7
commit
cbd63eeeff
@ -51,21 +51,7 @@ class ExllamaModel:
|
|||||||
self.generator = generator
|
self.generator = generator
|
||||||
return result, result
|
return result, result
|
||||||
|
|
||||||
def generate(self, prompt, state, callback=None):
|
def generate_with_streaming(self, prompt, state):
|
||||||
self.generator.settings.temperature = state['temperature']
|
|
||||||
self.generator.settings.top_p = state['top_p']
|
|
||||||
self.generator.settings.top_k = state['top_k']
|
|
||||||
self.generator.settings.typical = state['typical_p']
|
|
||||||
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
|
||||||
if state['ban_eos_token']:
|
|
||||||
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
|
|
||||||
else:
|
|
||||||
self.generator.disallow_tokens(None)
|
|
||||||
|
|
||||||
text = self.generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
|
|
||||||
return text
|
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state, callback=None):
|
|
||||||
self.generator.settings.temperature = state['temperature']
|
self.generator.settings.temperature = state['temperature']
|
||||||
self.generator.settings.top_p = state['top_p']
|
self.generator.settings.top_p = state['top_p']
|
||||||
self.generator.settings.top_k = state['top_k']
|
self.generator.settings.top_k = state['top_k']
|
||||||
@ -86,5 +72,12 @@ class ExllamaModel:
|
|||||||
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def generate(self, prompt, state):
|
||||||
|
output = ''
|
||||||
|
for output in self.generate_with_streaming(prompt, state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
return self.tokenizer.encode(string)
|
return self.tokenizer.encode(string)
|
||||||
|
Loading…
Reference in New Issue
Block a user