Make the gallery extension work on colab

This commit is contained in:
oobabooga 2023-02-26 12:37:26 -03:00
parent 756cba2edc
commit 3333f94c30
2 changed files with 28 additions and 12 deletions

View File

@ -4,6 +4,8 @@ from pathlib import Path
import gradio as gr import gradio as gr
from modules.html_generator import image_to_base64
def generate_html(): def generate_html():
css = """ css = """
@ -29,10 +31,14 @@ def generate_html():
background-color: gray; background-color: gray;
} }
.character-gallery td { .character-gallery .image-td {
text-align: center; width: 150px;
vertical-align: middle;
} }
.character-gallery .character-td {
text-align: center !important;
}
""" """
table_html = f'<style>{css}</style><div class="character-gallery"><table>' table_html = f'<style>{css}</style><div class="character-gallery"><table>'
@ -41,15 +47,25 @@ def generate_html():
for file in Path("characters").glob("*"): for file in Path("characters").glob("*"):
if file.name.endswith(".json"): if file.name.endswith(".json"):
json_name = file.name json_name = file.name
image_name = file.name.replace(".json", "") character = file.name.replace(".json", "")
table_html += "<tr>" table_html += "<tr>"
if Path(f"characters/{image_name}.png").exists(): image_html = "<div class='placeholder'></div>"
image_html = f'<img src="file/characters/{image_name}.png">'
elif Path(f"characters/{image_name}.jpg").exists(): for i in [
image_html = f'<img src="file/characters/{image_name}.jpg">' f"characters/{character}.png",
else: f"characters/{character}.jpg",
image_html = "<div class='placeholder'></div>" f"characters/{character}.jpeg",
table_html += f"<td>{image_html}</td><td>{image_name}</td>" ]:
path = Path(i)
if path.exists():
try:
image_html = f'<img src="data:image/png;base64,{image_to_base64(path)}">'
break
except:
continue
table_html += f'<td class="image-td"=>{image_html}</td><td class="character-td">{character}</td>'
table_html += "</tr>" table_html += "</tr>"
table_html += "</table></div>" table_html += "</table></div>"

View File

@ -200,7 +200,7 @@ def image_to_base64(path):
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache): if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
img = Image.open(path) img = Image.open(path)
img.thumbnail((100, 100)) img.thumbnail((200, 200))
img_buffer = BytesIO() img_buffer = BytesIO()
img.convert('RGB').save(img_buffer, format='PNG') img.convert('RGB').save(img_buffer, format='PNG')
image_cache[path] = [mtime, base64.b64encode(img_buffer.getvalue()).decode("utf-8")] image_cache[path] = [mtime, base64.b64encode(img_buffer.getvalue()).decode("utf-8")]