Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emmanuel confusion matrix #833

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions discojs/src/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ export class Validator<D extends DataType> {
/** infer every line of the dataset and check that it is as labelled */
async *test(
dataset: Dataset<DataFormat.Raw[D]>,
): AsyncGenerator<boolean, void> {
): AsyncGenerator<{ result: boolean; predicted: DataFormat.ModelEncoded[D][1]; truth : DataFormat.ModelEncoded[D][1] }, void> {
const results = (await processing.preprocess(this.task, dataset))
.batch(this.task.trainingInformation.batchSize)
.map(async (batch) =>
(await this.#model.predict(batch.map(([inputs, _]) => inputs)))
.zip(batch.map(([_, outputs]) => outputs))
.map(([inferred, truth]) => inferred === truth),
)
.map(([inferred, truth]) => ({ result: inferred === truth, predicted: inferred, truth : truth })),
)
.unbatch();

for await (const e of results) yield e;
Expand Down
111 changes: 97 additions & 14 deletions webapp/src/components/testing/TestSteps.vue
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@
</div>
</div>

<div v-if="confusionMatrix && confusionMatrix.matrix.length > 0" class="p-4 mx-auto lg:w-1/2 h-full bg-white dark:bg-slate-950 rounded-md">
<h4 class="p-4 text-lg font-semibold text-slate-500 dark:text-slate-300">
Confusion Matrix
</h4>
<table class="min-w-full divide-y divide-slate-600 dark:divide-slate-400 text-center">
<thead>
<tr>
<th class="pl-6 py-3 text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider text-center border-r-gray-600 dark:border-r-gray-400 border-r-2 diagonal-header">
<span class="">Label \ Prediction</span>
</th>
<th v-for="(label, index) in confusionMatrix.matrix[0]" :key="'header-' + index" class="px-6 py-3 text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider">
{{ confusionMatrix.labels.get(index) }}
</th>
</tr>
</thead>
<tbody>
<tr v-for="(row, rowIndex) in confusionMatrix.matrix" :key="'row-' + rowIndex">
<td class="px-6 py-4 whitespace-nowrap text-sm font-medium text-gray-800 dark:text-gray-200 border-r-gray-600 dark:border-r-gray-400 border-r-2">
{{ confusionMatrix.labels.get(rowIndex) }}
</td>
<td v-for="(value, colIndex) in row" :key="'col-' + colIndex" class="px-6 py-4 whitespace-nowrap text-sm dark:text-gray-300 text-gray-700">
{{ value }}
</td>
</tr>
</tbody>
</table>
</div>

<div v-if="tested !== undefined">
<div class="mx-auto lg:w-1/2 text-center pb-8">
<CustomButton @click="saveCsv()"> download as csv </CustomButton>
Expand Down Expand Up @@ -142,24 +170,24 @@ const props = defineProps<{
interface Tested {
image: List<{
input: { filename: string; image: ImageData };
output: { truth: string; correct: boolean };
output: { truth: string; correct: boolean; predicted : number, label : number };
}>;
tabular: {
labels: {
input: List<string>;
output: { truth: string; correct: string };
output: { truth: string; correct: string, label : string };
};
results: List<{
input: List<string>;
output: { truth: string; correct: boolean };
output: { truth: string; correct: boolean; predicted : number, label : number };
}>;
};
// TODO what to show?
text: List<{ output: { correct: boolean } }>;
}

const dataset = ref<LabeledDataset[D]>();
const generator = ref<AsyncGenerator<boolean, void>>();
const generator = ref<AsyncGenerator<{result : boolean, predicted : number; truth : number}, void>>();
const tested = ref<Tested[D]>();

const visitedSamples = computed<number>(() => {
Expand All @@ -177,9 +205,60 @@ const visitedSamples = computed<number>(() => {
}
}
});

const confusionMatrix = computed<{labels : Map<number, string>, matrix : number[][]} | undefined>(() => {
if (tested.value === undefined) return undefined;
const labels = new Set<number>();
const mapLabels = new Map<number, string>();

// get all the labels
switch (props.task.trainingInformation.dataType) {
case "image":
(tested.value as Tested["image"]).forEach(({ output }) => {
labels.add(output.label);
labels.add(output.predicted);
mapLabels.set(output.label, output.truth);
});
break;
case "text":
return undefined;
case "tabular":
(tested.value as Tested["tabular"]).results.forEach(({ output }) => {
labels.add(output.label);
labels.add(output.predicted);
mapLabels.set(output.label, output.truth);
});
break;
default: {
const _: never = props.task.trainingInformation;
throw new Error("should never happen");
}
}
const size = Math.max(labels.size, Math.max(...Array.from(labels)));
// Initialize the confusion matrix
const matrix = Array.from({ length: size }, () => Array(size).fill(0));

switch (props.task.trainingInformation.dataType) {
case "image":
(tested.value as Tested["image"]).map(
( {output} ) => matrix[output.predicted][output.label] = matrix[output.predicted][output.label] + 1,
);
break;
//case "text":
// return undefined;
case "tabular":
return undefined;
default: {
const _: never = props.task.trainingInformation;
throw new Error("should never happen");
}
}
return {labels : mapLabels, matrix : matrix};
})

const currentAccuracy = computed<string>(() => {
if (tested.value === undefined) return "0";

if (tested.value === undefined) return "0";
let hits: number | undefined;
switch (props.task.trainingInformation.dataType) {
case "image":
Expand Down Expand Up @@ -250,7 +329,7 @@ async function startImageTest(
generator.value = validator.test(
dataset.map(({ image, label }) => [image, label] as [Image, string]),
);
for await (const [{ filename, image, label }, correct] of dataset.zip(
for await (const [{ filename, image, label }, {result, predicted, truth}] of dataset.zip(
toRaw(generator.value),
)) {
results = results.push({
Expand All @@ -264,7 +343,9 @@ async function startImageTest(
},
output: {
truth: label,
correct,
correct: result,
predicted: predicted,
label : truth,
},
});

Expand Down Expand Up @@ -295,9 +376,9 @@ async function startTabularTest(
let results: Tested["tabular"]["results"] = List();
try {
generator.value = validator.test(dataset);
for await (const [row, correct] of dataset.zip(toRaw(generator.value))) {
const truth = row[outputColumn];
if (truth === undefined)
for await (const [row, {result, predicted, truth}] of dataset.zip(toRaw(generator.value))) {
const truth_label = row[outputColumn];
if (truth_label === undefined)
throw new Error("row doesn't have expected output column");

results = results.push({
Expand All @@ -308,8 +389,10 @@ async function startTabularTest(
return ret;
}),
output: {
truth,
correct,
truth: truth_label,
correct: result,
predicted : predicted,
label : truth,
},
});

Expand All @@ -330,8 +413,8 @@ async function startTextTest(

try {
generator.value = validator.test(dataset);
for await (const correct of toRaw(generator.value)) {
results = results.push({ output: { correct } });
for await (const {result} of toRaw(generator.value)) {
results = results.push({ output: {correct : result} });
tested.value = results as Tested[D];
}
} finally {
Expand Down