diff --git a/internal/zig/ast.go b/internal/zig/ast.go index 4edfd0e..d68cfc3 100644 --- a/internal/zig/ast.go +++ b/internal/zig/ast.go @@ -27,14 +27,21 @@ type Decl interface { isDecl() } +// FnFlags is a bitfield for function declaration options. +type FnFlags uint8 + +const ( + FnExport FnFlags = 1 << iota + FnExtern + FnInline + FnNoInline + FnThreadLocal +) + // FnDecl represents a function declaration. type FnDecl struct { - Export bool - Extern bool + Flags FnFlags ExternName string // Optional string for extern - Inline bool - NoInline bool - ThreadLocal bool Name string // May be empty (anonymous) Params []*ParamDecl ByteAlign *Expr @@ -47,13 +54,22 @@ type FnDecl struct { func (*FnDecl) isDecl() {} -// GlobalVarDecl represents a global variable declaration. +// GlobalVarFlags is a bitfield for global variable declaration options. +type GlobalVarFlags uint8 + +const ( + GlobalVarConst GlobalVarFlags = 1 << iota + GlobalVarExport + GlobalVarExtern + GlobalVarThreadLocal +) + +// GlobalVarDecl represents a top-level (global) variable or constant declaration. +// These are only allowed at the container/module scope and use a restricted syntax: +// no destructuring or multi-var declarations, just a single name and optional type/initializer. type GlobalVarDecl struct { - Export bool - Extern bool + Flags GlobalVarFlags ExternName string // Optional string for extern - ThreadLocal bool - Const bool Name string Type TypeExpr // Optional ByteAlign *Expr @@ -95,6 +111,25 @@ type ParamDecl struct { Type TypeExpr // 'anytype' if empty } +// ContainerDecl represents a struct, enum, union, or opaque declaration. +type ContainerDecl struct { + Extern bool + Packed bool + Kind string // "struct", "enum", "union", "opaque" + TagType TypeExpr // Optional (for enum/union) + Fields []*ContainerMember + DocComment DocComment +} + +func (*ContainerDecl) isDecl() {} + +// ErrorSetDecl represents an error set declaration. +type ErrorSetDecl struct { + Names []string +} + +func (*ErrorSetDecl) isDecl() {} + // Block represents a block of statements. type Block struct { Label string // Optional @@ -113,7 +148,8 @@ type ExprStmt struct { func (*ExprStmt) isStmt() {} -// VarDeclStmt represents a variable or const declaration at statement level, supporting destructuring and multi-var declarations. +// VarDeclStmt represents a local variable or constant declaration statement inside a function or block. +// These support destructuring and multi-var declarations, and are only valid at statement/block scope. type VarDeclStmt struct { Const bool Pattern VarPattern // Destructuring or multiple variable names @@ -276,21 +312,6 @@ type AsmInputItem struct { Expr Expr } -// ContainerDecl represents a struct, enum, union, or opaque declaration. -type ContainerDecl struct { - Extern bool - Packed bool - Kind string // "struct", "enum", "union", "opaque" - TagType TypeExpr // Optional (for enum/union) - Fields []*ContainerMember - DocComment DocComment -} - -// ErrorSetDecl represents an error set declaration. -type ErrorSetDecl struct { - Names []string -} - // InitListExpr represents an initializer list. // Exactly one of Fields, Values, or Empty must be set (non-nil/non-empty or true). type InitListExpr struct { @@ -381,6 +402,12 @@ type Expr interface{} // TypeExpr is any type expression. type TypeExpr interface{} +// PrefixTypeExpr represents a type with a string prefix. Examples include optionals and pointers. +type PrefixTypeExpr struct { + Op string + Base TypeExpr +} + // DocComment represents a doc comment (/// or //! lines). // Newlines in the string automatically add more comments in the output. type DocComment string diff --git a/internal/zig/asthelpers.go b/internal/zig/asthelpers.go new file mode 100644 index 0000000..2bea368 --- /dev/null +++ b/internal/zig/asthelpers.go @@ -0,0 +1,136 @@ +package zig + +// Id creates an identifier expression. +func Id(name string) *Identifier { + return &Identifier{Name: name} +} + +// Lit creates a literal expression. +func Lit(kind, value string) *Literal { + return &Literal{Kind: kind, Value: value} +} + +// FieldAccess creates a field access expression. +func FieldAccess(recv Expr, field string) *FieldAccessExpr { + return &FieldAccessExpr{Receiver: recv, Field: field} +} + +// Call creates a function call expression. +func Call(fun Expr, args ...Expr) *CallExpr { + return &CallExpr{Fun: fun, Args: args} +} + +// InitList creates an initializer list expression. +func InitList(exprs ...Expr) *InitListExpr { + if len(exprs) == 0 { + return &InitListExpr{Empty: true} + } + return &InitListExpr{Values: exprs} +} + +// DeclareFn creates a function declaration with flags. +func DeclareFn(name string, retType TypeExpr, body *Block, params []*ParamDecl, flags FnFlags) *FnDecl { + return &FnDecl{ + Name: name, + ReturnType: retType, + Body: body, + Flags: flags, + Params: params, + } +} + +// NewExprStmt creates an expression statement. +func NewExprStmt(expr Expr) *ExprStmt { + return &ExprStmt{Expr: expr} +} + +// DeclareVarStmt creates a variable or const declaration statement. +func DeclareVarStmt(constant bool, names []string, typ TypeExpr, value Expr) *VarDeclStmt { + return &VarDeclStmt{ + Const: constant, + Pattern: VarPattern{Names: names}, + Type: typ, + Value: value, + } +} + +// DeclareGlobalVar creates a global variable declaration with flags. +func DeclareGlobalVar(name string, value Expr, flags GlobalVarFlags) *GlobalVarDecl { + return &GlobalVarDecl{ + Flags: flags, + Name: name, + Value: value, + } +} + +// Return creates a return statement. +func Return(value Expr) *ReturnStmt { + return &ReturnStmt{Value: value} +} + +// If creates an if statement. +func If(cond Expr, then, els Stmt) *IfStmt { + return &IfStmt{Cond: cond, Then: then, Else: els} +} + +// NewBlock creates a block of statements. +func NewBlock(stmts ...Stmt) *Block { + return &Block{Stmts: stmts} +} + +// NewBlock creates a block statement containing a block of statements. +func NewBlockStmt(stmts ...Stmt) *BlockStmt { + return &BlockStmt{ + Block: &Block{Stmts: stmts}, + } +} + +// Try creates a try expression. +func Try(expr Expr) *TryExpr { + return &TryExpr{Expr: expr} +} + +// Binary creates a binary expression. +func Binary(op string, left, right Expr) *BinaryExpr { + return &BinaryExpr{Op: op, Left: left, Right: right} +} + +// Param creates a function parameter declaration. +func Param(name string, typ TypeExpr) *ParamDecl { + return &ParamDecl{Name: name, Type: typ} +} + +// StructDecl creates a struct declaration with the given fields/members. +func StructDecl(fields ...*ContainerMember) *ContainerDecl { + return &ContainerDecl{ + Kind: "struct", + Fields: fields, + } +} + +// Field creates a struct field (optionally with initializer). +func Field(name string, typ TypeExpr, byteAlign *Expr, value Expr) *ContainerMember { + return &ContainerMember{ + Field: &ContainerField{ + Name: name, + Type: typ, + ByteAlign: byteAlign, + Value: value, + }, + } +} + +// Method creates a method (function declaration) as a struct member. +func Method(fn *FnDecl) *ContainerMember { + return &ContainerMember{Decl: fn} +} + +// OptionalType creates an optional type (?T). +func OptionalType(base TypeExpr) *PrefixTypeExpr { + return &PrefixTypeExpr{Op: "?", Base: base} +} + +// PointerType creates a pointer type (*T). +func PointerType(base TypeExpr) *PrefixTypeExpr { + return &PrefixTypeExpr{Op: "*", Base: base} +} diff --git a/internal/zig/zig.go b/internal/zig/zig.go index 88b8f47..ee5052d 100644 --- a/internal/zig/zig.go +++ b/internal/zig/zig.go @@ -3,6 +3,7 @@ package zig import ( "fmt" "io" + "strings" ) type formatter struct { @@ -15,9 +16,9 @@ type formatter struct { // indentStr defines the string used for each indentation level (4 spaces). const indentStr = " " -// Writef writes formatted text to the underlying writer and updates line/col counters. +// writef writes formatted text to the underlying writer and updates line/col counters. // It also handles indentation after newlines when appropriate. -func (f *formatter) Writef(format string, a ...any) { +func (f *formatter) writef(format string, a ...any) { s := fmt.Sprintf(format, a...) for i, r := range s { if r == '\n' { @@ -62,37 +63,52 @@ func Write(w io.Writer, root *Root) (err error) { } } }() - f := &formatter{w: w, line: 1, col: 1, indent: 0} + sb := &strings.Builder{} + f := &formatter{w: sb, line: 1, col: 1, indent: 0} if root.ContainerDocComment != "" { - f.Writef("//! %s\n\n", root.ContainerDocComment) + f.writef("//! %s\n\n", root.ContainerDocComment) } - for _, member := range root.ContainerMembers { - // Only handle Decl for now (fields not needed for hello world) + for i, member := range root.ContainerMembers { if member.Decl != nil { + // Only emit a leading newline before a function/global after the first declaration + if i > 0 { + f.writef("\n") + } writeDecl(f, member.Decl) } } - return nil + out := sb.String() + if len(out) == 0 || out[len(out)-1] != '\n' { + out += "\n" + } + _, err = w.Write([]byte(out)) + return err } // writeDecl emits a top-level declaration. func writeDecl(f *formatter, decl Decl) { switch d := decl.(type) { case *FnDecl: - f.Writef("\nfn %s(", d.Name) + if d.Flags&FnExport != 0 { + f.writef("pub ") + } + f.writef("fn %s(", d.Name) writeParams(f, d.Params) - f.Writef(") ") + f.writef(") ") writeTypeExpr(f, d.ReturnType) writeBlock(f, d.Body) case *GlobalVarDecl: - if d.Const { - f.Writef("const %s = ", d.Name) + if d.Flags&GlobalVarConst != 0 { + f.writef("const %s = ", d.Name) } else { - f.Writef("var %s = ", d.Name) + f.writef("var %s = ", d.Name) } writeExpr(f, d.Value) - f.Writef(";\n") + f.writef(";\n") + case *ContainerDecl: + f.writef("struct ") + writeStructBody(f, d) } } @@ -100,10 +116,10 @@ func writeDecl(f *formatter, decl Decl) { func writeParams(f *formatter, params []*ParamDecl) { for i, param := range params { if i > 0 { - f.Writef(", ") + f.writef(", ") } if param.Name != "" { - f.Writef("%s: ", param.Name) + f.writef("%s: ", param.Name) } writeTypeExpr(f, param.Type) } @@ -113,48 +129,111 @@ func writeParams(f *formatter, params []*ParamDecl) { func writeTypeExpr(f *formatter, typ TypeExpr) { switch t := typ.(type) { case *Identifier: - f.Writef("%s", t.Name) + f.writef("%s", t.Name) + case *PrefixTypeExpr: + f.writef("%s", t.Op) + writeTypeExpr(f, t.Base) case nil: // nothing default: - f.Writef("%v", t) + f.writef("%v", t) } } // writeBlock emits a block, handling indentation for statements and the closing brace. func writeBlock(f *formatter, block *Block) { if block == nil { - f.Writef(";") + f.writef(";") return } - f.Writef(" {\n") - f.indent++ // Increase indentation for block contents. - for i, stmt := range block.Stmts { - f.writeIndent() // Indent each statement. + f.writef(" {\n") + f.indent++ + for _, stmt := range block.Stmts { + f.writeIndent() writeStmt(f, stmt) - if i < len(block.Stmts)-1 { - f.Writef("\n") - } + f.writef("\n") } - f.indent-- // Decrease indentation before closing brace. - f.Writef("\n") - f.writeIndent() // Indent the closing brace. - f.Writef("}\n") + f.indent-- + f.writeIndent() + f.writef("}") } // writeStmt emits a statement. Indentation is handled by the caller (writeBlock). func writeStmt(f *formatter, stmt Stmt) { switch s := stmt.(type) { case *ReturnStmt: - f.Writef("return") + f.writef("return") if s.Value != nil { - f.Writef(" ") + f.writef(" ") writeExpr(f, s.Value) } - f.Writef(";") + f.writef(";") case *ExprStmt: writeExpr(f, s.Expr) - f.Writef(";") + f.writef(";") + case *VarDeclStmt: + if s.Const { + f.writef("const ") + } else { + f.writef("var ") + } + for i, name := range s.Pattern.Names { + if i > 0 { + f.writef(", ") + } + f.writef("%s", name) + } + if s.Type != nil { + f.writef(": ") + writeTypeExpr(f, s.Type) + } + if s.Value != nil { + f.writef(" = ") + writeExpr(f, s.Value) + } + f.writef(";") + case *BlockStmt: + writeBlock(f, s.Block) + case *IfStmt: + f.writef("if (") + writeExpr(f, s.Cond) + f.writef(")") + + // Always write the then branch as a block + if block, ok := s.Then.(*BlockStmt); ok { + writeBlock(f, block.Block) + } else { + f.writef(" ") + writeStmt(f, s.Then) + } + if s.Else != nil { + f.writef(" else") + if block, ok := s.Else.(*BlockStmt); ok { + writeBlock(f, block.Block) + } else { + f.writef(" ") + writeStmt(f, s.Else) + } + } + case *LoopStmt: + if s.Kind == "while" { + f.writef("while (") + if wp, ok := s.Prefix.(*WhilePrefix); ok { + writeExpr(f, wp.Cond) + if wp.Continue != nil { + f.writef(") : (") + writeExpr(f, wp.Continue) + } + } + f.writef(")") + // Always write the body as a block + if block, ok := s.Body.(*BlockStmt); ok { + writeBlock(f, block.Block) + } else { + f.writef(" ") + writeStmt(f, s.Body) + } + } } } @@ -162,34 +241,90 @@ func writeStmt(f *formatter, stmt Stmt) { func writeExpr(f *formatter, expr Expr) { switch e := expr.(type) { case *Identifier: - f.Writef("%s", e.Name) + f.writef("%s", e.Name) case *CallExpr: writeExpr(f, e.Fun) - f.Writef("(") + f.writef("(") for i, arg := range e.Args { if i > 0 { - f.Writef(", ") + f.writef(", ") } writeExpr(f, arg) } - f.Writef(")") + f.writef(")") case *FieldAccessExpr: writeExpr(f, e.Receiver) - f.Writef(".%s", e.Field) + f.writef(".%s", e.Field) case *Literal: switch e.Kind { case "string": - f.Writef("%q", e.Value) + f.writef(`"%v"`, e.Value) default: - f.Writef("%v", e.Value) + f.writef("%v", e.Value) } case *InitListExpr: if e.Empty { - f.Writef(".{}") - } else { - f.Writef(".{") - // TODO - f.Writef("}") + f.writef(".{}") + } else if len(e.Values) > 0 { + f.writef(".{") + for i, v := range e.Values { + if i > 0 { + f.writef(", ") + } + writeExpr(f, v) + } + f.writef("}") } + case *ContainerDecl: + if e.Kind == "struct" { + f.writef("struct ") + writeStructBody(f, e) + } else { + panic("not implemented: " + e.Kind) + } + case *TryExpr: + f.writef("try ") + writeExpr(f, e.Expr) + case *BinaryExpr: + writeExpr(f, e.Left) + f.writef(" %s ", e.Op) + writeExpr(f, e.Right) } } + +// writeStructBody emits the body of a struct/union/enum/opaque declaration. +func writeStructBody(f *formatter, decl *ContainerDecl) { + f.writef("{\n") + f.indent++ + for _, member := range decl.Fields { + if member.Field != nil { + // Field or const + if member.Field.Type == nil && member.Field.Value != nil { + // const field + f.writeIndent() + f.writef("const %s = ", member.Field.Name) + writeExpr(f, member.Field.Value) + f.writef(";\n") + } else { + // regular field + f.writeIndent() + f.writef("%s: ", member.Field.Name) + writeTypeExpr(f, member.Field.Type) + if member.Field.Value != nil { + f.writef(" = ") + writeExpr(f, member.Field.Value) + } + f.writef(",\n") + } + } else if member.Decl != nil { + // Method or nested decl + f.writef("\n") + f.writeIndent() + writeDecl(f, member.Decl) + f.writef("\n") + } + } + f.indent-- + f.writeIndent() + f.writef("}") +} diff --git a/internal/zig/zig_test.go b/internal/zig/zig_test.go index d4fe6c4..a9f2d38 100644 --- a/internal/zig/zig_test.go +++ b/internal/zig/zig_test.go @@ -1,7 +1,6 @@ package zig_test import ( - "cmp" "fmt" "strings" "testing" @@ -9,13 +8,45 @@ import ( "git.frop.prof/luke/go-zig-compiler/internal/zig" ) -func Expect[T cmp.Ordered](expected, actual T) error { +func Expect(expected, actual string) error { if expected != actual { - return fmt.Errorf("\nExpected: %v\nActual: %v", expected, actual) + message := fmt.Sprintf("\nExpected:\n%v\nActual:\n%v", expected, actual) + message += fmt.Sprintf("\n%q", actual) + + // Find the first difference (handle rune/byte mismatch) + minLen := len(expected) + if len(actual) < minLen { + minLen = len(actual) + } + for i := 0; i < minLen; i++ { + if expected[i] != actual[i] { + message += fmt.Sprintf("\nDifference at index %d: expected '%c', got '%c'", i, expected[i], actual[i]) + break + } + } + if len(expected) != len(actual) { + message += fmt.Sprintf("\nLength mismatch: expected %d bytes, got %d bytes", len(expected), len(actual)) + } + + return fmt.Errorf("%s", message) } return nil } +// runZigASTTest is a helper to check if the Zig AST renders as expected. +func runZigASTTest(t *testing.T, expected string, root *zig.Root) { + t.Helper() + sb := new(strings.Builder) + err := zig.Write(sb, root) + if err != nil { + t.FailNow() + } + actual := sb.String() + if err = Expect(expected, actual); err != nil { + t.Error(err) + } +} + func TestHelloWorld(t *testing.T) { expected := `//! Hello, world! @@ -30,52 +61,179 @@ fn main() void { ContainerDocComment: "Hello, world!", ContainerMembers: []*zig.ContainerMember{ { - Decl: &zig.GlobalVarDecl{ - Const: true, - Name: "std", - Value: &zig.CallExpr{ - Fun: &zig.Identifier{Name: "@import"}, - Args: []zig.Expr{ - &zig.Literal{Kind: "string", Value: "std"}, - }, - }, - }, + Decl: zig.DeclareGlobalVar("std", zig.Call( + zig.Id("@import"), + zig.Lit("string", "std"), + ), zig.GlobalVarConst), }, { - Decl: &zig.FnDecl{ - Name: "main", - ReturnType: &zig.Identifier{Name: "void"}, - Body: &zig.Block{ - Stmts: []zig.Stmt{ - &zig.ExprStmt{ - Expr: &zig.CallExpr{ - Fun: &zig.FieldAccessExpr{ - Receiver: &zig.FieldAccessExpr{ - Receiver: &zig.Identifier{Name: "std"}, - Field: "debug", - }, - Field: "print", - }, - Args: []zig.Expr{ - &zig.Literal{Kind: "string", Value: "Hello, world!\n"}, - &zig.InitListExpr{Empty: true}, - }, - }, - }, - }, - }, - }, + Decl: zig.DeclareFn( + "main", + zig.Id("void"), + zig.NewBlock( + zig.NewExprStmt( + zig.Call( + zig.FieldAccess( + zig.FieldAccess(zig.Id("std"), "debug"), + "print", + ), + zig.Lit("string", `Hello, world!\n`), + zig.InitList(), + ), + ), + ), + nil, // params + 0, // flags + ), }, }, } - sb := new(strings.Builder) - err := zig.Write(sb, root) - if err != nil { - t.FailNow() - } - actual := sb.String() - if err = Expect(expected, actual); err != nil { - t.Error(err) - } + runZigASTTest(t, expected, root) +} + +func TestEvenOdd(t *testing.T) { + expected := `//! Abc + +const std = @import("std"); + +pub fn main() !void { + const stdout = std.io.getStdOut().writer(); + var i: i32 = 1; + while (i <= 5) : (i += 1) { + if (i % 2 == 0) { + try stdout.writeAll("even: {d}\n", .{i}); + } else { + try stdout.writeAll("odd: {d}\n", .{i}); + } + } +} +` + + root := &zig.Root{ + ContainerDocComment: "Abc", + ContainerMembers: []*zig.ContainerMember{ + { + Decl: zig.DeclareGlobalVar("std", + zig.Call( + zig.Id("@import"), + zig.Lit("string", "std"), + ), + zig.GlobalVarConst, + ), + }, + { + Decl: zig.DeclareFn( + "main", + zig.Id("!void"), + zig.NewBlock( + // const stdout = std.io.getStdOut().writer(); + zig.DeclareVarStmt(true, []string{"stdout"}, nil, + zig.Call( + zig.FieldAccess( + zig.Call( + zig.FieldAccess( + zig.FieldAccess(zig.Id("std"), "io"), + "getStdOut", + ), + ), + "writer", + ), + ), + ), + // var i: i32 = 1; + zig.DeclareVarStmt(false, []string{"i"}, zig.Id("i32"), zig.Lit("int", "1")), + // while (i <= 5) : (i += 1) { ... } + &zig.LoopStmt{ + Kind: "while", + Prefix: &zig.WhilePrefix{ + Cond: zig.Binary("<=", zig.Id("i"), zig.Lit("int", "5")), + Continue: zig.Binary("+=", zig.Id("i"), zig.Lit("int", "1")), + }, + Body: zig.NewBlockStmt( + &zig.IfStmt{ + Cond: zig.Binary("==", + zig.Binary("%", zig.Id("i"), zig.Lit("int", "2")), + zig.Lit("int", "0"), + ), + Then: zig.NewBlockStmt( + zig.NewExprStmt( + zig.Try( + zig.Call( + zig.FieldAccess(zig.Id("stdout"), "writeAll"), + zig.Lit("string", `even: {d}\n`), + zig.InitList(zig.Id("i")), + ), + ), + ), + ), + Else: zig.NewBlockStmt( + zig.NewExprStmt( + zig.Try( + zig.Call( + zig.FieldAccess(zig.Id("stdout"), "writeAll"), + zig.Lit("string", `odd: {d}\n`), + zig.InitList(zig.Id("i")), + ), + ), + ), + ), + }, + ), + }, + ), + nil, + zig.FnExport, + ), + }, + }, + } + + runZigASTTest(t, expected, root) +} + +func TestStructWithFieldsAndMethod(t *testing.T) { + expected := `const MyStruct = struct { + const Self = @This(); + data: ?*u8, + count: u32, + + pub fn reset(self: Self) void { + self.count = 0; + } +}; +` + + root := &zig.Root{ + ContainerMembers: []*zig.ContainerMember{ + { + Decl: zig.DeclareGlobalVar("MyStruct", + zig.StructDecl( + zig.Field("Self", nil, nil, zig.Call(zig.Id("@This"))), + zig.Field("data", zig.OptionalType(zig.PointerType(zig.Id("u8"))), nil, nil), + zig.Field("count", zig.Id("u32"), nil, nil), + zig.Method(zig.DeclareFn( + "reset", + zig.Id("void"), + zig.NewBlock( + zig.NewExprStmt( + zig.Binary("=", + zig.FieldAccess(zig.Id("self"), "count"), + zig.Lit("int", "0"), + ), + ), + ), + []*zig.ParamDecl{ + zig.Param("self", zig.Id("Self")), + }, + zig.FnExport, + )), + ), + zig.GlobalVarConst, + ), + }, + }, + } + + runZigASTTest(t, expected, root) }