Improved multimodal error message

This commit is contained in:
oobabooga 2023-09-17 09:22:16 -07:00
parent 37e2980e05
commit 763ea3bcb2
2 changed files with 10 additions and 7 deletions

View File

@ -11,10 +11,10 @@ https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples: To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
``` ```
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b
python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b
``` ```
There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`: There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`:

View File

@ -56,10 +56,13 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
@staticmethod @staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(shared.model.model, 'embed_tokens'): for attr in ['', 'model', 'model.model', 'model.model.model']:
func = shared.model.model.embed_tokens tmp = getattr(shared.model, attr, None) if attr != '' else shared.model
if tmp is not None and hasattr(tmp, 'embed_tokens'):
func = tmp.embed_tokens
break
else: else:
func = shared.model.model.model.embed_tokens # AutoGPTQ case raise ValueError('The embed_tokens method has not been found for this loader.')
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype) return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)