Improve Zig AST developer experience

- Convert string-based type discrimination to type-safe enums
  - ContainerKind (struct, enum, union, opaque)
  - LiteralKind (int, float, string, char)
  - LoopKind (for, while)
- Remove duplicate AST nodes (consolidated init lists, removed unused types)
- Add comprehensive helper functions for all AST constructions
- Implement formatters for all AST nodes (expressions, statements, types)
- Add typed literal constructors: IntLit, FloatLit, StringLit, CharLit
- Improve documentation and add deprecation notices

This makes the AST more intuitive and type-safe for developers.
This commit is contained in:
Luke Wilson 2025-06-05 21:07:02 -05:00
parent 258b3c8e9b
commit 2af696078d
4 changed files with 339 additions and 129 deletions

View File

@ -2,9 +2,10 @@
package zig
// Root is the top-level node of a Zig source file.
// It represents the entire compilation unit.
type Root struct {
ContainerDocComment DocComment // //! Doc Comment (optional)
ContainerMembers []*ContainerMember
ContainerDocComment DocComment // Module-level doc comment using //!
ContainerMembers []*ContainerMember // Top-level declarations and fields
}
type ContainerMember struct {
@ -111,11 +112,21 @@ type ParamDecl struct {
Type TypeExpr // 'anytype' if empty
}
// ContainerKind represents the kind of container (struct, enum, union, opaque).
type ContainerKind int
const (
ContainerStruct ContainerKind = iota
ContainerEnum
ContainerUnion
ContainerOpaque
)
// ContainerDecl represents a struct, enum, union, or opaque declaration.
type ContainerDecl struct {
Extern bool
Packed bool
Kind string // "struct", "enum", "union", "opaque"
Kind ContainerKind
TagType TypeExpr // Optional (for enum/union)
Fields []*ContainerMember
DocComment DocComment
@ -226,10 +237,18 @@ type ContinueStmt struct {
func (*ContinueStmt) isStmt() {}
// LoopKind represents the kind of loop (for or while).
type LoopKind int
const (
LoopFor LoopKind = iota
LoopWhile
)
// LoopStmt represents a for/while loop statement.
type LoopStmt struct {
Inline bool // True if 'inline' is present
Kind string // "for" or "while"
Kind LoopKind // For or While
Prefix LoopPrefix // ForPrefix or WhilePrefix
Body Stmt // Loop body
Else Stmt // Optional else branch
@ -331,9 +350,19 @@ type Identifier struct {
Name string // The identifier name
}
// LiteralKind represents the kind of literal.
type LiteralKind int
const (
LiteralInt LiteralKind = iota
LiteralFloat
LiteralString
LiteralChar
)
// Literal represents a literal value (int, float, string, char).
type Literal struct {
Kind string // "int", "float", "string", "char"
Kind LiteralKind
Value string // The literal value as a string
}
@ -391,15 +420,14 @@ type NosuspendExpr struct {
Expr Expr // The expression to evaluate with nosuspend
}
// ContinueExpr represents a 'continue' expression.
type ContinueExpr struct {
Label string // Optional label
}
// Expr is any expression.
// This is an empty interface to allow maximum flexibility.
// Consider using type switches when working with expressions.
type Expr interface{}
// TypeExpr is any type expression.
// This is an empty interface to allow maximum flexibility.
// Consider using type switches when working with type expressions.
type TypeExpr interface{}
// PrefixTypeExpr represents a type with a string prefix. Examples include optionals and pointers.
@ -500,43 +528,14 @@ type AwaitExpr struct {
// UnreachableExpr represents the 'unreachable' keyword.
type UnreachableExpr struct{}
// EmptyInitListExpr represents an empty initializer list '{}'.
type EmptyInitListExpr struct{}
// PositionalInitListExpr represents a positional initializer list '{expr, ...}'.
type PositionalInitListExpr struct {
Values []Expr // Expressions in order
}
// FieldInitListExpr represents a field initializer list '{.field = expr, ...}'.
type FieldInitListExpr struct {
Fields []*FieldInit // Field initializers
}
// SwitchProngPayload represents a switch prong payload (|*x, y|).
type SwitchProngPayload struct {
Pointer bool
Names []string
}
// SwitchProngCase represents a single case in a switch prong.
type SwitchProngCase struct {
Expr Expr // The case expression
To Expr // Optional, for ranges
}
// SwitchProngFull represents a full switch prong with cases and payload.
type SwitchProngFull struct {
Inline bool
Cases []*SwitchProngCase // One or more cases
Payload *SwitchProngPayload // Optional
Expr Expr // The result expression
}
// SwitchElseProng represents an 'else' prong in a switch.
type SwitchElseProng struct {
Expr Expr // The result expression
}
// Note: The following types were removed as they were duplicates:
// - EmptyInitListExpr (use InitListExpr with Empty=true)
// - PositionalInitListExpr (use InitListExpr with Values)
// - FieldInitListExpr (use InitListExpr with Fields)
// - SwitchProngPayload (use Payload)
// - SwitchProngCase (use SwitchCase)
// - SwitchProngFull (use SwitchProng)
// - SwitchElseProng (use SwitchProng with SwitchCase.IsElse=true)
// VarPattern represents a variable pattern for destructuring or multiple variable declarations.
type VarPattern struct {

View File

@ -5,9 +5,24 @@ func Id(name string) *Identifier {
return &Identifier{Name: name}
}
// Lit creates a literal expression.
func Lit(kind, value string) *Literal {
return &Literal{Kind: kind, Value: value}
// IntLit creates an integer literal.
func IntLit(value string) *Literal {
return &Literal{Kind: LiteralInt, Value: value}
}
// FloatLit creates a float literal.
func FloatLit(value string) *Literal {
return &Literal{Kind: LiteralFloat, Value: value}
}
// StringLit creates a string literal.
func StringLit(value string) *Literal {
return &Literal{Kind: LiteralString, Value: value}
}
// CharLit creates a character literal.
func CharLit(value string) *Literal {
return &Literal{Kind: LiteralChar, Value: value}
}
// FieldAccess creates a field access expression.
@ -108,7 +123,7 @@ func Param(name string, typ TypeExpr) *ParamDecl {
// StructDecl creates a struct declaration with the given fields/members.
func StructDecl(fields ...*ContainerMember) *ContainerDecl {
return &ContainerDecl{
Kind: "struct",
Kind: ContainerStruct,
Fields: fields,
}
}
@ -163,13 +178,27 @@ func Defer(stmt Stmt) *DeferStmt {
// ForLoop creates a for loop statement.
func ForLoop(args []ForArg, payload *Payload, body Stmt, els Stmt) *LoopStmt {
return &LoopStmt{
Kind: "for",
Kind: LoopFor,
Prefix: &ForPrefix{Args: args, Payload: payload},
Body: body,
Else: els,
}
}
// WhileLoop creates a while loop statement.
func WhileLoop(cond Expr, continueExpr Expr, payload *Payload, body Stmt, els Stmt) *LoopStmt {
return &LoopStmt{
Kind: LoopWhile,
Prefix: &WhilePrefix{
Cond: cond,
Continue: continueExpr,
Payload: payload,
},
Body: body,
Else: els,
}
}
// ForArg creates a for loop argument.
func ForArgExpr(expr Expr) ForArg {
return ForArg{Expr: expr}
@ -229,3 +258,113 @@ func Unreachable() *UnreachableExpr {
func Test(name string, block *Block) *TestDecl {
return &TestDecl{Name: name, Block: block}
}
// GroupedExpr creates a grouped (parenthesized) expression.
func Grouped(expr Expr) *GroupedExpr {
return &GroupedExpr{Expr: expr}
}
// ComptimeExpr creates a comptime expression.
func Comptime(expr Expr) *ComptimeExpr {
return &ComptimeExpr{Expr: expr}
}
// NosuspendExpr creates a nosuspend expression.
func Nosuspend(expr Expr) *NosuspendExpr {
return &NosuspendExpr{Expr: expr}
}
// SliceType creates a slice type ([]T).
func SliceType(elem TypeExpr) *SliceTypeExpr {
return &SliceTypeExpr{Elem: elem}
}
// IfExpression creates an if expression.
func IfExpression(cond Expr, then, els Expr) *IfExpr {
return &IfExpr{Cond: cond, Then: then, Else: els}
}
// IfExpressionWithPayload creates an if expression with a payload.
func IfExpressionWithPayload(cond Expr, payload *Payload, then, els Expr) *IfExpr {
return &IfExpr{Cond: cond, Payload: payload, Then: then, Else: els}
}
// SwitchExpression creates a switch expression.
func SwitchExpression(cond Expr, prongs ...*SwitchProng) *SwitchExpr {
return &SwitchExpr{Cond: cond, Prongs: prongs}
}
// EnumDecl creates an enum declaration.
func EnumDecl(tagType TypeExpr, fields ...*ContainerMember) *ContainerDecl {
return &ContainerDecl{
Kind: ContainerEnum,
TagType: tagType,
Fields: fields,
}
}
// UnionDecl creates a union declaration.
func UnionDecl(tagType TypeExpr, fields ...*ContainerMember) *ContainerDecl {
return &ContainerDecl{
Kind: ContainerUnion,
TagType: tagType,
Fields: fields,
}
}
// OpaqueDecl creates an opaque declaration.
func OpaqueDecl() *ContainerDecl {
return &ContainerDecl{
Kind: ContainerOpaque,
}
}
// AsyncExpr creates an async expression.
func Async(expr Expr) *AsyncExpr {
return &AsyncExpr{Expr: expr}
}
// AwaitExpr creates an await expression.
func Await(expr Expr) *AwaitExpr {
return &AwaitExpr{Expr: expr}
}
// ResumeExpr creates a resume expression.
func Resume(expr Expr) *ResumeExpr {
return &ResumeExpr{Expr: expr}
}
// DotAsterisk creates a .* expression (pointer dereference).
func DotAsterisk(receiver Expr) *DotAsteriskExpr {
return &DotAsteriskExpr{Receiver: receiver}
}
// DotQuestion creates a .? expression (optional unwrap).
func DotQuestion(receiver Expr) *DotQuestionExpr {
return &DotQuestionExpr{Receiver: receiver}
}
// ErrDefer creates an errdefer statement.
func ErrDefer(payload *Payload, stmt Stmt) *DeferStmt {
return &DeferStmt{ErrDefer: true, Payload: payload, Stmt: stmt}
}
// UsingNamespace creates a usingnamespace declaration.
func UsingNamespace(expr Expr) *UsingNamespaceDecl {
return &UsingNamespaceDecl{Expr: expr}
}
// ComptimeBlock creates a comptime block declaration.
func ComptimeBlock(block *Block) *ComptimeDecl {
return &ComptimeDecl{Block: block}
}
// InitListFields creates an init list with field initializers.
func InitListFields(fields ...*FieldInit) *InitListExpr {
return &InitListExpr{Fields: fields}
}
// FieldInitPair creates a field initializer.
func FieldInitPair(name string, value Expr) *FieldInit {
return &FieldInit{Name: name, Value: value}
}

View File

@ -119,6 +119,14 @@ func writeDecl(f *formatter, decl Decl) {
}
writeBlock(f, d.Block)
f.writef("\n")
case *UsingNamespaceDecl:
f.writef("usingnamespace ")
writeExpr(f, d.Expr)
f.writef(";\n")
case *ComptimeDecl:
f.writef("comptime")
writeBlock(f, d.Block)
f.writef("\n")
}
}
@ -148,6 +156,9 @@ func writeTypeExpr(f *formatter, typ TypeExpr) {
writeExpr(f, t.Size)
f.writef("]")
writeTypeExpr(f, t.Elem)
case *SliceTypeExpr:
f.writef("[]")
writeTypeExpr(f, t.Elem)
case *ErrorUnionTypeExpr:
writeTypeExpr(f, t.ErrSet)
f.writef("!")
@ -241,7 +252,7 @@ func writeStmt(f *formatter, stmt Stmt) {
}
}
case *LoopStmt:
if s.Kind == "while" {
if s.Kind == LoopWhile {
f.writef("while (")
if wp, ok := s.Prefix.(*WhilePrefix); ok {
writeExpr(f, wp.Cond)
@ -258,7 +269,7 @@ func writeStmt(f *formatter, stmt Stmt) {
f.writef(" ")
writeStmt(f, s.Body)
}
} else if s.Kind == "for" {
} else if s.Kind == LoopFor {
f.writef("for (")
if fp, ok := s.Prefix.(*ForPrefix); ok {
for i, arg := range fp.Args {
@ -327,7 +338,7 @@ func writeStmt(f *formatter, stmt Stmt) {
f.writeIndent()
f.writef("}")
}
}
}
// writePayload emits a payload (|x|, |*x|, |*x, y|, etc).
func writePayload(f *formatter, payload *Payload) {
@ -431,7 +442,7 @@ func writeExpr(f *formatter, expr Expr) {
f.writef(".%s", e.Field)
case *Literal:
switch e.Kind {
case "string":
case LiteralString:
f.writef(`"%v"`, e.Value)
default:
f.writef("%v", e.Value)
@ -456,11 +467,21 @@ func writeExpr(f *formatter, expr Expr) {
}
}
case *ContainerDecl:
if e.Kind == "struct" {
switch e.Kind {
case ContainerStruct:
f.writef("struct ")
writeStructBody(f, e)
} else {
panic("not implemented: " + e.Kind)
case ContainerEnum:
f.writef("enum ")
writeStructBody(f, e)
case ContainerUnion:
f.writef("union ")
writeStructBody(f, e)
case ContainerOpaque:
f.writef("opaque ")
writeStructBody(f, e)
default:
panic("unknown container kind")
}
case *TryExpr:
f.writef("try ")
@ -489,6 +510,58 @@ func writeExpr(f *formatter, expr Expr) {
f.indent--
f.writeIndent()
f.writef("}")
case *GroupedExpr:
f.writef("(")
writeExpr(f, e.Expr)
f.writef(")")
case *ComptimeExpr:
f.writef("comptime ")
writeExpr(f, e.Expr)
case *NosuspendExpr:
f.writef("nosuspend ")
writeExpr(f, e.Expr)
case *AsyncExpr:
f.writef("async ")
writeExpr(f, e.Expr)
case *AwaitExpr:
f.writef("await ")
writeExpr(f, e.Expr)
case *ResumeExpr:
f.writef("resume ")
writeExpr(f, e.Expr)
case *DotAsteriskExpr:
writeExpr(f, e.Receiver)
f.writef(".*")
case *DotQuestionExpr:
writeExpr(f, e.Receiver)
f.writef(".?")
case *IfExpr:
f.writef("if (")
writeExpr(f, e.Cond)
f.writef(")")
if e.Payload != nil {
f.writef(" ")
writePayload(f, e.Payload)
}
f.writef(" ")
writeExpr(f, e.Then)
if e.Else != nil {
f.writef(" else ")
writeExpr(f, e.Else)
}
case *SwitchExpr:
f.writef("switch (")
writeExpr(f, e.Cond)
f.writef(") {\n")
f.indent++
for _, prong := range e.Prongs {
f.writeIndent()
writeSwitchProng(f, prong)
f.writef("\n")
}
f.indent--
f.writeIndent()
f.writef("}")
}
}

View File

@ -63,7 +63,7 @@ fn main() void {
{
Decl: zig.DeclareGlobalVar("std", zig.Call(
zig.Id("@import"),
zig.Lit("string", "std"),
zig.StringLit("std"),
), zig.GlobalVarConst),
},
{
@ -77,7 +77,7 @@ fn main() void {
zig.FieldAccess(zig.Id("std"), "debug"),
"print",
),
zig.Lit("string", `Hello, world!\n`),
zig.StringLit(`Hello, world!\n`),
zig.InitList(),
),
),
@ -144,7 +144,7 @@ test "processData" {
zig.ErrorUnionType(zig.Id("ProcessError"), zig.Id("i32")),
zig.NewBlock(
// var result: i32 = 0;
zig.DeclareVarStmt(false, []string{"result"}, zig.Id("i32"), zig.Lit("int", "0")),
zig.DeclareVarStmt(false, []string{"result"}, zig.Id("i32"), zig.IntLit("0")),
// defer std.debug.print("Cleanup\n", .{});
zig.Defer(
zig.NewExprStmt(
@ -153,7 +153,7 @@ test "processData" {
zig.FieldAccess(zig.Id("std"), "debug"),
"print",
),
zig.Lit("string", `Cleanup\n`),
zig.StringLit(`Cleanup\n`),
zig.InitList(),
),
),
@ -162,7 +162,7 @@ test "processData" {
zig.ForLoop(
[]zig.ForArg{
zig.ForArgExpr(zig.Id("values")),
{Expr: zig.Lit("int", "0"), From: zig.Lit("", "")},
{Expr: zig.IntLit("0"), From: zig.IntLit("")},
},
zig.PayloadNames([]string{"opt_val", "idx"}, []bool{false, false}),
zig.NewBlockStmt(
@ -178,8 +178,8 @@ test "processData" {
zig.Prong(
[]*zig.SwitchCase{
zig.Case(
zig.Unary("-", zig.Lit("int", "10")),
zig.Unary("-", zig.Lit("int", "1")),
zig.Unary("-", zig.IntLit("10")),
zig.Unary("-", zig.IntLit("1")),
),
},
nil,
@ -193,7 +193,7 @@ test "processData" {
// 0 => continue,
zig.Prong(
[]*zig.SwitchCase{
zig.Case(zig.Lit("int", "0"), nil),
zig.Case(zig.IntLit("0"), nil),
},
nil,
zig.Continue(""),
@ -201,12 +201,12 @@ test "processData" {
// 1...10 => { ... }
zig.Prong(
[]*zig.SwitchCase{
zig.Case(zig.Lit("int", "1"), zig.Lit("int", "10")),
zig.Case(zig.IntLit("1"), zig.IntLit("10")),
},
nil,
zig.NewBlockStmt(
zig.If(
zig.Binary(">", zig.Id("idx"), zig.Lit("int", "3")),
zig.Binary(">", zig.Id("idx"), zig.IntLit("3")),
zig.Break("", nil),
nil,
),
@ -228,7 +228,7 @@ test "processData" {
// else block
zig.NewBlockStmt(
zig.If(
zig.Binary("==", zig.Id("idx"), zig.Lit("int", "0")),
zig.Binary("==", zig.Id("idx"), zig.IntLit("0")),
zig.Return(zig.FieldAccess(zig.Id("ProcessError"), "InvalidInput")),
nil,
),
@ -241,7 +241,7 @@ test "processData" {
// if (values[0] == null) { unreachable; }
zig.If(
zig.Binary("==",
zig.Index(zig.Id("values"), zig.Lit("int", "0")),
zig.Index(zig.Id("values"), zig.IntLit("0")),
zig.Id("null"),
),
zig.NewBlockStmt(
@ -253,7 +253,7 @@ test "processData" {
zig.Return(zig.Id("result")),
),
[]*zig.ParamDecl{
zig.Param("values", zig.ArrayType(zig.Lit("int", "5"), zig.OptionalType(zig.Id("i32")))),
zig.Param("values", zig.ArrayType(zig.IntLit("5"), zig.OptionalType(zig.Id("i32")))),
},
0,
),
@ -265,11 +265,11 @@ test "processData" {
zig.DeclareVarStmt(true, []string{"data"}, nil,
&zig.InitListExpr{
Values: []zig.Expr{
zig.Lit("int", "5"),
zig.Unary("-", zig.Lit("int", "3")),
zig.Lit("int", "0"),
zig.IntLit("5"),
zig.Unary("-", zig.IntLit("3")),
zig.IntLit("0"),
zig.Id("null"),
zig.Lit("int", "10"),
zig.IntLit("10"),
},
},
),
@ -288,7 +288,7 @@ test "processData" {
zig.Call(
zig.Id("@as"),
zig.Id("i32"),
zig.Lit("int", "8"),
zig.IntLit("8"),
),
zig.Id("result"),
),
@ -328,7 +328,7 @@ pub fn main() !void {
Decl: zig.DeclareGlobalVar("std",
zig.Call(
zig.Id("@import"),
zig.Lit("string", "std"),
zig.StringLit("std"),
),
zig.GlobalVarConst,
),
@ -353,26 +353,24 @@ pub fn main() !void {
),
),
// var i: i32 = 1;
zig.DeclareVarStmt(false, []string{"i"}, zig.Id("i32"), zig.Lit("int", "1")),
zig.DeclareVarStmt(false, []string{"i"}, zig.Id("i32"), zig.IntLit("1")),
// while (i <= 5) : (i += 1) { ... }
&zig.LoopStmt{
Kind: "while",
Prefix: &zig.WhilePrefix{
Cond: zig.Binary("<=", zig.Id("i"), zig.Lit("int", "5")),
Continue: zig.Binary("+=", zig.Id("i"), zig.Lit("int", "1")),
},
Body: zig.NewBlockStmt(
zig.WhileLoop(
zig.Binary("<=", zig.Id("i"), zig.IntLit("5")),
zig.Binary("+=", zig.Id("i"), zig.IntLit("1")),
nil,
zig.NewBlockStmt(
&zig.IfStmt{
Cond: zig.Binary("==",
zig.Binary("%", zig.Id("i"), zig.Lit("int", "2")),
zig.Lit("int", "0"),
zig.Binary("%", zig.Id("i"), zig.IntLit("2")),
zig.IntLit("0"),
),
Then: zig.NewBlockStmt(
zig.NewExprStmt(
zig.Try(
zig.Call(
zig.FieldAccess(zig.Id("stdout"), "writeAll"),
zig.Lit("string", `even: {d}\n`),
zig.StringLit(`even: {d}\n`),
zig.InitList(zig.Id("i")),
),
),
@ -383,7 +381,7 @@ pub fn main() !void {
zig.Try(
zig.Call(
zig.FieldAccess(zig.Id("stdout"), "writeAll"),
zig.Lit("string", `odd: {d}\n`),
zig.StringLit(`odd: {d}\n`),
zig.InitList(zig.Id("i")),
),
),
@ -391,7 +389,8 @@ pub fn main() !void {
),
},
),
},
nil,
),
),
nil,
zig.FnExport,
@ -430,7 +429,7 @@ func TestStructWithFieldsAndMethod(t *testing.T) {
zig.NewExprStmt(
zig.Binary("=",
zig.FieldAccess(zig.Id("self"), "count"),
zig.Lit("int", "0"),
zig.IntLit("0"),
),
),
),