Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- `zb store object delete` is no longer flaky
([#135](https://github.com/256lights/zb/issues/135))
([#135](https://github.com/256lights/zb/issues/135)).
- Lua operator metamethods now receive their arguments in the correct order
when one of the operands is a constant
([#152](https://github.com/256lights/zb/issues/152)).

## [0.1.0][] - 2025-06-15

Expand Down
22 changes: 10 additions & 12 deletions internal/lua/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,12 +692,11 @@ func (l *State) exec(ctx context.Context) (err error) {
if err != nil {
return err
}
result, err := l.callArithmeticMetamethod(
ctx,
prevOperator.TagMethod(),
*ra,
integerValue(luacode.SignedArg(i.ArgB())),
)
var arg1, arg2 value = *ra, integerValue(luacode.SignedArg(i.ArgB()))
if i.K() {
arg1, arg2 = arg2, arg1
}
result, err := l.callArithmeticMetamethod(ctx, prevOperator.TagMethod(), arg1, arg2)
if err != nil {
return err
}
Expand Down Expand Up @@ -725,12 +724,11 @@ func (l *State) exec(ctx context.Context) (err error) {
if err != nil {
return err
}
result, err := l.callArithmeticMetamethod(
ctx,
prevOperator.TagMethod(),
*ra,
importConstant(kb),
)
arg1, arg2 := *ra, importConstant(kb)
if i.K() {
arg1, arg2 = arg2, arg1
}
result, err := l.callArithmeticMetamethod(ctx, prevOperator.TagMethod(), arg1, arg2)
if err != nil {
return err
}
Expand Down
153 changes: 153 additions & 0 deletions internal/lua/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -785,4 +785,157 @@ func TestVM(t *testing.T) {
t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, wantResult)
}
})

t.Run("FlippedAddImmediateMetamethod", func(t *testing.T) {
ctx := context.Background()
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

// Set a metatable for nil which treats the nil like 0 when adding.
state.PushNil()
var isNil [2]bool
NewPureLib(state, map[string]Function{
luacode.TagMethodAdd.String(): func(ctx context.Context, l *State) (int, error) {
for i := range 2 {
isNil[i] = l.IsNil(1 + i)
if isNil[i] {
l.PushInteger(0)
l.Replace(1 + i)
}
}
if err := l.Arithmetic(ctx, luacode.Add); err != nil {
return 0, err
}
return 1, nil
},
})
if err := state.SetMetatable(-2); err != nil {
t.Fatal(err)
}

const source = `return 1 + nil`
if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(ctx, 0, 1); err != nil {
t.Fatal(err)
}

if isNil[0] || !isNil[1] {
t.Errorf("when calling __add, (x == nil), (y == nil) = %t, %t; want false, true", isNil[0], isNil[1])
}

const wantResult = 1
if got, want := state.Type(-1), TypeNumber; got != want {
t.Fatalf("state.Type(-1) = %v; want %v", got, want)
} else if got, ok := state.ToInteger(-1); got != wantResult || !ok {
t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, wantResult)
}
})

t.Run("FlippedBOrImmediateMetamethod", func(t *testing.T) {
ctx := context.Background()
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

// Set a metatable for nil which treats the nil like 0 when adding.
state.PushNil()
var isNil [2]bool
NewPureLib(state, map[string]Function{
luacode.TagMethodBOr.String(): func(ctx context.Context, l *State) (int, error) {
for i := range 2 {
isNil[i] = l.IsNil(1 + i)
if isNil[i] {
l.PushInteger(0)
l.Replace(1 + i)
}
}
if err := l.Arithmetic(ctx, luacode.BitwiseOr); err != nil {
return 0, err
}
return 1, nil
},
})
if err := state.SetMetatable(-2); err != nil {
t.Fatal(err)
}

const source = `return 1 | nil`
if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(ctx, 0, 1); err != nil {
t.Fatal(err)
}

if isNil[0] || !isNil[1] {
t.Errorf("when calling __bor, (x == nil), (y == nil) = %t, %t; want false, true", isNil[0], isNil[1])
}

const wantResult = 1
if got, want := state.Type(-1), TypeNumber; got != want {
t.Fatalf("state.Type(-1) = %v; want %v", got, want)
} else if got, ok := state.ToInteger(-1); got != wantResult || !ok {
t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, wantResult)
}
})

t.Run("FlippedAddConstantMetamethod", func(t *testing.T) {
ctx := context.Background()
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

// Set a metatable for nil which treats the nil like 0 when adding.
state.PushNil()
var isNil [2]bool
NewPureLib(state, map[string]Function{
luacode.TagMethodAdd.String(): func(ctx context.Context, l *State) (int, error) {
for i := range 2 {
isNil[i] = l.IsNil(1 + i)
if isNil[i] {
l.PushInteger(0)
l.Replace(1 + i)
}
}
if err := l.Arithmetic(ctx, luacode.Add); err != nil {
return 0, err
}
return 1, nil
},
})
if err := state.SetMetatable(-2); err != nil {
t.Fatal(err)
}

const source = `return 3.14 + nil`
if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(ctx, 0, 1); err != nil {
t.Fatal(err)
}

if isNil[0] || !isNil[1] {
t.Errorf("when calling __add, (x == nil), (y == nil) = %t, %t; want false, true", isNil[0], isNil[1])
}

const wantResult = 3.14
if got, want := state.Type(-1), TypeNumber; got != want {
t.Fatalf("state.Type(-1) = %v; want %v", got, want)
} else if got, ok := state.ToNumber(-1); got != wantResult || !ok {
t.Errorf("state.ToNumber(-1) = %g, %t; want %g, true", got, ok, wantResult)
}
})
}
33 changes: 12 additions & 21 deletions internal/luacode/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,6 @@ func (p *parser) codeCommutative(fs *funcState, operator binaryOperator, e1, e2
flip := e1.isNumeral()
if flip {
e1, e2 = e2, e1
flip = true
}
if i, isInt := e2.intConstant(); isInt && fitsSignedArg(i) && operator == binaryOperatorAdd {
return p.codeBinaryExpImmediate(fs, OpAddI, e1, e2, flip, line, TagMethodAdd)
Expand All @@ -570,22 +569,24 @@ func (p *parser) codeCommutative(fs *funcState, operator binaryOperator, e1, e2
}

// codeBitwise appends instructions for bitwise operators
// to fs.Code.
//
// Equivalent to `codebitwise` in upstream Lua.
func (p *parser) codeBitwise(fs *funcState, operator binaryOperator, e1, e2 expressionDescriptor, line int) (expressionDescriptor, error) {
// All operations are commutative,
// so if first operand is a numeric constant,
// change order of operands to try to use an immediate or K operator.
flip := e1.kind == expressionKindIntConstant
if flip {
e1, e2 = e2, e1
}
if e2.kind == expressionKindIntConstant {
switch {
case e1.kind == expressionKindIntConstant:
// All operations are commutative,
// so if first operand is a numeric constant,
// change order of operands to try to use an immediate or K operator.
if e1, _, ok := p.toConstantTable(fs, e1); ok {
return p.codeBinaryExpConstant(fs, operator, e2, e1, true, line)
}
case e2.kind == expressionKindIntConstant:
if e2, _, ok := p.toConstantTable(fs, e2); ok {
return p.codeBinaryExpConstant(fs, operator, e1, e2, flip, line)
return p.codeBinaryExpConstant(fs, operator, e1, e2, false, line)
}
}
return p.codeBinaryExpNoConstants(fs, operator, e1, e2, flip, line)
return p.codeBinaryExp(fs, operator, e1, e2, line)
}

// codeArithmetic appends instructions for an arithmetic binary operator
Expand All @@ -598,17 +599,7 @@ func (p *parser) codeArithmetic(fs *funcState, operator binaryOperator, e1, e2 e
return p.codeBinaryExpConstant(fs, operator, e1, e2, flip, line)
}
}
return p.codeBinaryExpNoConstants(fs, operator, e1, e2, flip, line)
}

// codeBinaryExpNoConstants appends the instructions
// for a binary expression without constant operands
// to fs.Code.
//
// Equivalent to `codebinNoK` in upstream Lua.
func (p *parser) codeBinaryExpNoConstants(fs *funcState, operator binaryOperator, e1, e2 expressionDescriptor, flip bool, line int) (expressionDescriptor, error) {
if flip {
// Back to original order.
e1, e2 = e2, e1
}
return p.codeBinaryExp(fs, operator, e1, e2, line)
Expand Down