mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Failed attempt at evaluating exllama_hf perplexity
This commit is contained in:
parent
e356f69b36
commit
cec5fb0ef6
@ -100,7 +100,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
|||||||
target_ids[:, :-trg_len] = -100
|
target_ids[:, :-trg_len] = -100
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = shared.model(input_ids, labels=target_ids)
|
outputs = shared.model(input_ids=input_ids, labels=target_ids)
|
||||||
|
|
||||||
# loss is calculated using CrossEntropyLoss which averages over valid labels
|
# loss is calculated using CrossEntropyLoss which averages over valid labels
|
||||||
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
|
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
|
||||||
|
@ -1,15 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import *
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from torch.nn import CrossEntropyLoss
|
||||||
GenerationConfig,
|
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||||
LlamaTokenizer,
|
|
||||||
PretrainedConfig,
|
|
||||||
PreTrainedModel
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
@ -43,13 +38,29 @@ class ExllamaHF(PreTrainedModel):
|
|||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||||
assert len(args) == 0, 'no *args should be passed to forward'
|
assert len(args) == 0, 'no *args should be passed to forward'
|
||||||
use_cache = kwargs['use_cache']
|
use_cache = kwargs.get('use_cache', True)
|
||||||
|
labels = kwargs.get('labels', None)
|
||||||
seq = kwargs['input_ids'][0].tolist()
|
seq = kwargs['input_ids'][0].tolist()
|
||||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = ExLlamaCache(self.ex_model)
|
cache = ExLlamaCache(self.ex_model)
|
||||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
|
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
|
||||||
|
|
||||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
|
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, logits.shape[-1])
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None)
|
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user