import React, { useState, useEffect, useRef } from 'react'; import * as d3 from 'd3'; import { useTheme } from '../context/themeContext'; import MODELS from '../utils/models'; import DEVICES from '../utils/devices'; type Precision = '32-bit' | '16-bit' | '8-bit' | '4-bit'; interface ModelSizeBarChartProps { modelSize: number; // in GB largestModelSize: number; // largest model in full precision (32-bit) modelPrecision: Precision; deviceMemorySet: boolean; activationMemorySize?: number; } interface InferenceRuntimeLineChartProps { availableMemory: AvailableMemory; // in GB memoryPerInput: number; // in GB } interface LineChartData { seqLength: number; batchSize: number; } interface AvailableMemory { '4-bit': number; '8-bit': number; '16-bit': number; '32-bit': number; } // Utility to determine color based on precision function chooseColor(precision: Precision) { const colors = { '32-bit': '#e45f5b', '16-bit': '#ffc068', '8-bit': '#71cce9', '4-bit': '#383d95', }; return colors[precision] || 'gray'; } // Calculate standard memory (model size based on precision only) function calculateStandardMemory(modelParams: number, precision: Precision): number { const precisionFactor = { '32-bit': 4, '16-bit': 2, '8-bit': 1, '4-bit': 0.5, }; const memory = modelParams * precisionFactor[precision]; // GB console.log(`[Standard] ${precision.toUpperCase()} Memory:`, memory); return memory; } // Calculate prefill chunking memory (model size + activation + input memory) function calculatePrefillMemory( modelParams: number, hiddenSize: number, numLayers: number, intermediateSize: number, precision: Precision ): number { const precisionFactor = { '32-bit': 4, '16-bit': 2, '8-bit': 1, '4-bit': 0.5, }; // Max Chunk Size - adjustable in the future const maxChunkSize = 512; // Calculate each memory component const modelMemorySize = modelParams * precisionFactor[precision]; // GB const activationMemorySize = (maxChunkSize * 2 * Math.max(2 * intermediateSize, 4 * hiddenSize)) / 1_000_000_000; // GB const memoryPerInput = (4 * hiddenSize * numLayers) / 1_000_000_000; // GB // Combine all components const totalMemory = modelMemorySize + activationMemorySize + memoryPerInput; console.log(`[Prefill] ${precision.toUpperCase()} Memory:`, totalMemory); console.log(`[Prefill] Activation Memory:`, activationMemorySize); console.log(`[Prefill] Memory Per Input:`, memoryPerInput); return totalMemory; } // Bar chart for model footprint (shared by both standard and prefill chunking calculators) function ModelSizeBarChart({ modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize = 0, }: ModelSizeBarChartProps) { const { theme } = useTheme(); const chartRef = useRef(null); const width = 600; const height = 50; useEffect(() => { if (modelSize > 0 && largestModelSize > 0) { d3.select(chartRef.current).selectAll('*').remove(); const svg = d3.select(chartRef.current) .attr('width', width) .attr('height', height) .style('animation', 'fadeIn 0.3s ease-in-out') // Inline animation .style('transition', 'transform 0.3s ease-in-out') // Hover effect .on('mouseover', function () { d3.select(this).style('transform', 'scale(1.02)'); }) .on('mouseout', function () { d3.select(this).style('transform', 'scale(1)'); }); const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]); if (modelSize + activationMemorySize > largestModelSize) { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', width) .attr('height', height) .attr('fill', 'transparent') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') .style('stroke-dasharray', '4, 4') .style('stroke-width', '2px'); svg .append('text') .attr('x', width / 2) .attr('y', height / 2) .attr('text-anchor', 'middle') .attr('alignment-baseline', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Out of Memory'); } else { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', xScale(modelSize)) .attr('height', height) .attr('fill', chooseColor(modelPrecision)); if (activationMemorySize > 0) { svg .append('rect') .attr('x', xScale(modelSize)) .attr('y', 0) .attr('width', xScale(activationMemorySize)) .attr('height', height) .attr('fill', '#a4b8e0'); } if (deviceMemorySet) { svg .append('rect') .attr('x', xScale(modelSize + activationMemorySize)) .attr('y', 0) .attr('width', xScale(largestModelSize - (modelSize + activationMemorySize))) .attr('height', height) .attr('fill', 'transparent') .style('stroke', chooseColor(modelPrecision)) .style('stroke-width', '2px'); } } } }, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, theme]); return ; } // Line chart for inference runtime (shared by both standard and prefill chunking calculators) function InferenceRuntimeLineChart({ availableMemory, memoryPerInput }: InferenceRuntimeLineChartProps) { const { theme } = useTheme(); const chartRef = useRef(null); const tooltipRef = useRef(null); // Ref for the tooltip const maxSeqLength = 4096; const maxBatchSize = 128; useEffect(() => { if (memoryPerInput > 0 && Object.values(availableMemory).some((val) => val > 0)) { const margin = { top: 20, right: 20, bottom: 50, left: 50 }; const width = 600 - margin.left - margin.right; const height = 400 - margin.top - margin.bottom; const svg = d3.select(chartRef.current); svg.selectAll('*').remove(); const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]); const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]); const xAxis = d3.axisBottom(xScale); const yAxis = d3.axisLeft(yScale); const zoom = d3.zoom() .scaleExtent([0.5, 10]) .translateExtent([[-width, -height], [2 * width, 2 * height]]) .on('zoom', (event) => { const transform = event.transform; svg.select('.x-axis').call(xAxis.scale(transform.rescaleX(xScale))); svg.select('.y-axis').call(yAxis.scale(transform.rescaleY(yScale))); svg.selectAll('path').attr('transform', transform); }); svg .attr('width', width + margin.left + margin.right) .attr('height', height + margin.top + margin.bottom) .append('g') .attr('transform', `translate(${margin.left}, ${margin.top})`) .call(zoom); svg.append('g').attr('class', 'x-axis').attr('transform', `translate(${margin.left}, ${height + margin.top})`).call(xAxis); svg.append('g').attr('class', 'y-axis').attr('transform', `translate(${margin.left}, ${margin.top})`).call(yAxis); svg.append('text') .attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`) .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Sequence Length'); svg.append('text') .attr('transform', `rotate(-90)`) .attr('y', 0) .attr('x', 0 - height / 2 - margin.top) .attr('dy', '1em') .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Batch Size'); // Adding legend for precisions const precisions = [ { name: '32-bit', color: '#e45f5b' }, { name: '16-bit', color: '#ffc068' }, { name: '8-bit', color: '#71cce9' }, { name: '4-bit', color: '#383d95' }, ]; const legend = svg .append('g') .attr('class', 'legend') .attr('transform', `translate(${width - 20}, 20)`); precisions.forEach((precision, index) => { const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`); legendItem.append('rect') .attr('x', 10) .attr('y', 10) .attr('width', 10) .attr('height', 10) .style('fill', precision.color); legendItem.append('text') .attr('x', 30) .attr('y', 16) .text(precision.name) .style('font-size', '16px') .style('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .attr('alignment-baseline', 'middle'); }); legend.append('rect') .attr('class', 'legend-box') .attr('width', 80) .attr('height', precisions.length * 30) .style('fill', 'none') .style('stroke-width', '1px') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26'); const tooltip = d3.select(tooltipRef.current) .style('position', 'absolute') .style('padding', '8px') .style('border-radius', '4px') .style('pointer-events', 'none') .style('opacity', 0) .style('transition', 'opacity 0.3s ease-in-out, transform 0.3s ease-in-out') .style('background-color', 'rgba(0, 0, 0, 0.75)') .style('color', 'white') .style('font-size', '14px'); for (const [precision, memory] of Object.entries(availableMemory)) { const sequenceLengths = d3.range(1, maxSeqLength, 1) .map((seqLength) => ({ seqLength, batchSize: memory / (seqLength * memoryPerInput), })) .filter((d) => d.batchSize <= maxBatchSize && d.batchSize > 1 && d.seqLength > 1); const lineGroup = svg.append('g').attr('transform', `translate(${margin.left}, ${margin.top})`); const line = d3.line() .x((d) => xScale(d.seqLength)) .y((d) => yScale(d.batchSize)) .curve(d3.curveBasis); lineGroup.append('path') .datum(sequenceLengths) .attr('fill', 'none') .attr('stroke', chooseColor(precision as Precision)) .attr('stroke-width', 4) .attr('d', line) .on('mouseover', () => { tooltip.style('opacity', 1) .style('transform', 'translateY(-10px)'); }) .on('mousemove', (event) => { tooltip.selectAll('text').remove(); const [x, y] = d3.pointer(event); const xValue = xScale.invert(x); const yValue = yScale.invert(y); tooltip.html(`Sequence Length: ${xValue.toFixed(0)}
Batch Size: ${yValue.toFixed(0)}`) .style('left', event.pageX + 10 + 'px') .style('top', event.pageY + 10 + 'px'); }) .on('mouseout', () => { tooltip.style('opacity', 0); }); } } }, [availableMemory, memoryPerInput, theme]); return ( <>
); } // Prefill Chunking Calculator with Updated Logic and Precision Adjustment function PrefillChunkingCalculator({ deviceMemory, modelParams, hiddenSize, numLayers, intermediateSize, }: { deviceMemory: number; modelParams: number; hiddenSize: number; numLayers: number; intermediateSize: number; }) { if (!deviceMemory || !modelParams || !hiddenSize || !numLayers || !intermediateSize) { return null; } // Calculate activation memory size based on intermediate size and hidden size const activationMemorySize = (512 * 2 * (Math.max(2 * intermediateSize, 4 * hiddenSize))) / 1_000_000_000; return ( <> {/* Model Footprint with Prefill Chunking */}
Model Footprint with Prefill Chunking
{(['32-bit', '16-bit', '8-bit', '4-bit'] as Precision[]).map((precision) => { const totalMemory = calculatePrefillMemory( modelParams, hiddenSize, numLayers, intermediateSize, precision ); return (
{precision.toUpperCase()}
0} activationMemorySize={activationMemorySize} // Updated to pass activation memory size />
{totalMemory.toFixed(2)} / {deviceMemory} GB
); })}
{/* Inference Runtime with Prefill Chunking */}
Maximum Batch Size / Sequence Length with Prefill Chunking
); } // Standard Model Memory Calculator (unchanged) function StandardCalculator({ deviceMemory, modelParams, hiddenSize, numLayers, }: { deviceMemory: number; modelParams: number; hiddenSize: number; numLayers: number; }) { if (!deviceMemory || !modelParams || !hiddenSize || !numLayers) { return null; } return ( <> {/* Model Footprint */}
Model Footprint
{(['32-bit', '16-bit', '8-bit', '4-bit'] as Precision[]).map((precision) => (
{precision.toUpperCase()}
0} />
{calculateStandardMemory(modelParams, precision).toFixed(2)} / {deviceMemory} GB
))}
{/* Maximum Batch Size / Sequence Length */}
Maximum Batch Size / Sequence Length
); } // Main Calculator Page const Calculator = () => { const [modelParams, setModelParams] = useState(null); const [hiddenSize, setHiddenSize] = useState(null); const [numLayers, setNumLayers] = useState(null); const [intermediateSize, setIntermediateSize] = useState(null); const [deviceMemory, setDeviceMemory] = useState(null); const [isPrefillChunking, setIsPrefillChunking] = useState(false); const [modelSelectionTab, setModelSelectionTab] = useState(true); const [deviceSelectionTab, setDeviceSelectionTab] = useState(true); return (
{/* Toggle Between Standard and Prefill Chunking */}
{/* Model and Device Selection */}
Model Memory Calculator
Use our Model Memory Calculator to help you estimate the memory footprint of your model and the maximum batch size/sequence length combination you can run on your device.
{/* Model Selection */}
Model
{modelSelectionTab ? ( <> ) : ( <> setModelParams(Number(e.target.value))} /> setHiddenSize(Number(e.target.value))} /> setNumLayers(Number(e.target.value))} /> {isPrefillChunking && ( <> setIntermediateSize(Number(e.target.value))} /> )} )}
{/* Device Selection */}
Device
{deviceSelectionTab ? ( <> ) : ( <> setDeviceMemory(Number(e.target.value))} /> )}
{/* Render Appropriate Calculator Based on Toggle */} {isPrefillChunking ? ( // eslint-disable-next-line ) : ( // eslint-disable-next-line )}
); }; export default Calculator;