diff --git a/modules/logits.py b/modules/logits.py index d95dd8c0..6fc5bf60 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -1,9 +1,6 @@ import torch from modules import sampler_hijack, shared -from modules.exllama import ExllamaModel -from modules.exllamav2 import Exllamav2Model -from modules.llamacpp_model import LlamaCppModel from modules.logging_colors import logger from modules.text_generation import generate_reply @@ -15,9 +12,9 @@ def get_next_logits(prompt, state, use_samplers, previous): logger.error("No model is loaded! Select one in the Model tab.") return 'Error: No model is loaded1 Select one in the Model tab.', previous - is_non_hf_exllamav2 = isinstance(shared.model, Exllamav2Model) - is_non_hf_exllamav1 = isinstance(shared.model, ExllamaModel) - is_non_hf_llamacpp = isinstance(shared.model, LlamaCppModel) + is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model' + is_non_hf_exllamav1 = shared.model.__class__.__name__ == 'ExllamaModel' + is_non_hf_llamacpp = shared.model.__class__.__name__ == 'LlamaCppModel' if use_samplers: if any([is_non_hf_exllamav2, is_non_hf_exllamav1, is_non_hf_llamacpp]):