summaryrefslogtreecommitdiffstats
path: root/src/com/craftinginterpreters/tool/GenerateAst.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/com/craftinginterpreters/tool/GenerateAst.java')
-rw-r--r--src/com/craftinginterpreters/tool/GenerateAst.java88
1 files changed, 88 insertions, 0 deletions
diff --git a/src/com/craftinginterpreters/tool/GenerateAst.java b/src/com/craftinginterpreters/tool/GenerateAst.java
new file mode 100644
index 0000000..5e79d2f
--- /dev/null
+++ b/src/com/craftinginterpreters/tool/GenerateAst.java
@@ -0,0 +1,88 @@
+package com.craftinginterpreters.tool;
+
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Arrays;
+import java.util.List;
+
+public class GenerateAst {
+ public static void main(String[] args) throws IOException {
+ if (args.length != 1) {
+ System.err.println("Usage: generate_ast <output directory>");
+ System.exit(64);
+ }
+
+ String outputDir = args[0];
+
+ defineAst(outputDir, "Expr", Arrays.asList("Binary : Expr left, Token operator, Expr right",
+ "Grouping : Expr expression", "Literal : Object value", "Unary : Token operator, Expr right"));
+ }
+
+ private static void defineAst(String outputDir, String baseName, List<String> types) throws IOException {
+ String path = outputDir + "/" + baseName + ".java";
+ PrintWriter writer = new PrintWriter(path);
+
+ writer.println("package com.craftinginterpreters.lox;");
+ writer.println();
+ writer.println("import java.util.List;");
+ writer.println();
+ writer.println("abstract class " + baseName + " {");
+
+ defineVisitor(writer, baseName, types);
+
+ for (String type : types) {
+ String className = type.split(":")[0].trim();
+ String fields = type.split(":")[1].trim();
+ defineType(writer, baseName, className, fields);
+ }
+
+ // The base accept() method
+ writer.println();
+ writer.println(" abstract <R> R accept(Visitor<R> visitor);");
+
+ writer.println("}");
+ writer.close();
+ }
+
+ private static void defineVisitor(PrintWriter writer, String baseName, List<String> types) {
+ writer.println(" interface Visitor<R> {");
+
+ for (String type : types) {
+ String typeName = type.split(":")[0].trim();
+ writer.println(" R visit" + typeName + baseName + "(" + typeName + " " + baseName.toLowerCase() + ");");
+ }
+
+ writer.println(" }");
+ }
+
+ private static void defineType(PrintWriter writer, String baseName, String className, String fieldList) {
+ writer.println(" static class " + className + " extends " + baseName + " {");
+
+ // Constructor
+ writer.println(" " + className + "(" + fieldList + ") {");
+
+ // Store parameters in fields.
+ String[] fields = fieldList.split(", ");
+ for (String field : fields) {
+ String name = field.split(" ")[1];
+ writer.println(" this." + name + " = " + name + ";");
+ }
+
+ writer.println(" }");
+
+ // Visitor pattern.
+ writer.println();
+ writer.println(" @Override");
+ writer.println(" <R> R accept(Visitor<R> visitor) {");
+ writer.println(" return visitor.visit" + className + baseName + "(this);");
+ writer.println(" }");
+
+ // Fields
+ writer.println();
+ for (String field : fields) {
+ writer.println(" final " + field + ";");
+ }
+
+ writer.println(" }");
+ }
+}