This commit is contained in:
oobabooga 2024-02-13 16:05:41 -08:00
parent 193548edce
commit 069ed7c6ef
2 changed files with 5 additions and 5 deletions

View File

@ -36,7 +36,7 @@ def eval_with_progress(self, tokens: Sequence[int]):
progress_bar = range(0, len(tokens), self.n_batch) progress_bar = range(0, len(tokens), self.n_batch)
for i in progress_bar: for i in progress_bar:
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i: min(len(tokens), i + self.n_batch)]
n_past = self.n_tokens n_past = self.n_tokens
n_tokens = len(batch) n_tokens = len(batch)
self._batch.set_batch( self._batch.set_batch(
@ -44,16 +44,16 @@ def eval_with_progress(self, tokens: Sequence[int]):
) )
self._ctx.decode(self._batch) self._ctx.decode(self._batch)
# Save tokens # Save tokens
self.input_ids[n_past : n_past + n_tokens] = batch self.input_ids[n_past: n_past + n_tokens] = batch
# Save logits # Save logits
rows = n_tokens rows = n_tokens
cols = self._n_vocab cols = self._n_vocab
offset = ( offset = (
0 if self.context_params.logits_all else n_tokens - 1 0 if self.context_params.logits_all else n_tokens - 1
) # NOTE: Only save the last token logits if logits_all is False ) # NOTE: Only save the last token logits if logits_all is False
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[ self.scores[n_past + offset: n_past + n_tokens, :].reshape(-1)[
: :
] = self._ctx.get_logits()[offset * cols : rows * cols] ] = self._ctx.get_logits()[offset * cols: rows * cols]
# Update n_tokens # Update n_tokens
self.n_tokens += n_tokens self.n_tokens += n_tokens

View File

@ -125,7 +125,7 @@ def random_preset(state):
for cat in params_and_values: for cat in params_and_values:
choices = list(params_and_values[cat].keys()) choices = list(params_and_values[cat].keys())
if shared.args.loader is not None: if shared.args.loader is not None:
choices = [x for x in choices if loader_contains(sampler)] choices = [x for x in choices if loader_contains(x)]
if len(choices) > 0: if len(choices) > 0:
choice = random.choice(choices) choice = random.choice(choices)