mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-29 21:50:16 +01:00
62 lines
2.3 KiB
Python
62 lines
2.3 KiB
Python
import os
|
|
import json
|
|
|
|
def create_graph(lora_path, lora_name):
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.ticker import ScalarFormatter
|
|
|
|
peft_model_path = f'{lora_path}/training_graph.json'
|
|
image_model_path = f'{lora_path}/training_graph.png'
|
|
# Check if the JSON file exists
|
|
if os.path.exists(peft_model_path):
|
|
# Load data from JSON file
|
|
with open(peft_model_path, 'r') as file:
|
|
data = json.load(file)
|
|
# Extract x, y1, and y2 values
|
|
x = [item['epoch'] for item in data]
|
|
y1 = [item['learning_rate'] for item in data]
|
|
y2 = [item['loss'] for item in data]
|
|
|
|
# Create the line chart
|
|
fig, ax1 = plt.subplots(figsize=(10, 6))
|
|
|
|
|
|
# Plot y1 (learning rate) on the first y-axis
|
|
ax1.plot(x, y1, 'b-', label='Learning Rate')
|
|
ax1.set_xlabel('Epoch')
|
|
ax1.set_ylabel('Learning Rate', color='b')
|
|
ax1.tick_params('y', colors='b')
|
|
|
|
# Create a second y-axis
|
|
ax2 = ax1.twinx()
|
|
|
|
# Plot y2 (loss) on the second y-axis
|
|
ax2.plot(x, y2, 'r-', label='Loss')
|
|
ax2.set_ylabel('Loss', color='r')
|
|
ax2.tick_params('y', colors='r')
|
|
|
|
# Set the y-axis formatter to display numbers in scientific notation
|
|
ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
|
|
ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
|
|
# Add grid
|
|
ax1.grid(True)
|
|
|
|
# Combine the legends for both plots
|
|
lines, labels = ax1.get_legend_handles_labels()
|
|
lines2, labels2 = ax2.get_legend_handles_labels()
|
|
ax2.legend(lines + lines2, labels + labels2, loc='best')
|
|
|
|
# Set the title
|
|
plt.title(f'{lora_name} LR and Loss vs Epoch')
|
|
|
|
# Save the chart as an image
|
|
plt.savefig(image_model_path)
|
|
|
|
print(f"Graph saved in {image_model_path}")
|
|
else:
|
|
print(f"File 'training_graph.json' does not exist in the {lora_path}")
|
|
|
|
except ImportError:
|
|
print("matplotlib is not installed. Please install matplotlib to create PNG graphs") |