Skip to content

Commit

Permalink
Refactor plots (#91)
Browse files Browse the repository at this point in the history
* Trying to wrap charts in chartcontainer with context provider, but something's not working yet
* Don't use prop destructuring; seems to fix issue with context
* Distinguish between ChartContainer and Chart so legend can reuse context; lineplot now works perfectly with context, very nice
* Also use contextmanager in skewT plot, very clean. +Add hover to lineplot + make legend width work
* Move chartdata interface to chartcontainer
* formatting
  • Loading branch information
Peter9192 authored Nov 29, 2024
1 parent 1d20761 commit 4337353
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 262 deletions.
86 changes: 43 additions & 43 deletions apps/class-solid/src/components/plots/Axes.tsx
Original file line number Diff line number Diff line change
@@ -1,87 +1,81 @@
// Code generated by AI and checked/modified for correctness

import type { ScaleLinear } from "d3";
import * as d3 from "d3";
import { For } from "solid-js";
import { useChartContext } from "./ChartContainer";

interface AxisProps {
scale: ScaleLinear<number, number>;
transform?: string;
tickCount?: number;
type AxisProps = {
type?: "linear" | "log";
domain?: () => [number, number]; // TODO: is this needed for reactivity?
label?: string;
tickValues?: number[];
tickFormat?: (n: number | { valueOf(): number }) => string;
decreasing?: boolean;
}

const ticks = (props: AxisProps) => {
const domain = props.scale.domain();
const generateTicks = (domain = [0, 1], tickCount = 5) => {
const step = (domain[1] - domain[0]) / (tickCount - 1);
return [...Array(10).keys()].map((i) => domain[0] + i * step);
};

const values = props.tickValues
? props.tickValues.filter((x) => x >= domain[0] && x <= domain[1])
: generateTicks(domain, props.tickCount);
return values.map((value) => ({ value, position: props.scale(value) }));
};

export const AxisBottom = (props: AxisProps) => {
const [chart, updateChart] = useChartContext();
props.domain && chart.scaleX.domain(props.domain());

if (props.type === "log") {
const range = chart.scaleX.range();
const domain = chart.scaleX.range();
updateChart("scaleX", d3.scaleLog().domain(domain).range(range));
}

const format = props.tickFormat ? props.tickFormat : d3.format(".3g");
const ticks = props.tickValues || generateTicks(chart.scaleX.domain());
return (
<g transform={props.transform}>
<line
x1={props.scale.range()[0]}
x2={props.scale.range()[1]}
y1="0"
y2="0"
stroke="currentColor"
/>
<For each={ticks(props)}>
<g transform={`translate(0,${chart.innerHeight - 0.5})`}>
<line x1="0" x2={chart.innerWidth} y1="0" y2="0" stroke="currentColor" />
<For each={ticks}>
{(tick) => (
<g transform={`translate(${tick.position}, 0)`}>
<g transform={`translate(${chart.scaleX(tick)}, 0)`}>
<line y2="6" stroke="currentColor" />
<text y="9" dy="0.71em" text-anchor="middle">
{format(tick.value)}
{format(tick)}
</text>
</g>
)}
</For>
<text x={props.scale.range()[1]} y="9" dy="2em" text-anchor="end">
<text x={chart.innerWidth} y="9" dy="2em" text-anchor="end">
{props.label}
</text>
</g>
);
};

export const AxisLeft = (props: AxisProps) => {
const [chart, updateChart] = useChartContext();
props.domain && chart.scaleY.domain(props.domain());

if (props.type === "log") {
const range = chart.scaleY.range();
const domain = chart.scaleY.domain();
updateChart("scaleY", () => d3.scaleLog().range(range).domain(domain));
}

const ticks = props.tickValues || generateTicks(chart.scaleY.domain());
const format = props.tickFormat ? props.tickFormat : d3.format(".0f");
const yAnchor = props.decreasing ? 0 : 1;
return (
<g transform={props.transform}>
<g transform="translate(-0.5,0)">
<line
x1={0}
x2={0}
y1={props.scale.range()[0]}
y2={props.scale.range()[1]}
y1={chart.scaleY.range()[0]}
y2={chart.scaleY.range()[1]}
stroke="currentColor"
/>
<For each={ticks(props)}>
<For each={ticks}>
{(tick) => (
<g transform={`translate(0, ${tick.position})`}>
<g transform={`translate(0, ${chart.scaleY(tick)})`}>
<line x2="-6" stroke="currentColor" />
<text x="-9" dy="0.32em" text-anchor="end">
{format(tick.value)}
{format(tick)}
</text>
</g>
)}
</For>
<text
y={props.scale.range()[yAnchor]}
text-anchor="end"
transform="translate(-45, 0) rotate(-90)"
>
<text y="0" text-anchor="end" transform="translate(-45, 0) rotate(-90)">
{props.label}
</text>
</g>
Expand All @@ -103,3 +97,9 @@ export function getNiceAxisLimits(data: number[]): [number, number] {

return [niceMin, niceMax];
}

/** Generate evenly space tick values for a linear scale */
const generateTicks = (domain = [0, 1], tickCount = 5) => {
const step = (domain[1] - domain[0]) / (tickCount - 1);
return [...Array(10).keys()].map((i) => domain[0] + i * step);
};
9 changes: 0 additions & 9 deletions apps/class-solid/src/components/plots/Base.tsx

This file was deleted.

91 changes: 91 additions & 0 deletions apps/class-solid/src/components/plots/ChartContainer.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import * as d3 from "d3";
import type { JSX } from "solid-js";
import { createContext, useContext } from "solid-js";
import { type SetStoreFunction, createStore } from "solid-js/store";

interface Chart {
width: number;
height: number;
margin: [number, number, number, number];
innerWidth: number;
innerHeight: number;
scaleX: d3.ScaleLinear<number, number> | d3.ScaleLogarithmic<number, number>;
scaleY: d3.ScaleLinear<number, number> | d3.ScaleLogarithmic<number, number>;
}
type SetChart = SetStoreFunction<Chart>;
const ChartContext = createContext<[Chart, SetChart]>();

/** Container and context manager for chart + legend */
export function ChartContainer(props: {
children: JSX.Element;
width?: number;
height?: number;
margin?: [number, number, number, number];
}) {
const width = props.width || 500;
const height = props.height || 500;
const margin = props.margin || [20, 20, 35, 55];
const [marginTop, marginRight, marginBottom, marginLeft] = margin;
const innerHeight = height - marginTop - marginBottom;
const innerWidth = width - marginRight - marginLeft;
const [chart, updateChart] = createStore<Chart>({
width,
height,
margin,
innerHeight,
innerWidth,
scaleX: d3.scaleLinear().range([0, innerWidth]),
scaleY: d3.scaleLinear().range([innerHeight, 0]),
});
return (
<ChartContext.Provider value={[chart, updateChart]}>
<figure>{props.children}</figure>
</ChartContext.Provider>
);
}

/** Container for chart elements such as axes and lines */
export function Chart(props: { children: JSX.Element; title?: string }) {
const [chart, updateChart] = useChartContext();
const title = props.title || "Default chart";
const [marginTop, _, __, marginLeft] = chart.margin;

return (
<svg
width={chart.width}
height={chart.height}
class="text-slate-500 text-xs tracking-wide"
>
<title>{title}</title>
<g transform={`translate(${marginLeft},${marginTop})`}>
{props.children}
{/* Line along right edge of plot
<line
x1={chart.innerWidth - 0.5}
x2={chart.innerWidth - 0.5}
y1="0"
y2={chart.innerHeight}
stroke="#dfdfdf"
stroke-width="0.75px"
fill="none"
/> */}
</g>
</svg>
);
}

export function useChartContext() {
const context = useContext(ChartContext);
if (!context) {
throw new Error(
"useChartContext must be used within a ChartProvider; typically by wrapping your components in a ChartContainer.",
);
}
return context;
}
export interface ChartData<T> {
label: string;
color: string;
linestyle: string;
data: T[];
}
9 changes: 5 additions & 4 deletions apps/class-solid/src/components/plots/Legend.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import { For } from "solid-js";
import { cn } from "~/lib/utils";
import type { ChartData } from "./Base";
import type { ChartData } from "./ChartContainer";
import { useChartContext } from "./ChartContainer";

export interface LegendProps<T> {
entries: () => ChartData<T>[];
width: string;
}

export function Legend<T>(props: LegendProps<T>) {
const [chart, updateChart] = useChartContext();

return (
// {/* Legend */}
<div
class={cn(
"flex flex-wrap justify-end text-sm tracking-tight",
props.width,
`w-[${chart.width}px]`,
)}
>
<For each={props.entries()}>
Expand Down
91 changes: 34 additions & 57 deletions apps/class-solid/src/components/plots/LinePlot.tsx
Original file line number Diff line number Diff line change
@@ -1,78 +1,55 @@
import * as d3 from "d3";
import { For } from "solid-js";
import { For, createSignal } from "solid-js";
import { AxisBottom, AxisLeft, getNiceAxisLimits } from "./Axes";
import type { ChartData } from "./Base";
import type { ChartData } from "./ChartContainer";
import { Chart, ChartContainer, useChartContext } from "./ChartContainer";
import { Legend } from "./Legend";

export interface Point {
x: number;
y: number;
}

function Line(d: ChartData<Point>) {
const [chart, updateChart] = useChartContext();
const [hovered, setHovered] = createSignal(false);

const l = d3.line<Point>(
(d) => chart.scaleX(d.x),
(d) => chart.scaleY(d.y),
);
return (
<path
onMouseEnter={() => setHovered(true)}
onMouseLeave={() => setHovered(false)}
fill="none"
stroke={d.color}
stroke-dasharray={d.linestyle}
stroke-width={hovered() ? 5 : 3}
d={l(d.data) || ""}
>
<title>{d.label}</title>
</path>
);
}

export default function LinePlot({
data,
xlabel,
ylabel,
}: { data: () => ChartData<Point>[]; xlabel?: string; ylabel?: string }) {
// TODO: Make responsive
// const margin = [30, 40, 20, 45]; // reference from skew-T
const [marginTop, marginRight, marginBottom, marginLeft] = [20, 20, 35, 55];
const width = 500;
const height = 500;
const w = 500 - marginRight - marginLeft;
const h = 500 - marginTop - marginBottom;

const xLim = () =>
getNiceAxisLimits(data().flatMap((d) => d.data.flatMap((d) => d.x)));
const yLim = () =>
getNiceAxisLimits(data().flatMap((d) => d.data.flatMap((d) => d.y)));
const scaleX = () => d3.scaleLinear(xLim(), [0, w]);
const scaleY = () => d3.scaleLinear(yLim(), [h, 0]);

const l = d3.line<Point>(
(d) => scaleX()(d.x),
(d) => scaleY()(d.y),
);

return (
<figure>
<Legend entries={data} width={`w-[${width}px]`} />
{/* Plot */}
<svg
width={width}
height={height}
class="text-slate-500 text-xs tracking-wide"
>
<g transform={`translate(${marginLeft},${marginTop})`}>
<title>Vertical profile plot</title>
{/* Axes */}
<AxisBottom
scale={scaleX()}
transform={`translate(0,${h - 0.5})`}
label={xlabel}
/>
<AxisLeft
scale={scaleY()}
transform="translate(-0.5,0)"
label={ylabel}
/>

{/* Line */}
<For each={data()}>
{(d) => (
<path
fill="none"
stroke={d.color}
stroke-dasharray={d.linestyle}
stroke-width="3"
d={l(d.data) || ""}
>
<title>{d.label}</title>
</path>
)}
</For>
</g>
</svg>
</figure>
<ChartContainer>
<Legend entries={data} />
<Chart title="Vertical profile plot">
<AxisBottom domain={xLim} label={xlabel} />
<AxisLeft domain={yLim} label={ylabel} />
<For each={data()}>{(d) => Line(d)}</For>
</Chart>
</ChartContainer>
);
}
Loading

0 comments on commit 4337353

Please sign in to comment.