use derivatives.frink
use integrals.frink
use solvingTransformations.frink
use powerTransformations.frink
// This class solves a system of equations. It contains a graph of
// EquationNodes. Each EquationNode represents an equation in the system.
// These nodes are connected by Edges which indicate that various
// equations are interrelated by containing the same variables.
class Solver
// This is an array of EquationNodes that make up the graph.
var equationNodes
// This is a set of strings indicating the variables that we are not going
// to solve for.
var ignoreSet
// Boolean flag indicating if we have done phase 1 simplifications
var initialized
// The cached final solutions for each variable, keyed by variable name
var finalSolutions
// A (currently-unused) set of aliases that will simplify equations.
var aliases
// Create a new Solver.
new[eqs, ignoreList=[]] :=
ignoreSet = new set
equationNodes = new array
finalSolutions = new dict
for i = ignoreList
for eq = eqs
initialized = false
// Add an equation to the system. This creates a new EquationNode and
// automatically connects it to all of the other EquationNodes in the
// system that share its equations.
// If index is undef. this will push the node onto the end of the list.
// if index is a number, this will replace the specified node.
// (public)
addEquation[eq, index=undef] :=
// THINK ABOUT: Call transformExpression to canonicalize first?
// Check for duplicate equations.
for n = equationNodes
if structureEquals[eq, n.getOriginalEquation[]]
println["Eliminating duplicate equation $eq"]
reducedUnknowns = setDifference[getSymbols[eq], ignoreSet]
if length[reducedUnknowns] == 0
println["WARNING: Equation $eq has no unknowns!"]
node = new EquationNode[eq, reducedUnknowns]
// If replacing, disconnect the other node
if index != undef
for other = equationNodes
// THINK ABOUT: Should we check to see if one equation is a proper
// subset of the variables of the other and push the simpler into
// the more complex right now?
// This is a set of shared variables between the two equations.
sharedVars = intersection[reducedUnknowns, other.getUnknowns[]]
for varName = sharedVars
connect[node, other, varName]
if index == undef
equationNodes.insert[index, node]
// Discard any previously-found solutions
finalSolutions = new dict
initialized = false
// Method to initialize and simplify the system.
initialize[] :=
if ! initialized
// draw[]
changed = solveSimultaneous[]
// draw[]
// pushAliases[]
// draw[]
if changed
initialized = true
// Removes the specified node from the graph. This removes all connections
// to the specified node.
// (public)
remove[index] :=
node = equationNodes.remove[index]
initialized = false
finalSolutions = new dict // Discard any previously-found solutions.
// Disconnect the specified node from the graph.
// (private)
disconnect[node] :=
// Connect the two specified equations by the specified variable.
// (private)
connect[n1 is EquationNode, n2 is EquationNode, varName is string] :=
e = new Edge[n1, n2, varName]
// Return a count of the EquationNodes in the system
// (public)
getEquationCount[] := length[equationNodes]
// Returns the EquationNode with the specified index
// (public)
getEquationNode[index] := equationNodes@index
// Returns an array with each element in the array being an array
// [ unknowns, index ]
// of the equations in the system, ordered with the simplest equations
// (those with the fewest unknowns) first. Unknowns is a set.
getEquationsSortedByComplexity[] :=
list = new array
i = 0
last = length[equationNodes] - 1
for i = 0 to last
list.push[[equationNodes@i.getUnknowns[], i]]
sort[list, {|a,b| length[a@0] <=> length[b@0]}]
return list
// Returns the unknowns for the specified index.
getUnknowns[index] := equationNodes@index.getUnknowns[]
// Prints out the state of the solver for debugging.
// (public)
dump[] :=
last = getEquationCount[]-1
for i = 0 to last
node = getEquationNode[i]
print["$i\t" + node.getOriginalEquation[] + "\t"]
for e = node.getEdges[]
other = e.getOtherNode[node]
print["[" + e.getVariableName[] + "," + getNodeIndex[other] +"] "]
// Draw a representation of the system.
draw[g is graphics, left=0, top=0, right=1, bottom=2] :=
last = getEquationCount[]-1
width = right-left
height = bottom-top
g.font["Serif", "italic", height/30]
cy = top + height/2
w = 0.7 width/20;
for i = 0 to last
node = getEquationNode[i]
[x,y] = getPosition[i, left, top, right, bottom]
for e = node.getEdges[]
oi = getNodeIndex[e.getOtherNode[node]]
if (i < oi)
[ox, oy] = getPosition[oi, left, top,right,bottom]
g.line[x+randomFloat[-w,w],y+randomFloat[-w,w],ox+randomFloat[-w, w],oy+randomFloat[-w,w]]
g.fillEllipseCenter[x,y,width/10, height/2/10]
g.drawEllipseCenter[x,y,width/10, height/2/10]
g.text["$i: " + node.getOriginalEquation[], left, i * height/20 + cy, "left", "top"]
// Calculates the center of the specified node in the graph.
getPosition[index, left, top, right, bottom] :=
phase = circle / getEquationCount[]
width = right-left
height = bottom-top
cx = left + width/2
cy = top + height/4
radius = 1/2 (.9) width
x = radius sin[index phase] + cx
y = radius * -cos[index phase] + cy
return [x,y]
draw[] :=
g = new graphics
// Gets the node index given the node.
getNodeIndex[node is EquationNode] :=
last = length[equationNodes]
for i=0 to last
if equationNodes@i == node
return i
return "Node not found!"
// Keep replacing simpler variables until we're done.
pushSimpler[] :=
// dump[]
while pushSimplerOnce[]
// dump[]
// Pushes the simpler equations into the more complex
// equations. This returns true if changes were made to the system,
// false otherwise.
pushSimplerOnce[] :=
sortedEqs = getEquationsSortedByComplexity[]
last = length[sortedEqs]-1
for i=0 to last-1
[unknownsI, nodeI] = sortedEqs@i
for j=i+1 to last
[unknownsJ, nodeJ] = sortedEqs@j
if isProperSubset[unknownsI, unknownsJ]
// println["Node $nodeI is a proper subset of $nodeJ"]
simpleVarList = pickSimplestVariable[nodeI, nodeJ]
for [simpleVar, count] = simpleVarList
sols = equationNodes@nodeI.getSolutions[simpleVar]
if length[sols] < 1
println["Ugh. Has " + length[sols] + " solutions in pushSimplerOnce. Maybe improve pickSimplestVariable? Solutions are: $sols, equations are " + equationNodes@nodeI.getOriginalEquation[] + ", " + equationNodes@nodeJ.getOriginalEquation[] ]
} else
origEq = equationNodes@nodeJ.getOriginalEquation[]
// println["Removing node $nodeJ"]
if (length[sols] > 1)
println["Warning: pushSimplerOnce split into " + length[sols] + " solutions: $sols"]
for sol = sols
// println["Replacing $simpleVar in $origEq with " + child[sol,1]]
newEq = replaceVar[origEq, simpleVar, child[sol,1]]
[newEq, fullyReduced] = prettify[newEq]
if !fullyReduced
newEq = transformExpression[newEq]
return true
return false // We made no changes.
// Pushes aliases like a == b or a == 2 b around the system so that
// the system is as simple and disconnected as possible.
pushAliases[] :=
size = length[equationNodes]
i = 0
while i<size
node = equationNodes@i
eq = equationNodes@i.getOriginalEquation[]
left = child[eq, 0]
lu = getSymbols[left]
lu = setDifference[lu, ignoreSet]
if length[lu] == 1
right = child[eq, 1]
ru = getSymbols[right]
ru = setDifference[ru, ignoreSet]
if length[ru] == 1
ls = array[lu]@0
rs = array[ru]@0
first = ls > rs ? ls : rs
sols = node.getSolutions[first]
if length[sols] != 1
println["Ugh. Has " + length[sols] + " solutions in pushAliases. Solutions are:\n" + sols]
} else
println["Replacing " + sols@0]
replaceAll[sols@0, i]
i = i - 1 // Don't increment i
size = size - 1
i = i + 1
// Simplifies an equation to the form a === 10 if only one variable
// remains.
// returns [equation, remainingSymbol, reduced]
// where reduced is a boolean flag indicating if the equation has been
// simplified to the form above.
prettify[eq] :=
reduced = false
solvedUnknowns = getSymbols[eq]
solvedUnknowns = setDifference[solvedUnknowns, ignoreSet]
remainingSymbol = undef
if length[solvedUnknowns] == 1
remainingSymbol = array[solvedUnknowns]@0
prettified = array[solveSingle[eq, remainingSymbol]]
if length[prettified] == 1
reduced = true
eq = prettified@0
} else
eq = transformExpression[eq]
return [eq, reduced]
// Picks the simplest variable shared by nodes 1 and 2. Node 1 should be
// the simpler node. Returns a string.
pickSimplestVariable[index1, index2] :=
node1 = equationNodes@index1
node2 = equationNodes@index2
u1 = node1.getUnknowns[]
u2 = node2.getUnknowns[]
intersection = intersection[u1, u2]
results = new array
sortedUnknowns = getSymbolsByComplexity[node1.getOriginalEquation[]]
//println["Sorted unknowns is $sortedUnknowns"]
for [unknown, count] = sortedUnknowns
if intersection.contains[unknown]
results.push[[unknown, count]]
sortedUnknowns = getSymbolsByComplexity[node2.getOriginalEquation[]]
for [unknown, count] = sortedUnknowns
if intersection.contains[unknown]
for i = 0 to length[results]-1
if results@i@0 == unknown
results@i@1 = results@i@1 + count
sort[results, {|a,b| a@1 <=> b@1}]
return results
// Returns true if the node at index 1 contains unknowns which are a
// proper subset of the unknowns in index 2. This means that node 1
// is a "simpler" version of equation 2, and its values should be
// substituted into equation 2.
isSimpler[index1, index2] :=
return isProperSubset[equationNodes@index1.getUnknowns[],
// Eliminate simultaneous equations in the system.
// returns true if the system has been changed.
solveSimultaneous[] :=
changed = false
// TODO: Sort by simplest equations?
size = length[equationNodes]
for i=0 to size-2
for j = i+1 to size-1
nodeI = equationNodes@i
nodeJ = equationNodes@j
ui = nodeI.getUnknowns[]
uj = nodeJ.getUnknowns[]
sharedUnknowns = intersection[ui, uj]
// println[nodeI.getOriginalEquation[]]
// println[nodeJ.getOriginalEquation[]]
// println["$i: $ui\t$j: $uj"]
// println["$i $j Shared unknowns are $sharedUnknowns"]
if length[sharedUnknowns] >= 2
varsToReplace = pickSimplestVariable[i, j]
// println["varsToReplace is $varsToReplace"]
for [varToReplace, count] = varsToReplace
skipNode = i
solution = nodeI.getSolutions[varToReplace]
if length[solution] != 1
// Didn't find single solution, try solving
// and replacing from the other node.
solution = nodeJ.getSolutions[varToReplace]
skipNode = j
if length[solution] == 1
replaceAll[solution@0, skipNode]
// dump[]
changed = true
break JLOOP
println["Ugh. SolveSimultaneous fell through without replacing. Equations were " + nodeI.getOriginalEquation[] + " and " + nodeJ.getOriginalEquation[]]
return changed
// Replace the specified symbol, recursively, in all equations except
// the index specified.
// (private)
replaceAll[solution, skipIndex] :=
size = length[equationNodes]
sym = child[solution,0]
rep = child[solution,1] // Right-hand-side of solution
// Substitute result into other equations.
for k = 0 to size-1
if k != skipIndex
orig = equationNodes@k.getOriginalEquation[]
subst = substituteExpression[orig, sym, rep]
// println["orig is $orig, sym is $sym, solution is $solution, rep is $rep, subst is $subst"]
if orig != subst // and length[getSymbols[subst]] <= length[getSymbols[orig]]
[subst, eqSolved] = prettify[subst]
subst2 = transformExpression[subst] // THINK ABOUT: Do this?
if structureEquals[_a === _b, subst2] // and structureEquals[child[subst2,0],sym] and ! expressionContains[child[subst2,1], sym]
subst = subst2
println["Warning: In replaceAll, did not get solution. Input was $solution, output was $subst2"]
// println["Substituted $sym to " + rep + " in $orig, result is $subst"]
addEquation[subst, k] // Replace equation.
if eqSolved
// println["Going to recursively replace $sym"]
replaceAll[subst, k] // Recursively replace others
// Return a set of all unknowns in the system.
getAllUnknowns[] :=
allUnknowns = new set
for node = equationNodes
allUnknowns = union[allUnknowns, node.getUnknowns[]]
return allUnknowns
// Solves for all variables in the system.
solveAll[] :=
allUnknowns = getAllUnknowns[]
results = new array
for u = allUnknowns
res = solveFor[u]
for eq = res
return results
// Solves the system for the specified variable name.
// (public)
solveFor[varName] :=
if !initialized
cached = finalSolutions@varName
if cached
return cached
results = new array
size = getEquationCount[]
for i=0 to size-1
if getUnknowns[i].contains[varName]
partialResults = solveNodeForVariable[i, varName]
for r = partialResults
// Cache results.
finalSolutions@varName = results
return results
// Solve for the specified variable name, substituting the list of
// arguments. Args is an array of ["varname", value] pairs.
// The answer will be returned symbolically as an equation in the form
// varName === solution
// with constants and units still intact.
solveForSymbolic[varName, args] :=
results = new array
sols = solveFor[varName]
for sol = sols
for [arg, val] = args
sym = constructExpression["Symbol", arg]
sol = substituteExpression[sol, sym, val]
// THINK ABOUT: Transform expression here to simplify?
// res = transformExpression[res]
return eliminateOverconstrained[results, false, false]
// Solve for the specified variable name, substituting the list of
// arguments. The result is a list of evaluated solutions.
solveFor[varName, args] :=
sols = solveForSymbolic[varName, args]
results = new array
for sol = sols
right = child[sol,1]
final = eval[right]
exists = false
for r = results
if (final conforms r) and (final == r)
exists = true
if ! exists
return results
// Recursive method to find the solutions for the specified variable
// starting from the specified node. This recursively enumerates all
// of the permutations of substitutions in the system.
// This method just sets up parameters for the recursive call.
// (Private method)
solveNodeForVariable[index, varName, cachedSolutions = undef] :=
if cachedSolutions == undef
cachedSolutions = new dict
node = getEquationNode[index]
sols = node.getSolutions[varName]
// println["Solutions for $varName are $sols"]
results = solveNodeForVariable[node, varName, sols, new set, cachedSolutions]
results = transformExpression[results]
// return results
return eliminateOverconstrained[results, true, false]
// The recursive (private) call to solve for the particular variable.
solveNodeForVariable[node, varName, inEqs, usedEdges, cachedSolutions] :=
// print["Solving for $varName in " + getNodeIndex[node] + ", {"]
// for e = usedEdges
// print[e.getVariableName[] + " "]
// println["}"]
// Return partial solution from cache if possible.
if cachedSolutions.containsKey[node]
varDict = cachedSolutions@node
if varDict.containsKey[varName]
edgeDict = varDict@varName
if edgeDict.containsKey[usedEdges]
return edgeDict@usedEdges
results = inEqs.shallowCopy[]
edges = setDifference[node.getEdges[], usedEdges]
for e = edges
if e.getVariableName[] == varName
len = length[edges]
if (len == 0) // No more replacements to do.
putCache[cachedSolutions, node, varName, usedEdges, results]
return results
// Set up states array to enumerate through permutations.
states = new array
for i=0 to len-2
states@i = false
states@(len-1) = true // Skip all-false state (no replacements)
i = len-1
edgeArray = array[edges]
//newUsedEdges = union[node.getEdges[], usedEdges]
while i >= 0
newUsedEdges = usedEdges.shallowCopy[]
for j = 0 to len-1
if states@j
// Perform replacements on each edge
for j = 0 to len-1
edge = edgeArray@j
// newUsedEdges.put[edge]
// Mark this edge as used.
replacingVar = edge.getVariableName[]
if states@j
replacingSymbol = edge.getSymbol[]
otherNode = edge.getOtherNode[node]
// newGlobalUsedNodes.put[newUsedNodesHere]
// Recursively solve the other node for the variable
// represented by this edge.
repList = solveNodeForVariable[otherNode,
cachedSolutions ]
// println["repList is $repList"]
for repWithFull = repList
repWith = child[repWithFull, 1] // Get right-hand-side
for eq = inEqs
res = substituteExpression[eq, replacingSymbol, repWith]
// println["Replacing $replacingVar with $repWith in $eq, result is $res"]
// Check to see if the variable we're solving for occurs on the right
rightSyms = getSymbols[child[res,1]]
if rightSyms.contains[varName]
//println["WARNING: Right side contains $varName in $res"]
res2 = solveSingle[res, varName]
//println["Re-solving: $res2"]
// TODO: This may return a whole lot of solutions.
// We need to evaluate each one and push them all
// onto the solutions list.
varSymbol = constructExpression["Symbol", varName]
for subR = array[res2]
if structureEquals[_a === _b, subR] and structureEquals[child[subR,0],varSymbol] and ! expressionContains[child[subR,1], varSymbol]
//println["Re-solving successful."]
} else
// println["WARNING: Right side contains $varName in $res"]
println["Re-solving FAILED: $res2"]
// println["Re-solving FAILED."]
} else
// res = transformExpression[res]
// Advance to next binary state
flipped = false
i = len-1
while i>=0 and !flipped
// Enter next state
if states@i == false
states@i = true
flipped = true
} else
{ // Carry
states@i = false
i = i - 1
// i now contains the last index flipped. If i < 0, we're done
results = eliminateOverconstrained[results, true, false]
putCache[cachedSolutions, node, varName, usedEdges, results]
return results
// This function eliminates overconstrained equations. For example, a
// system containing the solutions a===1/2 c r and a===c d^-1 r^2 is
// overconstrained because a value can always be obtained with the first
// equation. The second is not necessary, and could lead to
// inconsistent results. This method ignores any symbols listed in the
// ignoreSymbols list, (these are probably units,) eliminating them from
// the equations.
eliminateOverconstrained[eqArray, dupsOnly, debug=false] :=
size = length[eqArray]
unknowns = new array
lefts = new array
for i = 0 to size-1
lefts@i = child[eqArray@i, 0]
unknowns@i = setDifference[getSymbols[child[eqArray@i,1]], ignoreSet]
res = new array
// Check for duplicates.
for i=0 to size-1
remove = false
j = 0
if i != j and structureEquals[lefts@i, lefts@j]
remove = (i<j and structureEquals[eqArray@i, eqArray@j]) or ((! dupsOnly) and isProperSubset[unknowns@j, unknowns@i])
if remove
if debug
println[eqArray@j + " is a proper subset or match of " + eqArray@i]
} while (j < size) and ! remove
if (! remove)
res.push[eqArray@i] // If we got here, no j is a proper subset of i.
return res
// Puts the specified values into the cache.
class putCache[cachedSolutions, node, varName, usedEdges, vals] :=
if ! cachedSolutions.containsKey[node]
nodeDict = new dict
cachedSolutions@node = nodeDict
} else
nodeDict = cachedSolutions@node
if ! nodeDict.containsKey[varName]
varDict = new dict
nodeDict@varName = varDict
} else
varDict = nodeDict@varName
varDict@usedEdges = vals
// This is an experimental function that uses Frink's multi-input capability
// to interactively plug numbers into a solution.
interact[] :=
allUnknowns = sort[array[getAllUnknowns[]]]
opts = new array
for u = allUnknowns
sols = solveFor[u]
if length[sols] != 1
vals = input["Enter values: ", opts]
// This is a node in the graph that represents an equation. It stores various
// information about the variables stored in the equation and its solutions
// for those variables. Users will not create these directly, but rather
// call methods on the System class to create these nodes and connect them
// properly.
class EquationNode
// The original equation
var origEq
// A set of unknowns in the equation.
var unknowns
// An set of edges that connect this node to other nodes.
var edges
// This is a dictionary whose key is the variable name (as a string)
// and the value is an object of type SolutionPart.
// If this is undef, it means that the equation has not
var solvedDict
// Create a new equation. The equation should contain a === expression.
// The set reducedUnknowns is the unknowns with the ignoreSet removed.
new[eq, reducedUnknowns is set] :=
origEq = eq
unknowns = reducedUnknowns
edges = new set
solvedDict = new dict
// Add a new edge that connects to another node
addEdge[e is Edge] :=
// Returns the primary equation for this node.
getOriginalEquation[] := origEq
// Return the set of unknowns in this equation.
getUnknowns[] := unknowns
// Disconnect any edges that connect this node to other nodes.
disconnectAllEdges[] :=
for e = edges
other = e.getOtherNode[this]
edges = new set
// Remove any edges that connect to the specified EquationNode.
// This NOT recursive and should only be called from disconnectAllEdges
// (private method)
removeEdgesTo[node is EquationNode] :=
for e = edges
if e.connectsTo[node]
// Returns a set of all Edge objects.
getEdges[] := edges
// Gets the solutions to this equation for the specified variable.
// This will fetch the value from the cache if it exists, otherwise it
// will solve for it.
getSolutions[varName] :=
if ! solvedDict.containsKey[varName]
solution = solveSingle[origEq, varName]
addSolutions[varName, solution]
// Add one or more solutions for the specified variable to a node.
// The equation may be a single equation or a list of equations.
// This checks to ensure that the equations are properly solved for the
// variable.
addSolutions[varName, equation] :=
if solvedDict.containsKey[varName]
sp = solvedDict@varName
sp = new SolutionPart[varName]
solvedDict@varName = sp
sym = sp.symbol
for eq = flatten[array[equation]]
// Make sure that the solution was properly solved for the variable.
// this ensures that the equation is of the form var == solution
// where solution does not contain the variable.
if structureEquals[_a === _b, eq] and structureEquals[child[eq,0],sym] and ! expressionContains[child[eq,1], sym]
println["Could not solve $origEq for $varName! Solution was $eq"]
// This represents an edge between two EquationNodes in a graph. It
// contains a variable name which defines the variable that connects the
// two nodes. Users will not create these directly, but rather
// call methods on the System class to create these nodes and connect them
// properly.
class Edge
// A string containing the name of the variable that connects the two
// EquationNodes
var varName
// A symbol representing the symbol.
var varSymbol
// One of the EquationNodes that this connects.
var node1
// The other one of the EquationNodes that this connects.
var node2
// Create a new Edge that connects the specified EquationNodes.
new[n1 is EquationNode, n2 is EquationNode, name is string] :=
node1 = n1
node2 = n2
varName = name
varSymbol = constructExpression["Symbol", name]
// Returns a string containing the variable name which indicates the
// variable name that connects the two Nodes by this edge
getVariableName[] := varName
// Returns a symbol representing the variable.
getSymbol[] := varSymbol
// Returns the other node of this edge.
getOtherNode[oneNode is EquationNode] :=
if oneNode == node1
return node2
if oneNode == node2
return node1
println["getOtherNode: Matching node not found!"]
return undef
// Returns true if this Edge connects to the specified node.
connectsTo[node is EquationNode] :=
return node == node1 or node == node2
// This is a helper class that contains the solutions for a particular
// variable.
class SolutionPart
var symbol // The variable stored as a symbol.
var solutions // An array of solutions in the form
// a === b + c
// Construct a new SolutionPart given the string that represents the
// variable.
new[symbolString] :=
symbol = constructExpression["Symbol", symbolString]
solutions = new array
// Adds a solution to the list. This does not do any checking; that is
// performed by EquationNode.addSolution.
addSolution[eq] :=
// Solve a single equation for the specified symbol.
// TODO: Verify that the equation was solved appropriately?
solveSingle[eq, symbol] :=
xSymbol = constructExpression["Symbol", symbol]
// We could use this in symbolic mode, otherwise it warns.
// solveEq = solve[eq, xSymbol]
solveEq = constructExpression["FunctionCall", ["solve", eq, xSymbol]]
return transformExpression[solveEq]
// Replace a variable in the specified equation with the specified value.
// You should probably wrap the "eq" and "value" in a noEval[] block if
// you're not passing them in from a variable.
replaceVar[eq, varName, value] :=
sym = constructExpression["Symbol", varName]
res = substituteExpression[eq, sym, value]
// THINK ABOUT: Transform expression here to simplify?
// res = transformExpression[res]
return res
