From 258b3c8e9bb87444ddf1a45892e1e1d16cec91fa Mon Sep 17 00:00:00 2001 From: Luke Wilson Date: Thu, 5 Jun 2025 20:44:49 -0500 Subject: [PATCH] Add comprehensive Zig syntax test coverage - Add support for error sets and error union types - Implement for loops with index and payload syntax - Add defer, break, continue, and switch statements - Support unary expressions and array indexing - Add unreachable expression and test declarations - Extend AST with new type expressions (array, error union) - Update formatter to handle all new syntax elements - Fix formatting for switch prongs, payloads, and blocks --- internal/zig/ast.go | 17 +++ internal/zig/asthelpers.go | 95 +++++++++++++++++ internal/zig/zig.go | 213 +++++++++++++++++++++++++++++++++++-- internal/zig/zig_test.go | 211 ++++++++++++++++++++++++++++++++++++ 4 files changed, 529 insertions(+), 7 deletions(-) diff --git a/internal/zig/ast.go b/internal/zig/ast.go index d68cfc3..2006770 100644 --- a/internal/zig/ast.go +++ b/internal/zig/ast.go @@ -408,6 +408,23 @@ type PrefixTypeExpr struct { Base TypeExpr } +// ArrayTypeExpr represents an array type ([N]T). +type ArrayTypeExpr struct { + Size Expr + Elem TypeExpr +} + +// SliceTypeExpr represents a slice type ([]T). +type SliceTypeExpr struct { + Elem TypeExpr +} + +// ErrorUnionTypeExpr represents an error union type (E!T). +type ErrorUnionTypeExpr struct { + ErrSet TypeExpr + Type TypeExpr +} + // DocComment represents a doc comment (/// or //! lines). // Newlines in the string automatically add more comments in the output. type DocComment string diff --git a/internal/zig/asthelpers.go b/internal/zig/asthelpers.go index 2bea368..4397da1 100644 --- a/internal/zig/asthelpers.go +++ b/internal/zig/asthelpers.go @@ -73,6 +73,11 @@ func If(cond Expr, then, els Stmt) *IfStmt { return &IfStmt{Cond: cond, Then: then, Else: els} } +// IfWithPayload creates an if statement with a payload. +func IfWithPayload(cond Expr, payload *Payload, then, els Stmt) *IfStmt { + return &IfStmt{Cond: cond, Payload: payload, Then: then, Else: els} +} + // NewBlock creates a block of statements. func NewBlock(stmts ...Stmt) *Block { return &Block{Stmts: stmts} @@ -134,3 +139,93 @@ func OptionalType(base TypeExpr) *PrefixTypeExpr { func PointerType(base TypeExpr) *PrefixTypeExpr { return &PrefixTypeExpr{Op: "*", Base: base} } + +// ArrayType creates an array type ([N]T). +func ArrayType(size Expr, elem TypeExpr) *ArrayTypeExpr { + return &ArrayTypeExpr{Size: size, Elem: elem} +} + +// ErrorSet creates an error set declaration. +func ErrorSet(names ...string) *ErrorSetDecl { + return &ErrorSetDecl{Names: names} +} + +// ErrorUnionType creates an error union type (E!T). +func ErrorUnionType(errSet, typ TypeExpr) *ErrorUnionTypeExpr { + return &ErrorUnionTypeExpr{ErrSet: errSet, Type: typ} +} + +// DeferStmt creates a defer statement. +func Defer(stmt Stmt) *DeferStmt { + return &DeferStmt{Stmt: stmt} +} + +// ForLoop creates a for loop statement. +func ForLoop(args []ForArg, payload *Payload, body Stmt, els Stmt) *LoopStmt { + return &LoopStmt{ + Kind: "for", + Prefix: &ForPrefix{Args: args, Payload: payload}, + Body: body, + Else: els, + } +} + +// ForArg creates a for loop argument. +func ForArgExpr(expr Expr) ForArg { + return ForArg{Expr: expr} +} + +// Payload creates a payload for control flow. +func PayloadNames(names []string, pointers []bool) *Payload { + return &Payload{Names: names, Pointers: pointers} +} + +// SwitchStmt creates a switch statement. +func Switch(cond Expr, prongs ...*SwitchProng) *SwitchStmt { + return &SwitchStmt{Cond: cond, Prongs: prongs} +} + +// SwitchProng creates a switch prong. +func Prong(cases []*SwitchCase, payload *Payload, expr Expr) *SwitchProng { + return &SwitchProng{Cases: cases, Payload: payload, Expr: expr} +} + +// SwitchCase creates a switch case. +func Case(expr Expr, to Expr) *SwitchCase { + return &SwitchCase{Expr: expr, To: to} +} + +// ElseCase creates an else case for a switch. +func ElseCase() *SwitchCase { + return &SwitchCase{IsElse: true} +} + +// BreakStmt creates a break statement. +func Break(label string, value Expr) *BreakStmt { + return &BreakStmt{Label: label, Value: value} +} + +// ContinueStmt creates a continue statement. +func Continue(label string) *ContinueStmt { + return &ContinueStmt{Label: label} +} + +// UnaryExpr creates a unary expression. +func Unary(op string, expr Expr) *UnaryExpr { + return &UnaryExpr{Op: op, Expr: expr} +} + +// IndexExpr creates an index expression. +func Index(receiver, index Expr) *IndexExpr { + return &IndexExpr{Receiver: receiver, Index: index} +} + +// UnreachableExpr creates an unreachable expression. +func Unreachable() *UnreachableExpr { + return &UnreachableExpr{} +} + +// TestDecl creates a test declaration. +func Test(name string, block *Block) *TestDecl { + return &TestDecl{Name: name, Block: block} +} diff --git a/internal/zig/zig.go b/internal/zig/zig.go index ee5052d..4489ccf 100644 --- a/internal/zig/zig.go +++ b/internal/zig/zig.go @@ -98,6 +98,7 @@ func writeDecl(f *formatter, decl Decl) { f.writef(") ") writeTypeExpr(f, d.ReturnType) writeBlock(f, d.Body) + f.writef("\n") case *GlobalVarDecl: if d.Flags&GlobalVarConst != 0 { f.writef("const %s = ", d.Name) @@ -109,6 +110,15 @@ func writeDecl(f *formatter, decl Decl) { case *ContainerDecl: f.writef("struct ") writeStructBody(f, d) + case *ErrorSetDecl: + writeExpr(f, d) + case *TestDecl: + f.writef("test ") + if d.Name != "" { + f.writef(`"%s"`, d.Name) + } + writeBlock(f, d.Block) + f.writef("\n") } } @@ -133,6 +143,15 @@ func writeTypeExpr(f *formatter, typ TypeExpr) { case *PrefixTypeExpr: f.writef("%s", t.Op) writeTypeExpr(f, t.Base) + case *ArrayTypeExpr: + f.writef("[") + writeExpr(f, t.Size) + f.writef("]") + writeTypeExpr(f, t.Elem) + case *ErrorUnionTypeExpr: + writeTypeExpr(f, t.ErrSet) + f.writef("!") + writeTypeExpr(f, t.Type) case nil: // nothing default: @@ -198,6 +217,12 @@ func writeStmt(f *formatter, stmt Stmt) { f.writef("if (") writeExpr(f, s.Cond) f.writef(")") + + // Handle payload if present + if s.Payload != nil { + f.writef(" ") + writePayload(f, s.Payload) + } // Always write the then branch as a block if block, ok := s.Then.(*BlockStmt); ok { @@ -233,7 +258,156 @@ func writeStmt(f *formatter, stmt Stmt) { f.writef(" ") writeStmt(f, s.Body) } + } else if s.Kind == "for" { + f.writef("for (") + if fp, ok := s.Prefix.(*ForPrefix); ok { + for i, arg := range fp.Args { + if i > 0 { + f.writef(", ") + } + writeExpr(f, arg.Expr) + if arg.From != nil { + if lit, ok := arg.From.(*Literal); ok && lit.Value == "" { + f.writef("..") + } else { + f.writef("...") + writeExpr(f, arg.From) + } + } + } + f.writef(")") + if fp.Payload != nil { + f.writef(" ") + writePayload(f, fp.Payload) + } + } + // Always write the body as a block + if block, ok := s.Body.(*BlockStmt); ok { + writeBlock(f, block.Block) + } else { + f.writef(" ") + writeStmt(f, s.Body) + } } + case *DeferStmt: + if s.ErrDefer { + f.writef("errdefer") + } else { + f.writef("defer") + } + f.writef(" ") + writeStmt(f, s.Stmt) + case *BreakStmt: + f.writef("break") + if s.Label != "" { + f.writef(" :%s", s.Label) + } + if s.Value != nil { + f.writef(" ") + writeExpr(f, s.Value) + } + f.writef(";") + case *ContinueStmt: + f.writef("continue") + if s.Label != "" { + f.writef(" :%s", s.Label) + } + f.writef(";") + case *SwitchStmt: + f.writef("switch (") + writeExpr(f, s.Cond) + f.writef(") {\n") + f.indent++ + for _, prong := range s.Prongs { + f.writeIndent() + writeSwitchProng(f, prong) + f.writef("\n") + } + f.indent-- + f.writeIndent() + f.writef("}") + } + } + +// writePayload emits a payload (|x|, |*x|, |*x, y|, etc). +func writePayload(f *formatter, payload *Payload) { + f.writef("|") + for i, name := range payload.Names { + if i > 0 { + f.writef(", ") + } + if payload.Pointers[i] { + f.writef("*") + } + f.writef("%s", name) + } + f.writef("|") +} + +// writeSwitchProng emits a switch prong. +func writeSwitchProng(f *formatter, prong *SwitchProng) { + for i, c := range prong.Cases { + if i > 0 { + f.writef(", ") + } + if c.IsElse { + f.writef("else") + } else { + writeExpr(f, c.Expr) + if c.To != nil { + f.writef("...") + writeExpr(f, c.To) + } + } + } + f.writef(" => ") + // Check if the expression is actually a statement (like return or break) + if stmt, ok := prong.Expr.(Stmt); ok { + // If it's a block, write it directly without the leading space + if blockStmt, isBlock := stmt.(*BlockStmt); isBlock { + f.writef("{\n") + f.indent++ + for _, s := range blockStmt.Block.Stmts { + f.writeIndent() + writeStmt(f, s) + f.writef("\n") + } + f.indent-- + f.writeIndent() + f.writef("},") + } else { + // For single statements, write without the semicolon + switch s := stmt.(type) { + case *ReturnStmt: + f.writef("return") + if s.Value != nil { + f.writef(" ") + writeExpr(f, s.Value) + } + case *BreakStmt: + f.writef("break") + if s.Label != "" { + f.writef(" :%s", s.Label) + } + if s.Value != nil { + f.writef(" ") + writeExpr(f, s.Value) + } + case *ContinueStmt: + f.writef("continue") + if s.Label != "" { + f.writef(" :%s", s.Label) + } + case *ExprStmt: + writeExpr(f, s.Expr) + default: + writeStmt(f, stmt) + } + f.writef(",") + } + } else { + writeExpr(f, prong.Expr) + f.writef(",") } } @@ -266,14 +440,20 @@ func writeExpr(f *formatter, expr Expr) { if e.Empty { f.writef(".{}") } else if len(e.Values) > 0 { - f.writef(".{") - for i, v := range e.Values { - if i > 0 { - f.writef(", ") + if len(e.Values) == 1 { + f.writef(".{") + writeExpr(f, e.Values[0]) + f.writef("}") + } else { + f.writef(".{ ") + for i, v := range e.Values { + if i > 0 { + f.writef(", ") + } + writeExpr(f, v) } - writeExpr(f, v) + f.writef(" }") } - f.writef("}") } case *ContainerDecl: if e.Kind == "struct" { @@ -289,6 +469,26 @@ func writeExpr(f *formatter, expr Expr) { writeExpr(f, e.Left) f.writef(" %s ", e.Op) writeExpr(f, e.Right) + case *UnaryExpr: + f.writef("%s", e.Op) + writeExpr(f, e.Expr) + case *IndexExpr: + writeExpr(f, e.Receiver) + f.writef("[") + writeExpr(f, e.Index) + f.writef("]") + case *UnreachableExpr: + f.writef("unreachable") + case *ErrorSetDecl: + f.writef("error{\n") + f.indent++ + for _, name := range e.Names { + f.writeIndent() + f.writef("%s,\n", name) + } + f.indent-- + f.writeIndent() + f.writef("}") } } @@ -321,7 +521,6 @@ func writeStructBody(f *formatter, decl *ContainerDecl) { f.writef("\n") f.writeIndent() writeDecl(f, member.Decl) - f.writef("\n") } } f.indent-- diff --git a/internal/zig/zig_test.go b/internal/zig/zig_test.go index a9f2d38..dfbfbad 100644 --- a/internal/zig/zig_test.go +++ b/internal/zig/zig_test.go @@ -92,6 +92,217 @@ fn main() void { runZigASTTest(t, expected, root) } +func TestCompoundFeatures(t *testing.T) { + expected := `const ProcessError = error{ + InvalidInput, + OutOfBounds, +}; + +fn processData(values: [5]?i32) ProcessError!i32 { + var result: i32 = 0; + defer std.debug.print("Cleanup\n", .{}); + for (values, 0..) |opt_val, idx| { + if (opt_val) |val| { + switch (val) { + -10...-1 => result += -val, + 0 => continue, + 1...10 => { + if (idx > 3) break; + result += val; + }, + else => return ProcessError.InvalidInput, + } + } else { + if (idx == 0) return ProcessError.InvalidInput; + break; + } + } + if (values[0] == null) { + unreachable; + } + return result; +} + +test "processData" { + const data = .{ 5, -3, 0, null, 10 }; + const result = try processData(data); + try std.testing.expectEqual(@as(i32, 8), result); +} +` + + root := &zig.Root{ + ContainerMembers: []*zig.ContainerMember{ + { + Decl: zig.DeclareGlobalVar("ProcessError", + zig.ErrorSet("InvalidInput", "OutOfBounds"), + zig.GlobalVarConst, + ), + }, + { + Decl: zig.DeclareFn( + "processData", + zig.ErrorUnionType(zig.Id("ProcessError"), zig.Id("i32")), + zig.NewBlock( + // var result: i32 = 0; + zig.DeclareVarStmt(false, []string{"result"}, zig.Id("i32"), zig.Lit("int", "0")), + // defer std.debug.print("Cleanup\n", .{}); + zig.Defer( + zig.NewExprStmt( + zig.Call( + zig.FieldAccess( + zig.FieldAccess(zig.Id("std"), "debug"), + "print", + ), + zig.Lit("string", `Cleanup\n`), + zig.InitList(), + ), + ), + ), + // for (values, 0..) |opt_val, idx| { ... } + zig.ForLoop( + []zig.ForArg{ + zig.ForArgExpr(zig.Id("values")), + {Expr: zig.Lit("int", "0"), From: zig.Lit("", "")}, + }, + zig.PayloadNames([]string{"opt_val", "idx"}, []bool{false, false}), + zig.NewBlockStmt( + // if (opt_val) |val| { ... } else { ... } + zig.IfWithPayload( + zig.Id("opt_val"), + zig.PayloadNames([]string{"val"}, []bool{false}), + zig.NewBlockStmt( + // switch (val) { ... } + zig.Switch( + zig.Id("val"), + // -10...-1 => result += -val, + zig.Prong( + []*zig.SwitchCase{ + zig.Case( + zig.Unary("-", zig.Lit("int", "10")), + zig.Unary("-", zig.Lit("int", "1")), + ), + }, + nil, + zig.NewExprStmt( + zig.Binary("+=", + zig.Id("result"), + zig.Unary("-", zig.Id("val")), + ), + ), + ), + // 0 => continue, + zig.Prong( + []*zig.SwitchCase{ + zig.Case(zig.Lit("int", "0"), nil), + }, + nil, + zig.Continue(""), + ), + // 1...10 => { ... } + zig.Prong( + []*zig.SwitchCase{ + zig.Case(zig.Lit("int", "1"), zig.Lit("int", "10")), + }, + nil, + zig.NewBlockStmt( + zig.If( + zig.Binary(">", zig.Id("idx"), zig.Lit("int", "3")), + zig.Break("", nil), + nil, + ), + zig.NewExprStmt( + zig.Binary("+=", zig.Id("result"), zig.Id("val")), + ), + ), + ), + // else => return ProcessError.InvalidInput, + zig.Prong( + []*zig.SwitchCase{zig.ElseCase()}, + nil, + zig.Return( + zig.FieldAccess(zig.Id("ProcessError"), "InvalidInput"), + ), + ), + ), + ), + // else block + zig.NewBlockStmt( + zig.If( + zig.Binary("==", zig.Id("idx"), zig.Lit("int", "0")), + zig.Return(zig.FieldAccess(zig.Id("ProcessError"), "InvalidInput")), + nil, + ), + zig.Break("", nil), + ), + ), + ), + nil, + ), + // if (values[0] == null) { unreachable; } + zig.If( + zig.Binary("==", + zig.Index(zig.Id("values"), zig.Lit("int", "0")), + zig.Id("null"), + ), + zig.NewBlockStmt( + zig.NewExprStmt(zig.Unreachable()), + ), + nil, + ), + // return result; + zig.Return(zig.Id("result")), + ), + []*zig.ParamDecl{ + zig.Param("values", zig.ArrayType(zig.Lit("int", "5"), zig.OptionalType(zig.Id("i32")))), + }, + 0, + ), + }, + { + Decl: zig.Test("processData", + zig.NewBlock( + // const data = [_]?i32{ 5, -3, 0, null, 10 }; + zig.DeclareVarStmt(true, []string{"data"}, nil, + &zig.InitListExpr{ + Values: []zig.Expr{ + zig.Lit("int", "5"), + zig.Unary("-", zig.Lit("int", "3")), + zig.Lit("int", "0"), + zig.Id("null"), + zig.Lit("int", "10"), + }, + }, + ), + // const result = try processData(data); + zig.DeclareVarStmt(true, []string{"result"}, nil, + zig.Try(zig.Call(zig.Id("processData"), zig.Id("data"))), + ), + // try std.testing.expectEqual(@as(i32, 8), result); + zig.NewExprStmt( + zig.Try( + zig.Call( + zig.FieldAccess( + zig.FieldAccess(zig.Id("std"), "testing"), + "expectEqual", + ), + zig.Call( + zig.Id("@as"), + zig.Id("i32"), + zig.Lit("int", "8"), + ), + zig.Id("result"), + ), + ), + ), + ), + ), + }, + }, + } + + runZigASTTest(t, expected, root) +} + func TestEvenOdd(t *testing.T) { expected := `//! Abc