mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Failed attempt at evaluating exllama_hf perplexity
This commit is contained in:
parent
e356f69b36
commit
cec5fb0ef6
@ -100,7 +100,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||
target_ids[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = shared.model(input_ids, labels=target_ids)
|
||||
outputs = shared.model(input_ids=input_ids, labels=target_ids)
|
||||
|
||||
# loss is calculated using CrossEntropyLoss which averages over valid labels
|
||||
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
|
||||
|
@ -1,15 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import *
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
LlamaTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
@ -43,13 +38,29 @@ class ExllamaHF(PreTrainedModel):
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
use_cache = kwargs['use_cache']
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
if cache is None:
|
||||
cache = ExLlamaCache(self.ex_model)
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
|
||||
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, logits.shape[-1])
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None)
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user