diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 0a86b4fd..d5ebbb76 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -104,7 +104,7 @@ class MirostatLogitsWarper(LogitsWarper): break # Normalize the probabilities of the remaining words - prob_topk = torch.softmax(sorted_logits, dim=0) + prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda') prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')