mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
5ddf7ea1fb
Small, non-functional changes were made to non-compliant files. These include breaking up long lines, whitespace sanitation and unused import removal. Maximum line length in python files was set to a generous 125 chars, in order to minimize number of changes needed in scripts and general annoyance. The "txt" prompts directory is excluded from the checks as it may contain oddly formatted files and strings for a good reason. Signed-off-by: Jiri Podivin <jpodivin@gmail.com>
58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
import matplotlib.pyplot as plt
|
|
import os
|
|
import csv
|
|
|
|
labels = []
|
|
numbers = []
|
|
numEntries = 1
|
|
|
|
rows = []
|
|
|
|
|
|
def bar_chart(numbers, labels, pos):
|
|
plt.bar(pos, numbers, color='blue')
|
|
plt.xticks(ticks=pos, labels=labels)
|
|
plt.title("Jeopardy Results by Model")
|
|
plt.xlabel("Model")
|
|
plt.ylabel("Questions Correct")
|
|
plt.show()
|
|
|
|
|
|
def calculatecorrect():
|
|
directory = os.fsencode("./examples/jeopardy/results/")
|
|
csv_reader = csv.reader(open("./examples/jeopardy/qasheet.csv", 'rt'), delimiter=',')
|
|
for row in csv_reader:
|
|
global rows
|
|
rows.append(row)
|
|
for listing in os.listdir(directory):
|
|
filename = os.fsdecode(listing)
|
|
if filename.endswith(".txt"):
|
|
file = open("./examples/jeopardy/results/" + filename, "rt")
|
|
global labels
|
|
global numEntries
|
|
global numbers
|
|
labels.append(filename[:-4])
|
|
numEntries += 1
|
|
i = 1
|
|
totalcorrect = 0
|
|
for line in file.readlines():
|
|
if line.strip() != "------":
|
|
print(line)
|
|
else:
|
|
print("Correct answer: " + rows[i][2] + "\n")
|
|
i += 1
|
|
print("Did the AI get the question right? (y/n)")
|
|
if input() == "y":
|
|
totalcorrect += 1
|
|
numbers.append(totalcorrect)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
calculatecorrect()
|
|
pos = list(range(numEntries))
|
|
labels.append("Human")
|
|
numbers.append(48.11)
|
|
bar_chart(numbers, labels, pos)
|
|
print(labels)
|
|
print(numbers)
|