From bea04d681025d64643526321c345bc51db626448 Mon Sep 17 00:00:00 2001 From: Luke Wilson Date: Thu, 5 Jun 2025 22:35:18 -0500 Subject: [PATCH] Add comprehensive test system for Go-to-Zig compiler - Implement end-to-end test runner for compilation and behavior tests - Add test cases for basic print functionality - Refactor translator to use proper AST generation - Remove redundant programs directory in favor of tests --- Makefile | 12 +- cmd/testrunner/main.go | 228 +++++++++++++++++++++++++++ hello.zig | 5 + internal/main.go | 135 ++++++++++++---- tests/basic/hello.expected | 1 + {programs => tests/basic}/hello.go | 4 +- tests/basic/multiple_prints.expected | 1 + tests/basic/multiple_prints.go | 8 + tests/basic/print_escape.expected | 5 + tests/basic/print_escape.go | 8 + 10 files changed, 370 insertions(+), 37 deletions(-) create mode 100644 cmd/testrunner/main.go create mode 100644 hello.zig create mode 100644 tests/basic/hello.expected rename {programs => tests/basic}/hello.go (50%) create mode 100644 tests/basic/multiple_prints.expected create mode 100644 tests/basic/multiple_prints.go create mode 100644 tests/basic/print_escape.expected create mode 100644 tests/basic/print_escape.go diff --git a/Makefile b/Makefile index efb04a2..885bb8a 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,15 @@ run: go run internal/main.go -o hello.zig programs/hello.go && zig run hello.zig -test: +test: test-unit test-integration + +test-unit: go test ./internal/zig + +test-integration: + go run ./cmd/testrunner + +test-quick: + go run ./cmd/testrunner + +.PHONY: run test test-unit test-integration test-quick diff --git a/cmd/testrunner/main.go b/cmd/testrunner/main.go new file mode 100644 index 0000000..21ffea8 --- /dev/null +++ b/cmd/testrunner/main.go @@ -0,0 +1,228 @@ +package main + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" +) + +type TestResult struct { + Name string + Passed bool + Output string + Error string + Elapsed time.Duration +} + +type TestCase struct { + GoFile string + ExpectedFile string + ErrorFile string + StdinFile string + ArgsFile string +} + +func main() { + // Find all test cases + testCases, err := findTestCases("tests") + if err != nil { + fmt.Fprintf(os.Stderr, "Error finding tests: %v\n", err) + os.Exit(1) + } + + if len(testCases) == 0 { + fmt.Println("No tests found") + os.Exit(0) + } + + fmt.Printf("Running %d tests...\n\n", len(testCases)) + + // Run tests in parallel + results := runTests(testCases) + + // Print results + passed := 0 + failed := 0 + for _, result := range results { + if result.Passed { + fmt.Printf("✓ %s (%v)\n", result.Name, result.Elapsed) + passed++ + } else { + fmt.Printf("✗ %s (%v)\n", result.Name, result.Elapsed) + fmt.Printf(" Error: %s\n", result.Error) + if result.Output != "" { + fmt.Printf(" Output:\n%s\n", indent(result.Output, " ")) + } + failed++ + } + } + + fmt.Printf("\n%d passed, %d failed\n", passed, failed) + + if failed > 0 { + os.Exit(1) + } +} + +func findTestCases(dir string) ([]TestCase, error) { + var testCases []TestCase + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if strings.HasSuffix(path, ".go") { + base := strings.TrimSuffix(path, ".go") + testCase := TestCase{ + GoFile: path, + ExpectedFile: base + ".expected", + ErrorFile: base + ".error", + StdinFile: base + ".stdin", + ArgsFile: base + ".args", + } + + // Check if either .expected or .error exists + if fileExists(testCase.ExpectedFile) || fileExists(testCase.ErrorFile) { + testCases = append(testCases, testCase) + } + } + + return nil + }) + + return testCases, err +} + +func runTests(testCases []TestCase) []TestResult { + results := make([]TestResult, len(testCases)) + var wg sync.WaitGroup + + for i, tc := range testCases { + wg.Add(1) + go func(idx int, testCase TestCase) { + defer wg.Done() + results[idx] = runTest(testCase) + }(i, tc) + } + + wg.Wait() + return results +} + +func runTest(tc TestCase) TestResult { + start := time.Now() + testName := filepath.Base(strings.TrimSuffix(tc.GoFile, ".go")) + + result := TestResult{ + Name: testName, + Elapsed: time.Since(start), + } + + // Create temporary directory for test outputs + tempDir, err := os.MkdirTemp("", "go-zig-test-*") + if err != nil { + result.Error = fmt.Sprintf("Failed to create temp dir: %v", err) + return result + } + defer os.RemoveAll(tempDir) + + zigFile := filepath.Join(tempDir, testName+".zig") + + // Compile Go to Zig + compileCmd := exec.Command("go", "run", "./internal/main.go", "-o", zigFile, tc.GoFile) + var compileOut bytes.Buffer + compileCmd.Stderr = &compileOut + compileCmd.Stdout = &compileOut + + if err := compileCmd.Run(); err != nil { + // Check if this is an expected compilation error + if fileExists(tc.ErrorFile) { + expectedError, _ := os.ReadFile(tc.ErrorFile) + actualError := strings.TrimSpace(compileOut.String()) + expectedErrorStr := strings.TrimSpace(string(expectedError)) + + if actualError == expectedErrorStr { + result.Passed = true + } else { + result.Error = fmt.Sprintf("Expected error:\n%s\n\nActual error:\n%s", expectedErrorStr, actualError) + } + } else { + result.Error = fmt.Sprintf("Compilation failed: %v\nOutput: %s", err, compileOut.String()) + } + result.Elapsed = time.Since(start) + return result + } + + // If we expected a compilation error but didn't get one + if fileExists(tc.ErrorFile) { + result.Error = "Expected compilation to fail, but it succeeded" + result.Elapsed = time.Since(start) + return result + } + + // Run the Zig program + runCmd := exec.Command("zig", "run", zigFile) + + // Set up stdin if provided + if fileExists(tc.StdinFile) { + stdinData, _ := os.ReadFile(tc.StdinFile) + runCmd.Stdin = bytes.NewReader(stdinData) + } + + // Set up args if provided + if fileExists(tc.ArgsFile) { + argsData, _ := os.ReadFile(tc.ArgsFile) + args := strings.Fields(string(argsData)) + runCmd.Args = append(runCmd.Args, args...) + } + + // Capture both stdout and stderr (Zig's debug.print goes to stderr) + output, err := runCmd.CombinedOutput() + if err != nil { + result.Error = fmt.Sprintf("Execution failed: %v\nOutput: %s", err, string(output)) + result.Output = string(output) + result.Elapsed = time.Since(start) + return result + } + + // Compare output with expected + if fileExists(tc.ExpectedFile) { + expectedOutput, _ := os.ReadFile(tc.ExpectedFile) + actualOutput := string(output) + + if actualOutput == string(expectedOutput) { + result.Passed = true + } else { + result.Error = fmt.Sprintf("Output mismatch.\nExpected:\n%s\nActual:\n%s", + string(expectedOutput), actualOutput) + result.Output = actualOutput + } + } else { + // No expected file, just check it runs without error + result.Passed = true + } + + result.Elapsed = time.Since(start) + return result +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func indent(s string, prefix string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + if line != "" { + lines[i] = prefix + line + } + } + return strings.Join(lines, "\n") +} \ No newline at end of file diff --git a/hello.zig b/hello.zig new file mode 100644 index 0000000..14c62ac --- /dev/null +++ b/hello.zig @@ -0,0 +1,5 @@ +const std = @import("std"); + +pub fn main() void { + std.debug.print("Hello, world\n", .{}); +} diff --git a/internal/main.go b/internal/main.go index ff45b38..4376a5b 100644 --- a/internal/main.go +++ b/internal/main.go @@ -8,6 +8,8 @@ import ( "go/token" "os" "strings" + + "git.frop.prof/luke/go-zig-compiler/internal/zig" ) var ( @@ -36,57 +38,122 @@ func main() { panic(err) } - output, err := generate(f) + zigRoot, err := translateToZig(f) if err != nil { panic(err) } + outputFile, err := os.Create(*outputFilepath) if err != nil { panic(err) } - _, err = outputFile.WriteString(output) + defer outputFile.Close() + + err = zig.Write(outputFile, zigRoot) if err != nil { panic(err) } - fmt.Printf("%v:\n", *outputFilepath) - fmt.Println("--------------------") - fmt.Println(output) } -func generate(f *ast.File) (string, error) { - sb := new(strings.Builder) - - def := f.Decls[0].(*ast.FuncDecl) - - if def.Name.Name != "main" { - return "", fmt.Errorf("must have main") +func translateToZig(f *ast.File) (*zig.Root, error) { + // Create the root AST node + root := &zig.Root{ + ContainerMembers: []*zig.ContainerMember{}, } - sb.WriteString(`const std = @import("std");`) - sb.WriteString("\npub fn main() void {\n") + // Add the std import + root.ContainerMembers = append(root.ContainerMembers, &zig.ContainerMember{ + Decl: zig.DeclareGlobalVar("std", + zig.Call(zig.Id("@import"), zig.StringLit("std")), + zig.GlobalVarConst, + ), + }) - stmt := def.Body.List[0].(*ast.ExprStmt) - call := stmt.X.(*ast.CallExpr) - fn := call.Fun.(*ast.Ident) - - if fn.Name == "print" { - sb.WriteString(fmt.Sprintf(`std.debug.print(`)) - - args := call.Args - for _, arg := range args { - if s, ok := arg.(*ast.BasicLit); ok { - sb.WriteString(fmt.Sprintf("%s", s.Value)) - } else { - panic("WTF") + // Find and translate the main function + for _, decl := range f.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok && fn.Name.Name == "main" { + mainFunc, err := translateMainFunction(fn) + if err != nil { + return nil, err } + root.ContainerMembers = append(root.ContainerMembers, &zig.ContainerMember{ + Decl: mainFunc, + }) } - - sb.WriteString(", .{});\n") - } else { - return "", fmt.Errorf("expected printf") } - sb.WriteString("}\n") - - return sb.String(), nil + return root, nil +} + +func translateMainFunction(fn *ast.FuncDecl) (*zig.FnDecl, error) { + // Create the main function + stmts := []zig.Stmt{} + + // Translate each statement in the function body + for _, stmt := range fn.Body.List { + zigStmt, err := translateStatement(stmt) + if err != nil { + return nil, err + } + if zigStmt != nil { + stmts = append(stmts, zigStmt) + } + } + + return zig.DeclareFn( + "main", + zig.Id("void"), + zig.NewBlock(stmts...), + nil, + zig.FnExport, + ), nil +} + +func translateStatement(stmt ast.Stmt) (zig.Stmt, error) { + switch s := stmt.(type) { + case *ast.ExprStmt: + // Handle expression statements (like function calls) + expr, err := translateExpression(s.X) + if err != nil { + return nil, err + } + return zig.NewExprStmt(expr), nil + default: + return nil, fmt.Errorf("unsupported statement type: %T", stmt) + } +} + +func translateExpression(expr ast.Expr) (zig.Expr, error) { + switch e := expr.(type) { + case *ast.CallExpr: + // Handle function calls + if ident, ok := e.Fun.(*ast.Ident); ok && ident.Name == "print" { + // Translate print() to std.debug.print() + args := []zig.Expr{} + + // First argument is the format string + if len(e.Args) > 0 { + if lit, ok := e.Args[0].(*ast.BasicLit); ok && lit.Kind == token.STRING { + // Remove quotes and use the string value + args = append(args, zig.StringLit(strings.Trim(lit.Value, `"`))) + } else { + return nil, fmt.Errorf("print() requires a string literal argument") + } + } + + // Second argument is always .{} for now + args = append(args, zig.InitList()) + + return zig.Call( + zig.FieldAccess( + zig.FieldAccess(zig.Id("std"), "debug"), + "print", + ), + args..., + ), nil + } + return nil, fmt.Errorf("unsupported function call: %v", e.Fun) + default: + return nil, fmt.Errorf("unsupported expression type: %T", expr) + } } diff --git a/tests/basic/hello.expected b/tests/basic/hello.expected new file mode 100644 index 0000000..af5626b --- /dev/null +++ b/tests/basic/hello.expected @@ -0,0 +1 @@ +Hello, world! diff --git a/programs/hello.go b/tests/basic/hello.go similarity index 50% rename from programs/hello.go rename to tests/basic/hello.go index fdc7ec1..90fc7da 100644 --- a/programs/hello.go +++ b/tests/basic/hello.go @@ -1,5 +1,5 @@ package main func main() { - print("Hello, world\n") -} + print("Hello, world!\n") +} \ No newline at end of file diff --git a/tests/basic/multiple_prints.expected b/tests/basic/multiple_prints.expected new file mode 100644 index 0000000..9c43f1e --- /dev/null +++ b/tests/basic/multiple_prints.expected @@ -0,0 +1 @@ +First Second diff --git a/tests/basic/multiple_prints.go b/tests/basic/multiple_prints.go new file mode 100644 index 0000000..e8883d8 --- /dev/null +++ b/tests/basic/multiple_prints.go @@ -0,0 +1,8 @@ +package main + +func main() { + print("First") + print(" ") + print("Second") + print("\n") +} \ No newline at end of file diff --git a/tests/basic/print_escape.expected b/tests/basic/print_escape.expected new file mode 100644 index 0000000..907567d --- /dev/null +++ b/tests/basic/print_escape.expected @@ -0,0 +1,5 @@ +Hello World! +Line 1 +Line 2 +Quote: "test" +Backslash: \ diff --git a/tests/basic/print_escape.go b/tests/basic/print_escape.go new file mode 100644 index 0000000..4340053 --- /dev/null +++ b/tests/basic/print_escape.go @@ -0,0 +1,8 @@ +package main + +func main() { + print("Hello\tWorld!\n") + print("Line 1\nLine 2\n") + print("Quote: \"test\"\n") + print("Backslash: \\\n") +} \ No newline at end of file