diff --git a/modules/evaluate.py b/modules/evaluate.py index 629d3e94..3e555a3e 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -100,7 +100,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): target_ids[:, :-trg_len] = -100 with torch.no_grad(): - outputs = shared.model(input_ids, labels=target_ids) + outputs = shared.model(input_ids=input_ids, labels=target_ids) # loss is calculated using CrossEntropyLoss which averages over valid labels # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index 64de7a5f..a1b73bed 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -1,15 +1,10 @@ import os -import sys from pathlib import Path -from typing import * +from typing import Any, Dict, Optional, Union import torch -from transformers import ( - GenerationConfig, - LlamaTokenizer, - PretrainedConfig, - PreTrainedModel -) +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared @@ -43,13 +38,29 @@ class ExllamaHF(PreTrainedModel): def __call__(self, *args, **kwargs): # TODO: Some decoding methods (such as Contrastive Search) may not work at this time assert len(args) == 0, 'no *args should be passed to forward' - use_cache = kwargs['use_cache'] + use_cache = kwargs.get('use_cache', True) + labels = kwargs.get('labels', None) seq = kwargs['input_ids'][0].tolist() cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None if cache is None: cache = ExLlamaCache(self.ex_model) self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True) + logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None) @classmethod