Add comprehensive Zig syntax test coverage

- Add support for error sets and error union types
- Implement for loops with index and payload syntax
- Add defer, break, continue, and switch statements
- Support unary expressions and array indexing
- Add unreachable expression and test declarations
- Extend AST with new type expressions (array, error union)
- Update formatter to handle all new syntax elements
- Fix formatting for switch prongs, payloads, and blocks
This commit is contained in:
Luke Wilson 2025-06-05 20:44:49 -05:00
parent 50b38254ab
commit 258b3c8e9b
4 changed files with 529 additions and 7 deletions

View File

@ -408,6 +408,23 @@ type PrefixTypeExpr struct {
Base TypeExpr
}
// ArrayTypeExpr represents an array type ([N]T).
type ArrayTypeExpr struct {
Size Expr
Elem TypeExpr
}
// SliceTypeExpr represents a slice type ([]T).
type SliceTypeExpr struct {
Elem TypeExpr
}
// ErrorUnionTypeExpr represents an error union type (E!T).
type ErrorUnionTypeExpr struct {
ErrSet TypeExpr
Type TypeExpr
}
// DocComment represents a doc comment (/// or //! lines).
// Newlines in the string automatically add more comments in the output.
type DocComment string

View File

@ -73,6 +73,11 @@ func If(cond Expr, then, els Stmt) *IfStmt {
return &IfStmt{Cond: cond, Then: then, Else: els}
}
// IfWithPayload creates an if statement with a payload.
func IfWithPayload(cond Expr, payload *Payload, then, els Stmt) *IfStmt {
return &IfStmt{Cond: cond, Payload: payload, Then: then, Else: els}
}
// NewBlock creates a block of statements.
func NewBlock(stmts ...Stmt) *Block {
return &Block{Stmts: stmts}
@ -134,3 +139,93 @@ func OptionalType(base TypeExpr) *PrefixTypeExpr {
func PointerType(base TypeExpr) *PrefixTypeExpr {
return &PrefixTypeExpr{Op: "*", Base: base}
}
// ArrayType creates an array type ([N]T).
func ArrayType(size Expr, elem TypeExpr) *ArrayTypeExpr {
return &ArrayTypeExpr{Size: size, Elem: elem}
}
// ErrorSet creates an error set declaration.
func ErrorSet(names ...string) *ErrorSetDecl {
return &ErrorSetDecl{Names: names}
}
// ErrorUnionType creates an error union type (E!T).
func ErrorUnionType(errSet, typ TypeExpr) *ErrorUnionTypeExpr {
return &ErrorUnionTypeExpr{ErrSet: errSet, Type: typ}
}
// DeferStmt creates a defer statement.
func Defer(stmt Stmt) *DeferStmt {
return &DeferStmt{Stmt: stmt}
}
// ForLoop creates a for loop statement.
func ForLoop(args []ForArg, payload *Payload, body Stmt, els Stmt) *LoopStmt {
return &LoopStmt{
Kind: "for",
Prefix: &ForPrefix{Args: args, Payload: payload},
Body: body,
Else: els,
}
}
// ForArg creates a for loop argument.
func ForArgExpr(expr Expr) ForArg {
return ForArg{Expr: expr}
}
// Payload creates a payload for control flow.
func PayloadNames(names []string, pointers []bool) *Payload {
return &Payload{Names: names, Pointers: pointers}
}
// SwitchStmt creates a switch statement.
func Switch(cond Expr, prongs ...*SwitchProng) *SwitchStmt {
return &SwitchStmt{Cond: cond, Prongs: prongs}
}
// SwitchProng creates a switch prong.
func Prong(cases []*SwitchCase, payload *Payload, expr Expr) *SwitchProng {
return &SwitchProng{Cases: cases, Payload: payload, Expr: expr}
}
// SwitchCase creates a switch case.
func Case(expr Expr, to Expr) *SwitchCase {
return &SwitchCase{Expr: expr, To: to}
}
// ElseCase creates an else case for a switch.
func ElseCase() *SwitchCase {
return &SwitchCase{IsElse: true}
}
// BreakStmt creates a break statement.
func Break(label string, value Expr) *BreakStmt {
return &BreakStmt{Label: label, Value: value}
}
// ContinueStmt creates a continue statement.
func Continue(label string) *ContinueStmt {
return &ContinueStmt{Label: label}
}
// UnaryExpr creates a unary expression.
func Unary(op string, expr Expr) *UnaryExpr {
return &UnaryExpr{Op: op, Expr: expr}
}
// IndexExpr creates an index expression.
func Index(receiver, index Expr) *IndexExpr {
return &IndexExpr{Receiver: receiver, Index: index}
}
// UnreachableExpr creates an unreachable expression.
func Unreachable() *UnreachableExpr {
return &UnreachableExpr{}
}
// TestDecl creates a test declaration.
func Test(name string, block *Block) *TestDecl {
return &TestDecl{Name: name, Block: block}
}

View File

@ -98,6 +98,7 @@ func writeDecl(f *formatter, decl Decl) {
f.writef(") ")
writeTypeExpr(f, d.ReturnType)
writeBlock(f, d.Body)
f.writef("\n")
case *GlobalVarDecl:
if d.Flags&GlobalVarConst != 0 {
f.writef("const %s = ", d.Name)
@ -109,6 +110,15 @@ func writeDecl(f *formatter, decl Decl) {
case *ContainerDecl:
f.writef("struct ")
writeStructBody(f, d)
case *ErrorSetDecl:
writeExpr(f, d)
case *TestDecl:
f.writef("test ")
if d.Name != "" {
f.writef(`"%s"`, d.Name)
}
writeBlock(f, d.Block)
f.writef("\n")
}
}
@ -133,6 +143,15 @@ func writeTypeExpr(f *formatter, typ TypeExpr) {
case *PrefixTypeExpr:
f.writef("%s", t.Op)
writeTypeExpr(f, t.Base)
case *ArrayTypeExpr:
f.writef("[")
writeExpr(f, t.Size)
f.writef("]")
writeTypeExpr(f, t.Elem)
case *ErrorUnionTypeExpr:
writeTypeExpr(f, t.ErrSet)
f.writef("!")
writeTypeExpr(f, t.Type)
case nil:
// nothing
default:
@ -198,6 +217,12 @@ func writeStmt(f *formatter, stmt Stmt) {
f.writef("if (")
writeExpr(f, s.Cond)
f.writef(")")
// Handle payload if present
if s.Payload != nil {
f.writef(" ")
writePayload(f, s.Payload)
}
// Always write the then branch as a block
if block, ok := s.Then.(*BlockStmt); ok {
@ -233,7 +258,156 @@ func writeStmt(f *formatter, stmt Stmt) {
f.writef(" ")
writeStmt(f, s.Body)
}
} else if s.Kind == "for" {
f.writef("for (")
if fp, ok := s.Prefix.(*ForPrefix); ok {
for i, arg := range fp.Args {
if i > 0 {
f.writef(", ")
}
writeExpr(f, arg.Expr)
if arg.From != nil {
if lit, ok := arg.From.(*Literal); ok && lit.Value == "" {
f.writef("..")
} else {
f.writef("...")
writeExpr(f, arg.From)
}
}
}
f.writef(")")
if fp.Payload != nil {
f.writef(" ")
writePayload(f, fp.Payload)
}
}
// Always write the body as a block
if block, ok := s.Body.(*BlockStmt); ok {
writeBlock(f, block.Block)
} else {
f.writef(" ")
writeStmt(f, s.Body)
}
}
case *DeferStmt:
if s.ErrDefer {
f.writef("errdefer")
} else {
f.writef("defer")
}
f.writef(" ")
writeStmt(f, s.Stmt)
case *BreakStmt:
f.writef("break")
if s.Label != "" {
f.writef(" :%s", s.Label)
}
if s.Value != nil {
f.writef(" ")
writeExpr(f, s.Value)
}
f.writef(";")
case *ContinueStmt:
f.writef("continue")
if s.Label != "" {
f.writef(" :%s", s.Label)
}
f.writef(";")
case *SwitchStmt:
f.writef("switch (")
writeExpr(f, s.Cond)
f.writef(") {\n")
f.indent++
for _, prong := range s.Prongs {
f.writeIndent()
writeSwitchProng(f, prong)
f.writef("\n")
}
f.indent--
f.writeIndent()
f.writef("}")
}
}
// writePayload emits a payload (|x|, |*x|, |*x, y|, etc).
func writePayload(f *formatter, payload *Payload) {
f.writef("|")
for i, name := range payload.Names {
if i > 0 {
f.writef(", ")
}
if payload.Pointers[i] {
f.writef("*")
}
f.writef("%s", name)
}
f.writef("|")
}
// writeSwitchProng emits a switch prong.
func writeSwitchProng(f *formatter, prong *SwitchProng) {
for i, c := range prong.Cases {
if i > 0 {
f.writef(", ")
}
if c.IsElse {
f.writef("else")
} else {
writeExpr(f, c.Expr)
if c.To != nil {
f.writef("...")
writeExpr(f, c.To)
}
}
}
f.writef(" => ")
// Check if the expression is actually a statement (like return or break)
if stmt, ok := prong.Expr.(Stmt); ok {
// If it's a block, write it directly without the leading space
if blockStmt, isBlock := stmt.(*BlockStmt); isBlock {
f.writef("{\n")
f.indent++
for _, s := range blockStmt.Block.Stmts {
f.writeIndent()
writeStmt(f, s)
f.writef("\n")
}
f.indent--
f.writeIndent()
f.writef("},")
} else {
// For single statements, write without the semicolon
switch s := stmt.(type) {
case *ReturnStmt:
f.writef("return")
if s.Value != nil {
f.writef(" ")
writeExpr(f, s.Value)
}
case *BreakStmt:
f.writef("break")
if s.Label != "" {
f.writef(" :%s", s.Label)
}
if s.Value != nil {
f.writef(" ")
writeExpr(f, s.Value)
}
case *ContinueStmt:
f.writef("continue")
if s.Label != "" {
f.writef(" :%s", s.Label)
}
case *ExprStmt:
writeExpr(f, s.Expr)
default:
writeStmt(f, stmt)
}
f.writef(",")
}
} else {
writeExpr(f, prong.Expr)
f.writef(",")
}
}
@ -266,14 +440,20 @@ func writeExpr(f *formatter, expr Expr) {
if e.Empty {
f.writef(".{}")
} else if len(e.Values) > 0 {
f.writef(".{")
for i, v := range e.Values {
if i > 0 {
f.writef(", ")
if len(e.Values) == 1 {
f.writef(".{")
writeExpr(f, e.Values[0])
f.writef("}")
} else {
f.writef(".{ ")
for i, v := range e.Values {
if i > 0 {
f.writef(", ")
}
writeExpr(f, v)
}
writeExpr(f, v)
f.writef(" }")
}
f.writef("}")
}
case *ContainerDecl:
if e.Kind == "struct" {
@ -289,6 +469,26 @@ func writeExpr(f *formatter, expr Expr) {
writeExpr(f, e.Left)
f.writef(" %s ", e.Op)
writeExpr(f, e.Right)
case *UnaryExpr:
f.writef("%s", e.Op)
writeExpr(f, e.Expr)
case *IndexExpr:
writeExpr(f, e.Receiver)
f.writef("[")
writeExpr(f, e.Index)
f.writef("]")
case *UnreachableExpr:
f.writef("unreachable")
case *ErrorSetDecl:
f.writef("error{\n")
f.indent++
for _, name := range e.Names {
f.writeIndent()
f.writef("%s,\n", name)
}
f.indent--
f.writeIndent()
f.writef("}")
}
}
@ -321,7 +521,6 @@ func writeStructBody(f *formatter, decl *ContainerDecl) {
f.writef("\n")
f.writeIndent()
writeDecl(f, member.Decl)
f.writef("\n")
}
}
f.indent--

View File

@ -92,6 +92,217 @@ fn main() void {
runZigASTTest(t, expected, root)
}
func TestCompoundFeatures(t *testing.T) {
expected := `const ProcessError = error{
InvalidInput,
OutOfBounds,
};
fn processData(values: [5]?i32) ProcessError!i32 {
var result: i32 = 0;
defer std.debug.print("Cleanup\n", .{});
for (values, 0..) |opt_val, idx| {
if (opt_val) |val| {
switch (val) {
-10...-1 => result += -val,
0 => continue,
1...10 => {
if (idx > 3) break;
result += val;
},
else => return ProcessError.InvalidInput,
}
} else {
if (idx == 0) return ProcessError.InvalidInput;
break;
}
}
if (values[0] == null) {
unreachable;
}
return result;
}
test "processData" {
const data = .{ 5, -3, 0, null, 10 };
const result = try processData(data);
try std.testing.expectEqual(@as(i32, 8), result);
}
`
root := &zig.Root{
ContainerMembers: []*zig.ContainerMember{
{
Decl: zig.DeclareGlobalVar("ProcessError",
zig.ErrorSet("InvalidInput", "OutOfBounds"),
zig.GlobalVarConst,
),
},
{
Decl: zig.DeclareFn(
"processData",
zig.ErrorUnionType(zig.Id("ProcessError"), zig.Id("i32")),
zig.NewBlock(
// var result: i32 = 0;
zig.DeclareVarStmt(false, []string{"result"}, zig.Id("i32"), zig.Lit("int", "0")),
// defer std.debug.print("Cleanup\n", .{});
zig.Defer(
zig.NewExprStmt(
zig.Call(
zig.FieldAccess(
zig.FieldAccess(zig.Id("std"), "debug"),
"print",
),
zig.Lit("string", `Cleanup\n`),
zig.InitList(),
),
),
),
// for (values, 0..) |opt_val, idx| { ... }
zig.ForLoop(
[]zig.ForArg{
zig.ForArgExpr(zig.Id("values")),
{Expr: zig.Lit("int", "0"), From: zig.Lit("", "")},
},
zig.PayloadNames([]string{"opt_val", "idx"}, []bool{false, false}),
zig.NewBlockStmt(
// if (opt_val) |val| { ... } else { ... }
zig.IfWithPayload(
zig.Id("opt_val"),
zig.PayloadNames([]string{"val"}, []bool{false}),
zig.NewBlockStmt(
// switch (val) { ... }
zig.Switch(
zig.Id("val"),
// -10...-1 => result += -val,
zig.Prong(
[]*zig.SwitchCase{
zig.Case(
zig.Unary("-", zig.Lit("int", "10")),
zig.Unary("-", zig.Lit("int", "1")),
),
},
nil,
zig.NewExprStmt(
zig.Binary("+=",
zig.Id("result"),
zig.Unary("-", zig.Id("val")),
),
),
),
// 0 => continue,
zig.Prong(
[]*zig.SwitchCase{
zig.Case(zig.Lit("int", "0"), nil),
},
nil,
zig.Continue(""),
),
// 1...10 => { ... }
zig.Prong(
[]*zig.SwitchCase{
zig.Case(zig.Lit("int", "1"), zig.Lit("int", "10")),
},
nil,
zig.NewBlockStmt(
zig.If(
zig.Binary(">", zig.Id("idx"), zig.Lit("int", "3")),
zig.Break("", nil),
nil,
),
zig.NewExprStmt(
zig.Binary("+=", zig.Id("result"), zig.Id("val")),
),
),
),
// else => return ProcessError.InvalidInput,
zig.Prong(
[]*zig.SwitchCase{zig.ElseCase()},
nil,
zig.Return(
zig.FieldAccess(zig.Id("ProcessError"), "InvalidInput"),
),
),
),
),
// else block
zig.NewBlockStmt(
zig.If(
zig.Binary("==", zig.Id("idx"), zig.Lit("int", "0")),
zig.Return(zig.FieldAccess(zig.Id("ProcessError"), "InvalidInput")),
nil,
),
zig.Break("", nil),
),
),
),
nil,
),
// if (values[0] == null) { unreachable; }
zig.If(
zig.Binary("==",
zig.Index(zig.Id("values"), zig.Lit("int", "0")),
zig.Id("null"),
),
zig.NewBlockStmt(
zig.NewExprStmt(zig.Unreachable()),
),
nil,
),
// return result;
zig.Return(zig.Id("result")),
),
[]*zig.ParamDecl{
zig.Param("values", zig.ArrayType(zig.Lit("int", "5"), zig.OptionalType(zig.Id("i32")))),
},
0,
),
},
{
Decl: zig.Test("processData",
zig.NewBlock(
// const data = [_]?i32{ 5, -3, 0, null, 10 };
zig.DeclareVarStmt(true, []string{"data"}, nil,
&zig.InitListExpr{
Values: []zig.Expr{
zig.Lit("int", "5"),
zig.Unary("-", zig.Lit("int", "3")),
zig.Lit("int", "0"),
zig.Id("null"),
zig.Lit("int", "10"),
},
},
),
// const result = try processData(data);
zig.DeclareVarStmt(true, []string{"result"}, nil,
zig.Try(zig.Call(zig.Id("processData"), zig.Id("data"))),
),
// try std.testing.expectEqual(@as(i32, 8), result);
zig.NewExprStmt(
zig.Try(
zig.Call(
zig.FieldAccess(
zig.FieldAccess(zig.Id("std"), "testing"),
"expectEqual",
),
zig.Call(
zig.Id("@as"),
zig.Id("i32"),
zig.Lit("int", "8"),
),
zig.Id("result"),
),
),
),
),
),
},
},
}
runZigASTTest(t, expected, root)
}
func TestEvenOdd(t *testing.T) {
expected := `//! Abc