package zig_test import ( "fmt" "strings" "testing" "git.frop.prof/luke/go-zig-compiler/internal/zig" ) func Expect(expected, actual string) error { if expected != actual { message := fmt.Sprintf("\nExpected:\n%v\nActual:\n%v", expected, actual) message += fmt.Sprintf("\n%q", actual) // Find the first difference (handle rune/byte mismatch) minLen := len(expected) if len(actual) < minLen { minLen = len(actual) } for i := 0; i < minLen; i++ { if expected[i] != actual[i] { message += fmt.Sprintf("\nDifference at index %d: expected '%c', got '%c'", i, expected[i], actual[i]) break } } if len(expected) != len(actual) { message += fmt.Sprintf("\nLength mismatch: expected %d bytes, got %d bytes", len(expected), len(actual)) } return fmt.Errorf("%s", message) } return nil } // runZigASTTest is a helper to check if the Zig AST renders as expected. func runZigASTTest(t *testing.T, expected string, root *zig.Root) { t.Helper() sb := new(strings.Builder) err := zig.Write(sb, root) if err != nil { t.FailNow() } actual := sb.String() if err = Expect(expected, actual); err != nil { t.Error(err) } } func TestHelloWorld(t *testing.T) { expected := `//! Hello, world! const std = @import("std"); fn main() void { std.debug.print("Hello, world!\n", .{}); } ` root := &zig.Root{ ContainerDocComment: "Hello, world!", ContainerMembers: []*zig.ContainerMember{ { Decl: zig.DeclareGlobalVar("std", zig.Call( zig.Id("@import"), zig.StringLit("std"), ), zig.GlobalVarConst), }, { Decl: zig.DeclareFn( "main", zig.Id("void"), zig.NewBlock( zig.NewExprStmt( zig.Call( zig.FieldAccess( zig.FieldAccess(zig.Id("std"), "debug"), "print", ), zig.StringLit(`Hello, world!\n`), zig.InitList(), ), ), ), nil, // params 0, // flags ), }, }, } runZigASTTest(t, expected, root) } func TestCompoundFeatures(t *testing.T) { expected := `const ProcessError = error{ InvalidInput, OutOfBounds, }; fn processData(values: [5]?i32) ProcessError!i32 { var result: i32 = 0; defer std.debug.print("Cleanup\n", .{}); for (values, 0..) |opt_val, idx| { if (opt_val) |val| { switch (val) { -10...-1 => result += -val, 0 => continue, 1...10 => { if (idx > 3) break; result += val; }, else => return ProcessError.InvalidInput, } } else { if (idx == 0) return ProcessError.InvalidInput; break; } } if (values[0] == null) { unreachable; } return result; } test "processData" { const data = .{ 5, -3, 0, null, 10 }; const result = try processData(data); try std.testing.expectEqual(@as(i32, 8), result); } ` root := &zig.Root{ ContainerMembers: []*zig.ContainerMember{ { Decl: zig.DeclareGlobalVar("ProcessError", zig.ErrorSet("InvalidInput", "OutOfBounds"), zig.GlobalVarConst, ), }, { Decl: zig.DeclareFn( "processData", zig.ErrorUnionType(zig.Id("ProcessError"), zig.Id("i32")), zig.NewBlock( // var result: i32 = 0; zig.DeclareVarStmt(false, []string{"result"}, zig.Id("i32"), zig.IntLit("0")), // defer std.debug.print("Cleanup\n", .{}); zig.Defer( zig.NewExprStmt( zig.Call( zig.FieldAccess( zig.FieldAccess(zig.Id("std"), "debug"), "print", ), zig.StringLit(`Cleanup\n`), zig.InitList(), ), ), ), // for (values, 0..) |opt_val, idx| { ... } zig.ForLoop( []zig.ForArg{ zig.ForArgExpr(zig.Id("values")), {Expr: zig.IntLit("0"), From: zig.IntLit("")}, }, zig.PayloadNames([]string{"opt_val", "idx"}, []bool{false, false}), zig.NewBlockStmt( // if (opt_val) |val| { ... } else { ... } zig.IfWithPayload( zig.Id("opt_val"), zig.PayloadNames([]string{"val"}, []bool{false}), zig.NewBlockStmt( // switch (val) { ... } zig.Switch( zig.Id("val"), // -10...-1 => result += -val, zig.Prong( []*zig.SwitchCase{ zig.Case( zig.Unary("-", zig.IntLit("10")), zig.Unary("-", zig.IntLit("1")), ), }, nil, zig.NewExprStmt( zig.Binary("+=", zig.Id("result"), zig.Unary("-", zig.Id("val")), ), ), ), // 0 => continue, zig.Prong( []*zig.SwitchCase{ zig.Case(zig.IntLit("0"), nil), }, nil, zig.Continue(""), ), // 1...10 => { ... } zig.Prong( []*zig.SwitchCase{ zig.Case(zig.IntLit("1"), zig.IntLit("10")), }, nil, zig.NewBlockStmt( zig.If( zig.Binary(">", zig.Id("idx"), zig.IntLit("3")), zig.Break("", nil), nil, ), zig.NewExprStmt( zig.Binary("+=", zig.Id("result"), zig.Id("val")), ), ), ), // else => return ProcessError.InvalidInput, zig.Prong( []*zig.SwitchCase{zig.ElseCase()}, nil, zig.Return( zig.FieldAccess(zig.Id("ProcessError"), "InvalidInput"), ), ), ), ), // else block zig.NewBlockStmt( zig.If( zig.Binary("==", zig.Id("idx"), zig.IntLit("0")), zig.Return(zig.FieldAccess(zig.Id("ProcessError"), "InvalidInput")), nil, ), zig.Break("", nil), ), ), ), nil, ), // if (values[0] == null) { unreachable; } zig.If( zig.Binary("==", zig.Index(zig.Id("values"), zig.IntLit("0")), zig.Id("null"), ), zig.NewBlockStmt( zig.NewExprStmt(zig.Unreachable()), ), nil, ), // return result; zig.Return(zig.Id("result")), ), []*zig.ParamDecl{ zig.Param("values", zig.ArrayType(zig.IntLit("5"), zig.OptionalType(zig.Id("i32")))), }, 0, ), }, { Decl: zig.Test("processData", zig.NewBlock( // const data = [_]?i32{ 5, -3, 0, null, 10 }; zig.DeclareVarStmt(true, []string{"data"}, nil, &zig.InitListExpr{ Values: []zig.Expr{ zig.IntLit("5"), zig.Unary("-", zig.IntLit("3")), zig.IntLit("0"), zig.Id("null"), zig.IntLit("10"), }, }, ), // const result = try processData(data); zig.DeclareVarStmt(true, []string{"result"}, nil, zig.Try(zig.Call(zig.Id("processData"), zig.Id("data"))), ), // try std.testing.expectEqual(@as(i32, 8), result); zig.NewExprStmt( zig.Try( zig.Call( zig.FieldAccess( zig.FieldAccess(zig.Id("std"), "testing"), "expectEqual", ), zig.Call( zig.Id("@as"), zig.Id("i32"), zig.IntLit("8"), ), zig.Id("result"), ), ), ), ), ), }, }, } runZigASTTest(t, expected, root) } func TestEvenOdd(t *testing.T) { expected := `//! Abc const std = @import("std"); pub fn main() !void { const stdout = std.io.getStdOut().writer(); var i: i32 = 1; while (i <= 5) : (i += 1) { if (i % 2 == 0) { try stdout.writeAll("even: {d}\n", .{i}); } else { try stdout.writeAll("odd: {d}\n", .{i}); } } } ` root := &zig.Root{ ContainerDocComment: "Abc", ContainerMembers: []*zig.ContainerMember{ { Decl: zig.DeclareGlobalVar("std", zig.Call( zig.Id("@import"), zig.StringLit("std"), ), zig.GlobalVarConst, ), }, { Decl: zig.DeclareFn( "main", zig.Id("!void"), zig.NewBlock( // const stdout = std.io.getStdOut().writer(); zig.DeclareVarStmt(true, []string{"stdout"}, nil, zig.Call( zig.FieldAccess( zig.Call( zig.FieldAccess( zig.FieldAccess(zig.Id("std"), "io"), "getStdOut", ), ), "writer", ), ), ), // var i: i32 = 1; zig.DeclareVarStmt(false, []string{"i"}, zig.Id("i32"), zig.IntLit("1")), // while (i <= 5) : (i += 1) { ... } zig.WhileLoop( zig.Binary("<=", zig.Id("i"), zig.IntLit("5")), zig.Binary("+=", zig.Id("i"), zig.IntLit("1")), nil, zig.NewBlockStmt( &zig.IfStmt{ Cond: zig.Binary("==", zig.Binary("%", zig.Id("i"), zig.IntLit("2")), zig.IntLit("0"), ), Then: zig.NewBlockStmt( zig.NewExprStmt( zig.Try( zig.Call( zig.FieldAccess(zig.Id("stdout"), "writeAll"), zig.StringLit(`even: {d}\n`), zig.InitList(zig.Id("i")), ), ), ), ), Else: zig.NewBlockStmt( zig.NewExprStmt( zig.Try( zig.Call( zig.FieldAccess(zig.Id("stdout"), "writeAll"), zig.StringLit(`odd: {d}\n`), zig.InitList(zig.Id("i")), ), ), ), ), }, ), nil, ), ), nil, zig.FnExport, ), }, }, } runZigASTTest(t, expected, root) } func TestStructWithFieldsAndMethod(t *testing.T) { expected := `const MyStruct = struct { const Self = @This(); data: ?*u8, count: u32, pub fn reset(self: Self) void { self.count = 0; } }; ` root := &zig.Root{ ContainerMembers: []*zig.ContainerMember{ { Decl: zig.DeclareGlobalVar("MyStruct", zig.StructDecl( zig.Field("Self", nil, nil, zig.Call(zig.Id("@This"))), zig.Field("data", zig.OptionalType(zig.PointerType(zig.Id("u8"))), nil, nil), zig.Field("count", zig.Id("u32"), nil, nil), zig.Method(zig.DeclareFn( "reset", zig.Id("void"), zig.NewBlock( zig.NewExprStmt( zig.Binary("=", zig.FieldAccess(zig.Id("self"), "count"), zig.IntLit("0"), ), ), ), []*zig.ParamDecl{ zig.Param("self", zig.Id("Self")), }, zig.FnExport, )), ), zig.GlobalVarConst, ), }, }, } runZigASTTest(t, expected, root) }