Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding custom Datapoints to canvas. #125

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion index.html
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,21 @@ <h4>Output</h4>
<div id="linechart"></div>
</div>
<div id="heatmap"></div>
<div id="select-platform" class="ui-paintPlatform">
<div class="label">
Select color that you would like to paint with.
</div>
<label class="mdl-checkbox mdl-js-checkbox mdl-js-ripple-effect">
<input type="radio" name="color-choose" id="select-orange" value="1" checked>
<span class="mdl-checkbox__label label">Orange</span>
</label>
<label class="mdl-checkbox mdl-js-checkbox mdl-js-ripple-effect">
<input type="radio" name="color-choose" id="select-blue" value="-1">
<span class="mdl-checkbox__label label">Blue</span>
</label>
</div>
<div style="float:left;margin-top:20px">
<div style="display:flex; align-items:center;">

<!-- Gradient color scale -->
<div class="label" style="width:105px; margin-right: 10px">
Colors shows data, neuron and weight values.
Expand Down
4 changes: 2 additions & 2 deletions src/heatmap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export class HeatMap {
position: "relative",
top: `-${padding}px`,
left: `-${padding}px`
});
})
this.canvas = container.append("canvas")
.attr("width", numSamples)
.attr("height", numSamples)
Expand All @@ -97,7 +97,7 @@ export class HeatMap {
.style("position", "absolute")
.style("top", `${padding}px`)
.style("left", `${padding}px`);

if (!this.settings.noSvg) {
this.svg = container.append("svg").attr({
"width": width,
Expand Down
67 changes: 55 additions & 12 deletions src/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ let HIDABLE_CONTROLS = [
["Noise level", "noise"],
["Batch size", "batchSize"],
["# of hidden layers", "numHiddenLayers"],
["Paint Platform", "paintPlatform"],
];

class Player {
Expand Down Expand Up @@ -166,8 +167,6 @@ let colorScale = d3.scale.linear<string, number>()
.range(["#f59322", "#e8eaeb", "#0877bd"])
.clamp(true);
let iter = 0;
let trainData: Example2D[] = [];
let testData: Example2D[] = [];
let network: nn.Node[][] = null;
let lossTrain = 0;
let lossTest = 0;
Expand Down Expand Up @@ -264,11 +263,44 @@ function makeGUI() {
reset();
});

// For changing state on different selections
d3.select("#select-orange").on("change", function() {
state.editColor = this.checked ? -1 : 1
state.serialize()
userHasInteracted()
});

d3.select("#select-blue").on("change", function() {
state.editColor = this.checked ? 1 : -1
state.serialize()
userHasInteracted()
});

// On drag, we want to paint our canvas with the dots.
let dragBehavior = d3.behavior.drag().on("drag", function() {
let isVisible = d3.select("#select-platform").style("display") === "block"
if(state.problem === Problem.CLASSIFICATION && isVisible) {
let [x, y] = d3.mouse(this)
let label = state.editColor
let padding = 20
let maxScale = 5.0
let factor = 23.07
x -= padding
y -= padding
x = x/factor - maxScale
y = maxScale - y/factor
state.trainData.push({x, y, label})
heatMap.updatePoints(state.trainData);
}
});

d3.select("#heatmap").call(dragBehavior);

let showTestData = d3.select("#show-test-data").on("change", function() {
state.showTestData = this.checked;
state.serialize();
userHasInteracted();
heatMap.updateTestPoints(state.showTestData ? testData : []);
heatMap.updateTestPoints(state.showTestData ? state.testData : []);
});
// Check/uncheck the checkbox according to the current state.
showTestData.property("checked", state.showTestData);
Expand Down Expand Up @@ -355,6 +387,7 @@ function makeGUI() {

let problem = d3.select("#problem").on("change", function() {
state.problem = problems[this.value];
togglePaintSelection();
generateData();
drawDatasetThumbnails();
parametersChanged = true;
Expand Down Expand Up @@ -908,7 +941,7 @@ function constructInput(x: number, y: number): number[] {

function oneStep(): void {
iter++;
trainData.forEach((point, i) => {
state.trainData.forEach((point, i) => {
let input = constructInput(point.x, point.y);
nn.forwardProp(network, input);
nn.backProp(network, point.label, nn.Errors.SQUARE);
Expand All @@ -917,8 +950,8 @@ function oneStep(): void {
}
});
// Compute the loss.
lossTrain = getLoss(network, trainData);
lossTest = getLoss(network, testData);
lossTrain = getLoss(network, state.trainData);
lossTest = getLoss(network, state.testData);
updateUI();
}

Expand Down Expand Up @@ -949,6 +982,11 @@ function reset(onStartup=false) {
d3.select("#layers-label").text("Hidden layer" + suffix);
d3.select("#num-layers").text(state.numHiddenLayers);

togglePaintSelection()
// Correct radio button on reset
let radioColor = state.editColor === - 1 ? "#select-orange" : "#select-blue";
d3.select(radioColor).attr("checked", "checked")

// Make a simple network.
iter = 0;
let numInputs = constructInput(0 , 0).length;
Expand All @@ -957,8 +995,8 @@ function reset(onStartup=false) {
nn.Activations.LINEAR : nn.Activations.TANH;
network = nn.buildNetwork(shape, state.activation, outputActivation,
state.regularization, constructInputIds(), state.initZero);
lossTrain = getLoss(network, trainData);
lossTest = getLoss(network, testData);
lossTrain = getLoss(network, state.trainData);
lossTest = getLoss(network, state.testData);
drawNetwork(network);
updateUI(true);
};
Expand Down Expand Up @@ -1064,6 +1102,11 @@ function hideControls() {
.attr("href", window.location.href);
}

function togglePaintSelection() {
let visiblity = state.problem === Problem.CLASSIFICATION ? "" : "none"
d3.select("#select-platform").style("display", visiblity);
}

function generateData(firstTime = false) {
if (!firstTime) {
// Change the seed.
Expand All @@ -1081,10 +1124,10 @@ function generateData(firstTime = false) {
shuffle(data);
// Split into train and test data.
let splitIndex = Math.floor(data.length * state.percTrainData / 100);
trainData = data.slice(0, splitIndex);
testData = data.slice(splitIndex);
heatMap.updatePoints(trainData);
heatMap.updateTestPoints(state.showTestData ? testData : []);
state.trainData = data.slice(0, splitIndex);
state.testData = data.slice(splitIndex);
heatMap.updatePoints(state.trainData);
heatMap.updateTestPoints(state.showTestData ? state.testData : []);
}

let firstInteraction = true;
Expand Down
6 changes: 5 additions & 1 deletion src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ export class State {
{name: "tutorial", type: Type.STRING},
{name: "problem", type: Type.OBJECT, keyMap: problems},
{name: "initZero", type: Type.BOOLEAN},
{name: "hideText", type: Type.BOOLEAN}
{name: "hideText", type: Type.BOOLEAN},
{name: "editColor", type: Type.NUMBER}
];

[key: string]: any;
Expand Down Expand Up @@ -160,8 +161,11 @@ export class State {
sinX = false;
cosY = false;
sinY = false;
editColor = -1;
dataset: dataset.DataGenerator = dataset.classifyCircleData;
regDataset: dataset.DataGenerator = dataset.regressPlane;
trainData: dataset.Example2D[] = [];
testData: dataset.Example2D[] = [];
seed: string;

/**
Expand Down