mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 01:09:22 +01:00
Lint
This commit is contained in:
parent
193548edce
commit
069ed7c6ef
@ -36,7 +36,7 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
|||||||
progress_bar = range(0, len(tokens), self.n_batch)
|
progress_bar = range(0, len(tokens), self.n_batch)
|
||||||
|
|
||||||
for i in progress_bar:
|
for i in progress_bar:
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i: min(len(tokens), i + self.n_batch)]
|
||||||
n_past = self.n_tokens
|
n_past = self.n_tokens
|
||||||
n_tokens = len(batch)
|
n_tokens = len(batch)
|
||||||
self._batch.set_batch(
|
self._batch.set_batch(
|
||||||
@ -44,16 +44,16 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
|||||||
)
|
)
|
||||||
self._ctx.decode(self._batch)
|
self._ctx.decode(self._batch)
|
||||||
# Save tokens
|
# Save tokens
|
||||||
self.input_ids[n_past : n_past + n_tokens] = batch
|
self.input_ids[n_past: n_past + n_tokens] = batch
|
||||||
# Save logits
|
# Save logits
|
||||||
rows = n_tokens
|
rows = n_tokens
|
||||||
cols = self._n_vocab
|
cols = self._n_vocab
|
||||||
offset = (
|
offset = (
|
||||||
0 if self.context_params.logits_all else n_tokens - 1
|
0 if self.context_params.logits_all else n_tokens - 1
|
||||||
) # NOTE: Only save the last token logits if logits_all is False
|
) # NOTE: Only save the last token logits if logits_all is False
|
||||||
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
|
self.scores[n_past + offset: n_past + n_tokens, :].reshape(-1)[
|
||||||
:
|
:
|
||||||
] = self._ctx.get_logits()[offset * cols : rows * cols]
|
] = self._ctx.get_logits()[offset * cols: rows * cols]
|
||||||
# Update n_tokens
|
# Update n_tokens
|
||||||
self.n_tokens += n_tokens
|
self.n_tokens += n_tokens
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ def random_preset(state):
|
|||||||
for cat in params_and_values:
|
for cat in params_and_values:
|
||||||
choices = list(params_and_values[cat].keys())
|
choices = list(params_and_values[cat].keys())
|
||||||
if shared.args.loader is not None:
|
if shared.args.loader is not None:
|
||||||
choices = [x for x in choices if loader_contains(sampler)]
|
choices = [x for x in choices if loader_contains(x)]
|
||||||
|
|
||||||
if len(choices) > 0:
|
if len(choices) > 0:
|
||||||
choice = random.choice(choices)
|
choice = random.choice(choices)
|
||||||
|
Loading…
Reference in New Issue
Block a user