AST equals and hashCode

This commit is contained in:
2023-08-21 00:58:05 -07:00
parent d92dc8d904
commit 1de73ed855
14 changed files with 163 additions and 3 deletions

View File

@ -42,8 +42,8 @@ interface NodeVisitor<T> {
else -> throw RuntimeException("Unknown Node") else -> throw RuntimeException("Unknown Node")
} }
fun visitNodes(vararg nodes: Node): List<T> = fun visitNodes(vararg nodes: Node?): List<T> =
nodes.map { visit(it) } nodes.filterNotNull().map { visit(it) }
fun visitAll(vararg nodeLists: List<Node>): List<T> = fun visitAll(vararg nodeLists: List<Node>): List<T> =
nodeLists.asSequence().flatten().map { visit(it) }.toList() nodeLists.asSequence().flatten().map { visit(it) }.toList()

View File

@ -4,4 +4,15 @@ import gay.pizza.pork.ast.NodeType
class BooleanLiteral(val value: Boolean) : Expression() { class BooleanLiteral(val value: Boolean) : Expression() {
override val type: NodeType = NodeType.BooleanLiteral override val type: NodeType = NodeType.BooleanLiteral
override fun equals(other: Any?): Boolean {
if (other !is BooleanLiteral) return false
return other.value == value
}
override fun hashCode(): Int {
var result = value.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -8,4 +8,16 @@ class Define(val symbol: Symbol, val value: Expression) : Expression() {
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitNodes(symbol, value) visitor.visitNodes(symbol, value)
override fun equals(other: Any?): Boolean {
if (other !is Define) return false
return other.symbol == symbol && other.value == value
}
override fun hashCode(): Int {
var result = symbol.hashCode()
result = 31 * result + value.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -8,4 +8,9 @@ class FunctionCall(val symbol: Symbol, val arguments: List<Expression>) : Expres
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitAll(listOf(symbol), arguments) visitor.visitAll(listOf(symbol), arguments)
override fun equals(other: Any?): Boolean {
if (other !is FunctionCall) return false
return other.symbol == symbol && other.arguments == arguments
}
} }

View File

@ -1,6 +1,7 @@
package gay.pizza.pork.ast.nodes package gay.pizza.pork.ast.nodes
import gay.pizza.pork.ast.NodeType import gay.pizza.pork.ast.NodeType
import gay.pizza.pork.ast.NodeVisitor
class If( class If(
val condition: Expression, val condition: Expression,
@ -8,4 +9,22 @@ class If(
val elseExpression: Expression? = null val elseExpression: Expression? = null
) : Expression() { ) : Expression() {
override val type: NodeType = NodeType.If override val type: NodeType = NodeType.If
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitNodes(condition, thenExpression, elseExpression)
override fun equals(other: Any?): Boolean {
if (other !is If) return false
return other.condition == condition &&
other.thenExpression == thenExpression &&
other.elseExpression == elseExpression
}
override fun hashCode(): Int {
var result = condition.hashCode()
result = 31 * result + thenExpression.hashCode()
result = 31 * result + (elseExpression?.hashCode() ?: 0)
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -3,9 +3,28 @@ package gay.pizza.pork.ast.nodes
import gay.pizza.pork.ast.NodeType import gay.pizza.pork.ast.NodeType
import gay.pizza.pork.ast.NodeVisitor import gay.pizza.pork.ast.NodeVisitor
class InfixOperation(val left: Expression, val op: InfixOperator, val right: Expression) : Expression() { class InfixOperation(
val left: Expression,
val op: InfixOperator,
val right: Expression
) : Expression() {
override val type: NodeType = NodeType.InfixOperation override val type: NodeType = NodeType.InfixOperation
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitNodes(left, right) visitor.visitNodes(left, right)
override fun equals(other: Any?): Boolean {
if (other !is InfixOperation) return false
return other.op == op &&
other.left == left &&
other.right == right
}
override fun hashCode(): Int {
var result = left.hashCode()
result = 31 * result + op.hashCode()
result = 31 * result + right.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -4,4 +4,15 @@ import gay.pizza.pork.ast.NodeType
class IntLiteral(val value: Int) : Expression() { class IntLiteral(val value: Int) : Expression() {
override val type: NodeType = NodeType.IntLiteral override val type: NodeType = NodeType.IntLiteral
override fun equals(other: Any?): Boolean {
if (other !is IntLiteral) return false
return other.value == value
}
override fun hashCode(): Int {
var result = value
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -8,4 +8,16 @@ class Lambda(val arguments: List<Symbol>, val expressions: List<Expression>) : E
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitAll(arguments, expressions) visitor.visitAll(arguments, expressions)
override fun equals(other: Any?): Boolean {
if (other !is Lambda) return false
return other.arguments == arguments && other.expressions == expressions
}
override fun hashCode(): Int {
var result = arguments.hashCode()
result = 31 * result + expressions.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -8,4 +8,15 @@ class ListLiteral(val items: List<Expression>) : Expression() {
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitAll(items) visitor.visitAll(items)
override fun equals(other: Any?): Boolean {
if (other !is ListLiteral) return false
return other.items == items
}
override fun hashCode(): Int {
var result = items.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -8,4 +8,15 @@ class Parentheses(val expression: Expression) : Expression() {
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitNodes(expression) visitor.visitNodes(expression)
override fun equals(other: Any?): Boolean {
if (other !is Parentheses) return false
return other.expression == expression
}
override fun hashCode(): Int {
var result = expression.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -1,7 +1,23 @@
package gay.pizza.pork.ast.nodes package gay.pizza.pork.ast.nodes
import gay.pizza.pork.ast.NodeType import gay.pizza.pork.ast.NodeType
import gay.pizza.pork.ast.NodeVisitor
class PrefixOperation(val op: PrefixOperator, val expression: Expression) : Expression() { class PrefixOperation(val op: PrefixOperator, val expression: Expression) : Expression() {
override val type: NodeType = NodeType.PrefixOperation override val type: NodeType = NodeType.PrefixOperation
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitNodes(expression)
override fun equals(other: Any?): Boolean {
if (other !is PrefixOperation) return false
return other.op == op && other.expression == expression
}
override fun hashCode(): Int {
var result = op.hashCode()
result = 31 * result + expression.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -8,4 +8,15 @@ class Program(val expressions: List<Expression>) : Node() {
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitAll(expressions) visitor.visitAll(expressions)
override fun equals(other: Any?): Boolean {
if (other !is Program) return false
return other.expressions == expressions
}
override fun hashCode(): Int {
var result = expressions.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -4,4 +4,15 @@ import gay.pizza.pork.ast.NodeType
class Symbol(val id: String) : Node() { class Symbol(val id: String) : Node() {
override val type: NodeType = NodeType.Symbol override val type: NodeType = NodeType.Symbol
override fun equals(other: Any?): Boolean {
if (other !is Symbol) return false
return other.id == id
}
override fun hashCode(): Int {
var result = id.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }

View File

@ -8,4 +8,15 @@ class SymbolReference(val symbol: Symbol) : Expression() {
override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> = override fun <T> visitChildren(visitor: NodeVisitor<T>): List<T> =
visitor.visitNodes(symbol) visitor.visitNodes(symbol)
override fun equals(other: Any?): Boolean {
if (other !is SymbolReference) return false
return other.symbol == symbol
}
override fun hashCode(): Int {
var result = symbol.hashCode()
result = 31 * result + type.hashCode()
return result
}
} }