From e9f29a9a76291d436f65063ca01c1d79eefe5f36 Mon Sep 17 00:00:00 2001 From: ivansukach <47761294+ivansukach@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:05:58 +0100 Subject: [PATCH 01/10] warp message: addMessage (#64) * warp message: addMessage * introduce codec, codec manager; add unit tests; add all necessary utilities in MVP * implement bls signature methods; add test cases * get rid of avalanchego dependency; use cometbft instead * delete unused variable * lint go imports * lint go imports --------- Co-authored-by: Ivan Sukach --- go.mod | 9 +- go.sum | 10 +- utils/bimap/bimap.go | 154 +++++ utils/bimap/bimap_test.go | 366 ++++++++++ utils/codec/codec.go | 28 + utils/codec/linearcodec/codec.go | 116 ++++ utils/codec/manager.go | 110 +++ utils/codec/reflectcodec/struct_fielder.go | 96 +++ utils/codec/reflectcodec/type_codec.go | 744 +++++++++++++++++++++ utils/codec/registry.go | 13 + utils/crypto/bls/public.go | 14 + utils/crypto/bls/secret.go | 62 ++ utils/crypto/bls/secret_test.go | 49 ++ utils/crypto/bls/signature.go | 57 ++ utils/crypto/bls/signature_test.go | 51 ++ utils/set/set.go | 195 ++++++ utils/set/set_benchmark_test.go | 39 ++ utils/set/set_test.go | 228 +++++++ utils/warp/codec_manager.go | 19 + utils/warp/predefined.go | 3 + utils/warp/signature.go | 15 + utils/warp/signer.go | 52 ++ utils/warp/signer_test.go | 90 +++ utils/warp/unsigned_message.go | 72 ++ utils/warp/unsigned_message_test.go | 36 + warp/backend.go | 54 ++ 26 files changed, 2674 insertions(+), 8 deletions(-) create mode 100644 utils/bimap/bimap.go create mode 100644 utils/bimap/bimap_test.go create mode 100644 utils/codec/codec.go create mode 100644 utils/codec/linearcodec/codec.go create mode 100644 utils/codec/manager.go create mode 100644 utils/codec/reflectcodec/struct_fielder.go create mode 100644 utils/codec/reflectcodec/type_codec.go create mode 100644 utils/codec/registry.go create mode 100644 utils/crypto/bls/public.go create mode 100644 utils/crypto/bls/secret.go create mode 100644 utils/crypto/bls/secret_test.go create mode 100644 utils/crypto/bls/signature.go create mode 100644 utils/crypto/bls/signature_test.go create mode 100644 utils/set/set.go create mode 100644 utils/set/set_benchmark_test.go create mode 100644 utils/set/set_test.go create mode 100644 utils/warp/codec_manager.go create mode 100644 utils/warp/predefined.go create mode 100644 utils/warp/signature.go create mode 100644 utils/warp/signer.go create mode 100644 utils/warp/signer_test.go create mode 100644 utils/warp/unsigned_message.go create mode 100644 utils/warp/unsigned_message_test.go create mode 100644 warp/backend.go diff --git a/go.mod b/go.mod index 4df39a6b..5b78e8ca 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/landslidenetwork/slide-sdk -go 1.22.7 +go 1.22.8 require ( cosmossdk.io/log v1.3.1 @@ -14,7 +14,9 @@ require ( github.com/prometheus/client_golang v1.19.0 github.com/prometheus/client_model v0.6.1 github.com/stretchr/testify v1.9.0 + github.com/supranational/blst v0.3.13 go.uber.org/mock v0.4.0 + golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 golang.org/x/sync v0.7.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 google.golang.org/grpc v1.64.1 @@ -51,9 +53,9 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect github.com/bgentry/speakeasy v0.1.1-0.20220910012023-760eaf8b6816 // indirect - github.com/bits-and-blooms/bitset v1.8.0 // indirect + github.com/bits-and-blooms/bitset v1.10.0 // indirect github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect - github.com/cenkalti/backoff/v4 v4.1.3 // indirect + github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chzyer/readline v1.5.1 // indirect @@ -187,7 +189,6 @@ require ( go.opentelemetry.io/otel/trace v1.24.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.25.0 // indirect - golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/net v0.27.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect golang.org/x/sys v0.22.0 // indirect diff --git a/go.sum b/go.sum index f67378b7..5ca38e26 100644 --- a/go.sum +++ b/go.sum @@ -276,8 +276,8 @@ github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d/go.mod h1:6QX/PXZ github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bgentry/speakeasy v0.1.1-0.20220910012023-760eaf8b6816 h1:41iFGWnSlI2gVpmOtVTJZNodLdLQLn/KsJqFvXwnd/s= github.com/bgentry/speakeasy v0.1.1-0.20220910012023-760eaf8b6816/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bits-and-blooms/bitset v1.8.0 h1:FD+XqgOZDUxxZ8hzoBFuV9+cGWY9CslN6d5MS5JVb4c= -github.com/bits-and-blooms/bitset v1.8.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bitset v1.10.0 h1:ePXTeiPEazB5+opbv5fr8umg2R/1NlzgDsyepwsSr88= +github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/btcsuite/btcd/btcec/v2 v2.3.2 h1:5n0X6hX0Zk+6omWcihdYvdAlGf2DfasC0GMf7DClJ3U= github.com/btcsuite/btcd/btcec/v2 v2.3.2/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/btcsuite/btcd/btcutil v1.1.3 h1:xfbtw8lwpp0G6NwSHb+UE67ryTFHJAiNuipusjXSohQ= @@ -290,8 +290,8 @@ github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= -github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4= -github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= +github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= +github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= @@ -1011,6 +1011,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/supranational/blst v0.3.13 h1:AYeSxdOMacwu7FBmpfloBz5pbFXDmJL33RuwnKtmTjk= +github.com/supranational/blst v0.3.13/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tendermint/go-amino v0.16.0 h1:GyhmgQKvqF82e2oZeuMSp9JTN0N09emoSZlb2lyGa2E= diff --git a/utils/bimap/bimap.go b/utils/bimap/bimap.go new file mode 100644 index 00000000..525766a1 --- /dev/null +++ b/utils/bimap/bimap.go @@ -0,0 +1,154 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bimap + +import ( + "bytes" + "encoding/json" + "errors" + + "golang.org/x/exp/maps" + + "github.com/landslidenetwork/slide-sdk/utils" +) + +var ( + _ json.Marshaler = (*BiMap[int, int])(nil) + _ json.Unmarshaler = (*BiMap[int, int])(nil) + + nullBytes = []byte("null") + errNotBijective = errors.New("map not bijective") +) + +type Entry[K, V any] struct { + Key K + Value V +} + +// BiMap is a bi-directional map. +type BiMap[K, V comparable] struct { + keyToValue map[K]V + valueToKey map[V]K +} + +// New creates a new empty bimap. +func New[K, V comparable]() *BiMap[K, V] { + return &BiMap[K, V]{ + keyToValue: make(map[K]V), + valueToKey: make(map[V]K), + } +} + +// Put the key value pair into the map. If either [key] or [val] was previously +// in the map, the previous entries will be removed and returned. +// +// Note: Unlike normal maps, it's possible that Put removes 0, 1, or 2 existing +// entries to ensure that mappings are one-to-one. +func (m *BiMap[K, V]) Put(key K, val V) []Entry[K, V] { + var removed []Entry[K, V] + oldVal, oldValDeleted := m.DeleteKey(key) + if oldValDeleted { + removed = append(removed, Entry[K, V]{ + Key: key, + Value: oldVal, + }) + } + oldKey, oldKeyDeleted := m.DeleteValue(val) + if oldKeyDeleted { + removed = append(removed, Entry[K, V]{ + Key: oldKey, + Value: val, + }) + } + m.keyToValue[key] = val + m.valueToKey[val] = key + return removed +} + +// GetKey that maps to the provided value. +func (m *BiMap[K, V]) GetKey(val V) (K, bool) { + key, ok := m.valueToKey[val] + return key, ok +} + +// GetValue that is mapped to the provided key. +func (m *BiMap[K, V]) GetValue(key K) (V, bool) { + val, ok := m.keyToValue[key] + return val, ok +} + +// HasKey returns true if [key] is in the map. +func (m *BiMap[K, _]) HasKey(key K) bool { + _, ok := m.keyToValue[key] + return ok +} + +// HasValue returns true if [val] is in the map. +func (m *BiMap[_, V]) HasValue(val V) bool { + _, ok := m.valueToKey[val] + return ok +} + +// DeleteKey removes [key] from the map and returns the value it mapped to. +func (m *BiMap[K, V]) DeleteKey(key K) (V, bool) { + val, ok := m.keyToValue[key] + if !ok { + return utils.Zero[V](), false + } + delete(m.keyToValue, key) + delete(m.valueToKey, val) + return val, true +} + +// DeleteValue removes [val] from the map and returns the key that mapped to it. +func (m *BiMap[K, V]) DeleteValue(val V) (K, bool) { + key, ok := m.valueToKey[val] + if !ok { + return utils.Zero[K](), false + } + delete(m.keyToValue, key) + delete(m.valueToKey, val) + return key, true +} + +// Keys returns the keys of the map. The keys will be in an indeterminate order. +func (m *BiMap[K, _]) Keys() []K { + return maps.Keys(m.keyToValue) +} + +// Values returns the values of the map. The values will be in an indeterminate +// order. +func (m *BiMap[_, V]) Values() []V { + return maps.Values(m.keyToValue) +} + +// Len return the number of entries in this map. +func (m *BiMap[K, V]) Len() int { + return len(m.keyToValue) +} + +func (m *BiMap[K, V]) MarshalJSON() ([]byte, error) { + return json.Marshal(m.keyToValue) +} + +func (m *BiMap[K, V]) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, nullBytes) { + return nil + } + var keyToValue map[K]V + if err := json.Unmarshal(b, &keyToValue); err != nil { + return err + } + valueToKey := make(map[V]K, len(keyToValue)) + for k, v := range keyToValue { + valueToKey[v] = k + } + if len(keyToValue) != len(valueToKey) { + return errNotBijective + } + + m.keyToValue = keyToValue + m.valueToKey = valueToKey + return nil +} diff --git a/utils/bimap/bimap_test.go b/utils/bimap/bimap_test.go new file mode 100644 index 00000000..1792bec9 --- /dev/null +++ b/utils/bimap/bimap_test.go @@ -0,0 +1,366 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bimap + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBiMapPut(t *testing.T) { + tests := []struct { + name string + state *BiMap[int, int] + key int + value int + expectedRemoved []Entry[int, int] + expectedState *BiMap[int, int] + }{ + { + name: "none removed", + state: New[int, int](), + key: 1, + value: 2, + expectedRemoved: nil, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + }, + { + name: "key removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + key: 1, + value: 3, + expectedRemoved: []Entry[int, int]{ + { + Key: 1, + Value: 2, + }, + }, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 3, + }, + valueToKey: map[int]int{ + 3: 1, + }, + }, + }, + { + name: "value removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + key: 3, + value: 2, + expectedRemoved: []Entry[int, int]{ + { + Key: 1, + Value: 2, + }, + }, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 3: 2, + }, + valueToKey: map[int]int{ + 2: 3, + }, + }, + }, + { + name: "key and value removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + 3: 4, + }, + valueToKey: map[int]int{ + 2: 1, + 4: 3, + }, + }, + key: 1, + value: 4, + expectedRemoved: []Entry[int, int]{ + { + Key: 1, + Value: 2, + }, + { + Key: 3, + Value: 4, + }, + }, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 4, + }, + valueToKey: map[int]int{ + 4: 1, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + removed := test.state.Put(test.key, test.value) + require.Equal(test.expectedRemoved, removed) + require.Equal(test.expectedState, test.state) + }) + } +} + +func TestBiMapHasValueAndGetKey(t *testing.T) { + m := New[int, int]() + require.Empty(t, m.Put(1, 2)) + + tests := []struct { + name string + value int + expectedKey int + expectedExists bool + }{ + { + name: "fetch unknown", + value: 3, + expectedKey: 0, + expectedExists: false, + }, + { + name: "fetch known value", + value: 2, + expectedKey: 1, + expectedExists: true, + }, + { + name: "fetch known key", + value: 1, + expectedKey: 0, + expectedExists: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + exists := m.HasValue(test.value) + require.Equal(test.expectedExists, exists) + + key, exists := m.GetKey(test.value) + require.Equal(test.expectedKey, key) + require.Equal(test.expectedExists, exists) + }) + } +} + +func TestBiMapHasKeyAndGetValue(t *testing.T) { + m := New[int, int]() + require.Empty(t, m.Put(1, 2)) + + tests := []struct { + name string + key int + expectedValue int + expectedExists bool + }{ + { + name: "fetch unknown", + key: 3, + expectedValue: 0, + expectedExists: false, + }, + { + name: "fetch known key", + key: 1, + expectedValue: 2, + expectedExists: true, + }, + { + name: "fetch known value", + key: 2, + expectedValue: 0, + expectedExists: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + exists := m.HasKey(test.key) + require.Equal(test.expectedExists, exists) + + value, exists := m.GetValue(test.key) + require.Equal(test.expectedValue, value) + require.Equal(test.expectedExists, exists) + }) + } +} + +func TestBiMapDeleteKey(t *testing.T) { + tests := []struct { + name string + state *BiMap[int, int] + key int + expectedValue int + expectedRemoved bool + expectedState *BiMap[int, int] + }{ + { + name: "none removed", + state: New[int, int](), + key: 1, + expectedValue: 0, + expectedRemoved: false, + expectedState: New[int, int](), + }, + { + name: "key removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + key: 1, + expectedValue: 2, + expectedRemoved: true, + expectedState: New[int, int](), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + value, removed := test.state.DeleteKey(test.key) + require.Equal(test.expectedValue, value) + require.Equal(test.expectedRemoved, removed) + require.Equal(test.expectedState, test.state) + }) + } +} + +func TestBiMapDeleteValue(t *testing.T) { + tests := []struct { + name string + state *BiMap[int, int] + value int + expectedKey int + expectedRemoved bool + expectedState *BiMap[int, int] + }{ + { + name: "none removed", + state: New[int, int](), + value: 1, + expectedKey: 0, + expectedRemoved: false, + expectedState: New[int, int](), + }, + { + name: "key removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + value: 2, + expectedKey: 1, + expectedRemoved: true, + expectedState: New[int, int](), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + key, removed := test.state.DeleteValue(test.value) + require.Equal(test.expectedKey, key) + require.Equal(test.expectedRemoved, removed) + require.Equal(test.expectedState, test.state) + }) + } +} + +func TestBiMapLenAndLists(t *testing.T) { + require := require.New(t) + + m := New[int, int]() + require.Zero(m.Len()) + require.Empty(m.Keys()) + require.Empty(m.Values()) + + m.Put(1, 2) + require.Equal(1, m.Len()) + require.ElementsMatch([]int{1}, m.Keys()) + require.ElementsMatch([]int{2}, m.Values()) + + m.Put(2, 3) + require.Equal(2, m.Len()) + require.ElementsMatch([]int{1, 2}, m.Keys()) + require.ElementsMatch([]int{2, 3}, m.Values()) + + m.Put(1, 3) + require.Equal(1, m.Len()) + require.ElementsMatch([]int{1}, m.Keys()) + require.ElementsMatch([]int{3}, m.Values()) + + m.DeleteKey(1) + require.Zero(m.Len()) + require.Empty(m.Keys()) + require.Empty(m.Values()) +} + +func TestBiMapJSON(t *testing.T) { + require := require.New(t) + + expectedMap := New[int, int]() + expectedMap.Put(1, 2) + expectedMap.Put(2, 3) + + jsonBytes, err := json.Marshal(expectedMap) + require.NoError(err) + + expectedJSONBytes := []byte(`{"1":2,"2":3}`) + require.Equal(expectedJSONBytes, jsonBytes) + + var unmarshalledMap BiMap[int, int] + require.NoError(json.Unmarshal(jsonBytes, &unmarshalledMap)) + require.Equal(expectedMap, &unmarshalledMap) +} + +func TestBiMapInvalidJSON(t *testing.T) { + require := require.New(t) + + invalidJSONBytes := []byte(`{"1":2,"2":2}`) + var unmarshalledMap BiMap[int, int] + err := json.Unmarshal(invalidJSONBytes, &unmarshalledMap) + require.ErrorIs(err, errNotBijective) +} diff --git a/utils/codec/codec.go b/utils/codec/codec.go new file mode 100644 index 00000000..1ffc0213 --- /dev/null +++ b/utils/codec/codec.go @@ -0,0 +1,28 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package codec + +import ( + "errors" + + "github.com/landslidenetwork/slide-sdk/utils/wrappers" +) + +var ( + ErrUnsupportedType = errors.New("unsupported type") + ErrMaxSliceLenExceeded = errors.New("max slice length exceeded") + ErrDoesNotImplementInterface = errors.New("does not implement interface") + ErrUnexportedField = errors.New("unexported field") + ErrMarshalZeroLength = errors.New("can't marshal zero length value") + ErrUnmarshalZeroLength = errors.New("can't unmarshal zero length value") +) + +// Codec marshals and unmarshals +type Codec interface { + MarshalInto(interface{}, *wrappers.Packer) error + UnmarshalFrom(*wrappers.Packer, interface{}) error + + // Returns the size, in bytes, of [value] when it's marshaled + Size(value interface{}) (int, error) +} diff --git a/utils/codec/linearcodec/codec.go b/utils/codec/linearcodec/codec.go new file mode 100644 index 00000000..10bc02b1 --- /dev/null +++ b/utils/codec/linearcodec/codec.go @@ -0,0 +1,116 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package linearcodec + +import ( + "fmt" + "reflect" + "sync" + + "github.com/landslidenetwork/slide-sdk/utils/bimap" + "github.com/landslidenetwork/slide-sdk/utils/codec" + "github.com/landslidenetwork/slide-sdk/utils/codec/reflectcodec" + "github.com/landslidenetwork/slide-sdk/utils/wrappers" +) + +var ( + _ Codec = (*linearCodec)(nil) + _ codec.Codec = (*linearCodec)(nil) +) + +// Codec marshals and unmarshals +type Codec interface { + codec.Registry + codec.Codec + SkipRegistrations(int) +} + +// Codec handles marshaling and unmarshaling of structs +type linearCodec struct { + codec.Codec + + lock sync.RWMutex + nextTypeID uint32 + registeredTypes *bimap.BiMap[uint32, reflect.Type] +} + +// New returns a new, concurrency-safe codec; it allow to specify tagNames. +func New(tagNames []string) Codec { + hCodec := &linearCodec{ + nextTypeID: 0, + registeredTypes: bimap.New[uint32, reflect.Type](), + } + hCodec.Codec = reflectcodec.New(hCodec, tagNames) + return hCodec +} + +// NewDefault is a convenience constructor; it returns a new codec with default +// tagNames. +func NewDefault() Codec { + return New([]string{reflectcodec.DefaultTagName}) +} + +// Skip some number of type IDs +func (c *linearCodec) SkipRegistrations(num int) { + c.lock.Lock() + c.nextTypeID += uint32(num) + c.lock.Unlock() +} + +// RegisterType is used to register types that may be unmarshaled into an interface +// [val] is a value of the type being registered +func (c *linearCodec) RegisterType(val interface{}) error { + c.lock.Lock() + defer c.lock.Unlock() + + valType := reflect.TypeOf(val) + if c.registeredTypes.HasValue(valType) { + return fmt.Errorf("%w: %v", codec.ErrDuplicateType, valType) + } + + c.registeredTypes.Put(c.nextTypeID, valType) + c.nextTypeID++ + return nil +} + +func (*linearCodec) PrefixSize(reflect.Type) int { + // see PackPrefix implementation + return wrappers.IntLen +} + +func (c *linearCodec) PackPrefix(p *wrappers.Packer, valueType reflect.Type) error { + c.lock.RLock() + defer c.lock.RUnlock() + + typeID, ok := c.registeredTypes.GetKey(valueType) // Get the type ID of the value being marshaled + if !ok { + return fmt.Errorf("can't marshal unregistered type %q", valueType) + } + p.PackInt(typeID) // Pack type ID so we know what to unmarshal this into + return p.Err +} + +func (c *linearCodec) UnpackPrefix(p *wrappers.Packer, valueType reflect.Type) (reflect.Value, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + typeID := p.UnpackInt() // Get the type ID + if p.Err != nil { + return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: %w", p.Err) + } + // Get a type that implements the interface + implementingType, ok := c.registeredTypes.GetValue(typeID) + if !ok { + return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: unknown type ID %d", typeID) + } + // Ensure type actually does implement the interface + if !implementingType.Implements(valueType) { + return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: %s %w %s", + implementingType, + codec.ErrDoesNotImplementInterface, + valueType, + ) + } + return reflect.New(implementingType).Elem(), nil // instance of the proper type +} diff --git a/utils/codec/manager.go b/utils/codec/manager.go new file mode 100644 index 00000000..25273700 --- /dev/null +++ b/utils/codec/manager.go @@ -0,0 +1,110 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package codec + +import ( + "errors" + "fmt" + + "github.com/landslidenetwork/slide-sdk/utils/wrappers" +) + +const ( + // initial capacity of byte slice that values are marshaled into. + // Larger value --> need less memory allocations but possibly have allocated but unused memory + // Smaller value --> need more memory allocations but more efficient use of allocated memory + initialSliceCap = 128 +) + +var ( + ErrMarshalNil = errors.New("can't marshal nil pointer or interface") + ErrUnmarshalNil = errors.New("can't unmarshal nil") + ErrUnmarshalTooBig = errors.New("byte array exceeds maximum length") + ErrExtraSpace = errors.New("trailing buffer space") +) + +var _ Manager = (*manager)(nil) + +// Manager describes the functionality for managing codec versions. +type Manager interface { + // Size returns the size, in bytes, of [value] when it's marshaled + // using the codec with the given version. + // RegisterCodec must have been called with that version. + // If [value] is nil, returns [ErrMarshalNil] + Size(value interface{}) (int, error) + + // Marshal the given value using the codec with the given version. + // RegisterCodec must have been called with that version. + Marshal(source interface{}) (destination []byte, err error) + + // Unmarshal the given bytes into the given destination. [destination] must + // be a pointer or an interface. Returns the version of the codec that + // produces the given bytes. + Unmarshal(source []byte, destination interface{}) error +} + +// NewManager returns a new codec manager. +func NewManager(maxSize int, codec Codec) Manager { + return &manager{ + maxSize: maxSize, + codec: codec, + } +} + +type manager struct { + maxSize int + codec Codec +} + +func (m *manager) Size(value interface{}) (int, error) { + if value == nil { + return 0, ErrMarshalNil // can't marshal nil + } + + res, err := m.codec.Size(value) + + // Add [wrappers.ShortLen] for the codec version + return wrappers.ShortLen + res, err +} + +// To marshal an interface, [value] must be a pointer to the interface. +func (m *manager) Marshal(value interface{}) ([]byte, error) { + if value == nil { + return nil, ErrMarshalNil // can't marshal nil + } + + p := wrappers.Packer{ + MaxSize: m.maxSize, + Bytes: make([]byte, 0, initialSliceCap), + } + return p.Bytes, m.codec.MarshalInto(value, &p) +} + +// Unmarshal unmarshals [bytes] into [dest], where [dest] must be a pointer or +// interface. +func (m *manager) Unmarshal(bytes []byte, dest interface{}) error { + if dest == nil { + return ErrUnmarshalNil + } + + if byteLen := len(bytes); byteLen > m.maxSize { + return fmt.Errorf("%w: %d > %d", ErrUnmarshalTooBig, byteLen, m.maxSize) + } + + p := wrappers.Packer{ + Bytes: bytes, + } + if err := m.codec.UnmarshalFrom(&p, dest); err != nil { + return err + } + if p.Offset != len(bytes) { + return fmt.Errorf("%w: read %d provided %d", + ErrExtraSpace, + p.Offset, + len(bytes), + ) + } + + return nil +} diff --git a/utils/codec/reflectcodec/struct_fielder.go b/utils/codec/reflectcodec/struct_fielder.go new file mode 100644 index 00000000..9e90f493 --- /dev/null +++ b/utils/codec/reflectcodec/struct_fielder.go @@ -0,0 +1,96 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package reflectcodec + +import ( + "fmt" + "reflect" + "sync" + + "github.com/landslidenetwork/slide-sdk/utils/codec" +) + +// TagValue is the value the tag must have to be serialized. +const TagValue = "true" + +var _ StructFielder = (*structFielder)(nil) + +// StructFielder handles discovery of serializable fields in a struct. +type StructFielder interface { + // Returns the fields that have been marked as serializable in [t], which is + // a struct type. + // Returns an error if a field has tag "[tagName]: [TagValue]" but the field + // is un-exported. + // GetSerializedField(Foo) --> [1,5,8] means Foo.Field(1), Foo.Field(5), + // Foo.Field(8) are to be serialized/deserialized. + GetSerializedFields(t reflect.Type) ([]int, error) +} + +func NewStructFielder(tagNames []string) StructFielder { + return &structFielder{ + tags: tagNames, + serializedFieldIndices: make(map[reflect.Type][]int), + } +} + +type structFielder struct { + lock sync.RWMutex + + // multiple tags per field can be specified. A field is serialized/deserialized + // if it has at least one of the specified tags. + tags []string + + // Key: a struct type + // Value: Slice where each element is index in the struct type of a field + // that is serialized/deserialized e.g. Foo --> [1,5,8] means Foo.Field(1), + // etc. are to be serialized/deserialized. We assume this cache is pretty + // small (a few hundred keys at most) and doesn't take up much memory. + serializedFieldIndices map[reflect.Type][]int +} + +func (s *structFielder) GetSerializedFields(t reflect.Type) ([]int, error) { + if serializedFields, ok := s.getCachedSerializedFields(t); ok { // use pre-computed result + return serializedFields, nil + } + + s.lock.Lock() + defer s.lock.Unlock() + + numFields := t.NumField() + serializedFields := make([]int, 0, numFields) + for i := 0; i < numFields; i++ { // Go through all fields of this struct + field := t.Field(i) + + // Multiple tags per fields can be specified. + // Serialize/Deserialize field if it has + // any tag with the right value + var captureField bool + for _, tag := range s.tags { + if field.Tag.Get(tag) == TagValue { + captureField = true + break + } + } + if !captureField { + continue + } + if !field.IsExported() { // Can only marshal exported fields + return nil, fmt.Errorf("can not marshal %w: %s", + codec.ErrUnexportedField, + field.Name, + ) + } + serializedFields = append(serializedFields, i) + } + s.serializedFieldIndices[t] = serializedFields // cache result + return serializedFields, nil +} + +func (s *structFielder) getCachedSerializedFields(t reflect.Type) ([]int, bool) { + s.lock.RLock() + defer s.lock.RUnlock() + + cachedFields, ok := s.serializedFieldIndices[t] + return cachedFields, ok +} diff --git a/utils/codec/reflectcodec/type_codec.go b/utils/codec/reflectcodec/type_codec.go new file mode 100644 index 00000000..d8515c5e --- /dev/null +++ b/utils/codec/reflectcodec/type_codec.go @@ -0,0 +1,744 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package reflectcodec + +import ( + "bytes" + "errors" + "fmt" + "math" + "reflect" + "slices" + + "github.com/landslidenetwork/slide-sdk/utils/codec" + "github.com/landslidenetwork/slide-sdk/utils/set" + "github.com/landslidenetwork/slide-sdk/utils/wrappers" +) + +const ( + // DefaultTagName that enables serialization. + DefaultTagName = "serialize" + initialSliceLen = 16 +) + +var ( + _ codec.Codec = (*genericCodec)(nil) + + errNeedPointer = errors.New("argument to unmarshal must be a pointer") + errRecursiveInterfaceTypes = errors.New("recursive interface types") +) + +type TypeCodec interface { + // UnpackPrefix unpacks the prefix of an interface from the given packer. + // The prefix specifies the concrete type that the interface should be + // deserialized into. This function returns a new instance of that concrete + // type. The concrete type must implement the given type. + UnpackPrefix(*wrappers.Packer, reflect.Type) (reflect.Value, error) + + // PackPrefix packs the prefix for the given type into the given packer. + // This identifies the bytes that follow, which are the byte representation + // of an interface, as having the given concrete type. + // When deserializing the bytes, the prefix specifies which concrete type + // to deserialize into. + PackPrefix(*wrappers.Packer, reflect.Type) error + + // PrefixSize returns prefix length for the given type into the given + // packer. + PrefixSize(reflect.Type) int +} + +// genericCodec handles marshaling and unmarshaling of structs with a generic +// implementation for interface encoding. +// +// A few notes: +// +// 1. We use "marshal" and "serialize" interchangeably, and "unmarshal" and +// "deserialize" interchangeably +// 2. To include a field of a struct in the serialized form, add the tag +// `{tagName}:"true"` to it. `{tagName}` defaults to `serialize`. +// 3. These typed members of a struct may be serialized: +// bool, string, uint[8,16,32,64], int[8,16,32,64], +// structs, slices, arrays, maps, interface. +// structs, slices, maps and arrays can only be serialized if their constituent +// values can be. +// 4. To marshal an interface, you must pass a pointer to the value +// 5. To unmarshal an interface, you must call +// codec.RegisterType([instance of the type that fulfills the interface]). +// 6. Serialized fields must be exported +// 7. nil slices are marshaled as empty slices +type genericCodec struct { + typer TypeCodec + fielder StructFielder +} + +// New returns a new, concurrency-safe codec +func New(typer TypeCodec, tagNames []string) codec.Codec { + return &genericCodec{ + typer: typer, + fielder: NewStructFielder(tagNames), + } +} + +func (c *genericCodec) Size(value interface{}) (int, error) { + if value == nil { + return 0, codec.ErrMarshalNil + } + + size, _, err := c.size(reflect.ValueOf(value), nil /*=typeStack*/) + return size, err +} + +// size returns the size of the value along with whether the value is constant +// sized. +func (c *genericCodec) size( + value reflect.Value, + typeStack set.Set[reflect.Type], +) (int, bool, error) { + switch valueKind := value.Kind(); valueKind { + case reflect.Uint8: + return wrappers.ByteLen, true, nil + case reflect.Int8: + return wrappers.ByteLen, true, nil + case reflect.Uint16: + return wrappers.ShortLen, true, nil + case reflect.Int16: + return wrappers.ShortLen, true, nil + case reflect.Uint32: + return wrappers.IntLen, true, nil + case reflect.Int32: + return wrappers.IntLen, true, nil + case reflect.Uint64: + return wrappers.LongLen, true, nil + case reflect.Int64: + return wrappers.LongLen, true, nil + case reflect.Bool: + return wrappers.BoolLen, true, nil + case reflect.String: + return wrappers.StringLen(value.String()), false, nil + case reflect.Ptr: + if value.IsNil() { + return 0, false, codec.ErrMarshalNil + } + + return c.size(value.Elem(), typeStack) + + case reflect.Interface: + if value.IsNil() { + return 0, false, codec.ErrMarshalNil + } + + underlyingValue := value.Interface() + underlyingType := reflect.TypeOf(underlyingValue) + if typeStack.Contains(underlyingType) { + return 0, false, fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, underlyingType) + } + typeStack.Add(underlyingType) + + prefixSize := c.typer.PrefixSize(underlyingType) + valueSize, _, err := c.size(value.Elem(), typeStack) + + typeStack.Remove(underlyingType) + return prefixSize + valueSize, false, err + + case reflect.Slice: + numElts := value.Len() + if numElts == 0 { + return wrappers.IntLen, false, nil + } + + size, constSize, err := c.size(value.Index(0), typeStack) + if err != nil { + return 0, false, err + } + + if size == 0 { + return 0, false, fmt.Errorf("can't marshal slice of zero length values: %w", codec.ErrMarshalZeroLength) + } + + // For fixed-size types we manually calculate lengths rather than + // processing each element separately to improve performance. + if constSize { + return wrappers.IntLen + numElts*size, false, nil + } + + for i := 1; i < numElts; i++ { + innerSize, _, err := c.size(value.Index(i), typeStack) + if err != nil { + return 0, false, err + } + size += innerSize + } + return wrappers.IntLen + size, false, nil + + case reflect.Array: + numElts := value.Len() + if numElts == 0 { + return 0, true, nil + } + + size, constSize, err := c.size(value.Index(0), typeStack) + if err != nil { + return 0, false, err + } + + // For fixed-size types we manually calculate lengths rather than + // processing each element separately to improve performance. + if constSize { + return numElts * size, true, nil + } + + for i := 1; i < numElts; i++ { + innerSize, _, err := c.size(value.Index(i), typeStack) + if err != nil { + return 0, false, err + } + size += innerSize + } + return size, false, nil + + case reflect.Struct: + serializedFields, err := c.fielder.GetSerializedFields(value.Type()) + if err != nil { + return 0, false, err + } + + var ( + size int + constSize = true + ) + for _, fieldIndex := range serializedFields { + innerSize, innerConstSize, err := c.size(value.Field(fieldIndex), typeStack) + if err != nil { + return 0, false, err + } + size += innerSize + constSize = constSize && innerConstSize + } + return size, constSize, nil + + case reflect.Map: + iter := value.MapRange() + if !iter.Next() { + return wrappers.IntLen, false, nil + } + + keySize, keyConstSize, err := c.size(iter.Key(), typeStack) + if err != nil { + return 0, false, err + } + valueSize, valueConstSize, err := c.size(iter.Value(), typeStack) + if err != nil { + return 0, false, err + } + + if keySize == 0 && valueSize == 0 { + return 0, false, fmt.Errorf("can't marshal map with zero length entries: %w", codec.ErrMarshalZeroLength) + } + + switch { + case keyConstSize && valueConstSize: + numElts := value.Len() + return wrappers.IntLen + numElts*(keySize+valueSize), false, nil + case keyConstSize: + var ( + numElts = 1 + totalValueSize = valueSize + ) + for iter.Next() { + valueSize, _, err := c.size(iter.Value(), typeStack) + if err != nil { + return 0, false, err + } + totalValueSize += valueSize + numElts++ + } + return wrappers.IntLen + numElts*keySize + totalValueSize, false, nil + case valueConstSize: + var ( + numElts = 1 + totalKeySize = keySize + ) + for iter.Next() { + keySize, _, err := c.size(iter.Key(), typeStack) + if err != nil { + return 0, false, err + } + totalKeySize += keySize + numElts++ + } + return wrappers.IntLen + totalKeySize + numElts*valueSize, false, nil + default: + totalSize := wrappers.IntLen + keySize + valueSize + for iter.Next() { + keySize, _, err := c.size(iter.Key(), typeStack) + if err != nil { + return 0, false, err + } + valueSize, _, err := c.size(iter.Value(), typeStack) + if err != nil { + return 0, false, err + } + totalSize += keySize + valueSize + } + return totalSize, false, nil + } + + default: + return 0, false, fmt.Errorf("can't evaluate marshal length of unknown kind %s", valueKind) + } +} + +// To marshal an interface, [value] must be a pointer to the interface +func (c *genericCodec) MarshalInto(value interface{}, p *wrappers.Packer) error { + if value == nil { + return codec.ErrMarshalNil + } + + return c.marshal(reflect.ValueOf(value), p, nil /*=typeStack*/) +} + +// marshal writes the byte representation of [value] to [p] +// +// c.lock should be held for the duration of this function +func (c *genericCodec) marshal( + value reflect.Value, + p *wrappers.Packer, + typeStack set.Set[reflect.Type], +) error { + switch valueKind := value.Kind(); valueKind { + case reflect.Uint8: + p.PackByte(uint8(value.Uint())) + return p.Err + case reflect.Int8: + p.PackByte(uint8(value.Int())) + return p.Err + case reflect.Uint16: + p.PackShort(uint16(value.Uint())) + return p.Err + case reflect.Int16: + p.PackShort(uint16(value.Int())) + return p.Err + case reflect.Uint32: + p.PackInt(uint32(value.Uint())) + return p.Err + case reflect.Int32: + p.PackInt(uint32(value.Int())) + return p.Err + case reflect.Uint64: + p.PackLong(value.Uint()) + return p.Err + case reflect.Int64: + p.PackLong(uint64(value.Int())) + return p.Err + case reflect.String: + p.PackStr(value.String()) + return p.Err + case reflect.Bool: + p.PackBool(value.Bool()) + return p.Err + case reflect.Ptr: + if value.IsNil() { + return codec.ErrMarshalNil + } + + return c.marshal(value.Elem(), p, typeStack) + case reflect.Interface: + if value.IsNil() { + return codec.ErrMarshalNil + } + + underlyingValue := value.Interface() + underlyingType := reflect.TypeOf(underlyingValue) + if typeStack.Contains(underlyingType) { + return fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, underlyingType) + } + typeStack.Add(underlyingType) + if err := c.typer.PackPrefix(p, underlyingType); err != nil { + return err + } + if err := c.marshal(value.Elem(), p, typeStack); err != nil { + return err + } + typeStack.Remove(underlyingType) + return p.Err + case reflect.Slice: + numElts := value.Len() // # elements in the slice/array. 0 if this slice is nil. + if numElts > math.MaxInt32 { + return fmt.Errorf("%w; slice length, %d, exceeds maximum length, %d", + codec.ErrMaxSliceLenExceeded, + numElts, + math.MaxInt32, + ) + } + p.PackInt(uint32(numElts)) // pack # elements + if p.Err != nil { + return p.Err + } + if numElts == 0 { + // Returning here prevents execution of the (expensive) reflect + // calls below which check if the slice is []byte and, if it is, + // the call of value.Bytes() + return nil + } + // If this is a slice of bytes, manually pack the bytes rather + // than calling marshal on each byte. This improves performance. + if elemKind := value.Type().Elem().Kind(); elemKind == reflect.Uint8 { + p.PackFixedBytes(value.Bytes()) + return p.Err + } + for i := 0; i < numElts; i++ { // Process each element in the slice + startOffset := p.Offset + if err := c.marshal(value.Index(i), p, typeStack); err != nil { + return err + } + if startOffset == p.Offset { + return fmt.Errorf("couldn't marshal slice of zero length values: %w", codec.ErrMarshalZeroLength) + } + } + return nil + case reflect.Array: + if elemKind := value.Type().Kind(); elemKind == reflect.Uint8 { + sliceVal := value.Convert(reflect.TypeOf([]byte{})) + p.PackFixedBytes(sliceVal.Bytes()) + return p.Err + } + numElts := value.Len() + for i := 0; i < numElts; i++ { // Process each element in the array + if err := c.marshal(value.Index(i), p, typeStack); err != nil { + return err + } + } + return nil + case reflect.Struct: + serializedFields, err := c.fielder.GetSerializedFields(value.Type()) + if err != nil { + return err + } + for _, fieldIndex := range serializedFields { // Go through all fields of this struct that are serialized + if err := c.marshal(value.Field(fieldIndex), p, typeStack); err != nil { // Serialize the field and write to byte array + return err + } + } + return nil + case reflect.Map: + keys := value.MapKeys() + numElts := len(keys) + if numElts > math.MaxInt32 { + return fmt.Errorf("%w; slice length, %d, exceeds maximum length, %d", + codec.ErrMaxSliceLenExceeded, + numElts, + math.MaxInt32, + ) + } + p.PackInt(uint32(numElts)) // pack # elements + if p.Err != nil { + return p.Err + } + + // pack key-value pairs sorted by increasing key + type keyTuple struct { + key reflect.Value + startIndex int + endIndex int + } + + sortedKeys := make([]keyTuple, len(keys)) + startOffset := p.Offset + endOffset := p.Offset + for i, key := range keys { + if err := c.marshal(key, p, typeStack); err != nil { + return err + } + if p.Err != nil { + return fmt.Errorf("couldn't marshal map key %+v: %w ", key, p.Err) + } + sortedKeys[i] = keyTuple{ + key: key, + startIndex: endOffset, + endIndex: p.Offset, + } + endOffset = p.Offset + } + + slices.SortFunc(sortedKeys, func(a, b keyTuple) int { + aBytes := p.Bytes[a.startIndex:a.endIndex] + bBytes := p.Bytes[b.startIndex:b.endIndex] + return bytes.Compare(aBytes, bBytes) + }) + + allKeyBytes := slices.Clone(p.Bytes[startOffset:p.Offset]) + p.Offset = startOffset + for _, key := range sortedKeys { + keyStartOffset := p.Offset + + // pack key + startIndex := key.startIndex - startOffset + endIndex := key.endIndex - startOffset + keyBytes := allKeyBytes[startIndex:endIndex] + p.PackFixedBytes(keyBytes) + if p.Err != nil { + return p.Err + } + + // serialize and pack value + if err := c.marshal(value.MapIndex(key.key), p, typeStack); err != nil { + return err + } + if keyStartOffset == p.Offset { + return fmt.Errorf("couldn't marshal map with zero length entries: %w", codec.ErrMarshalZeroLength) + } + } + + return nil + default: + return fmt.Errorf("%w: %s", codec.ErrUnsupportedType, valueKind) + } +} + +// UnmarshalFrom unmarshals [p.Bytes] into [dest], where [dest] must be a pointer or +// interface +func (c *genericCodec) UnmarshalFrom(p *wrappers.Packer, dest interface{}) error { + if dest == nil { + return codec.ErrUnmarshalNil + } + + destPtr := reflect.ValueOf(dest) + if destPtr.Kind() != reflect.Ptr { + return errNeedPointer + } + return c.unmarshal(p, destPtr.Elem(), nil /*=typeStack*/) +} + +// Unmarshal from p.Bytes into [value]. [value] must be addressable. +// +// c.lock should be held for the duration of this function +func (c *genericCodec) unmarshal( + p *wrappers.Packer, + value reflect.Value, + typeStack set.Set[reflect.Type], +) error { + switch value.Kind() { + case reflect.Uint8: + value.SetUint(uint64(p.UnpackByte())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint8: %w", p.Err) + } + return nil + case reflect.Int8: + value.SetInt(int64(p.UnpackByte())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int8: %w", p.Err) + } + return nil + case reflect.Uint16: + value.SetUint(uint64(p.UnpackShort())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint16: %w", p.Err) + } + return nil + case reflect.Int16: + value.SetInt(int64(p.UnpackShort())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int16: %w", p.Err) + } + return nil + case reflect.Uint32: + value.SetUint(uint64(p.UnpackInt())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint32: %w", p.Err) + } + return nil + case reflect.Int32: + value.SetInt(int64(p.UnpackInt())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int32: %w", p.Err) + } + return nil + case reflect.Uint64: + value.SetUint(p.UnpackLong()) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint64: %w", p.Err) + } + return nil + case reflect.Int64: + value.SetInt(int64(p.UnpackLong())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int64: %w", p.Err) + } + return nil + case reflect.Bool: + value.SetBool(p.UnpackBool()) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal bool: %w", p.Err) + } + return nil + case reflect.Slice: + numElts32 := p.UnpackInt() + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal slice: %w", p.Err) + } + if numElts32 > math.MaxInt32 { + return fmt.Errorf("%w; array length, %d, exceeds maximum length, %d", + codec.ErrMaxSliceLenExceeded, + numElts32, + math.MaxInt32, + ) + } + numElts := int(numElts32) + + sliceType := value.Type() + innerType := sliceType.Elem() + + // If this is a slice of bytes, manually unpack the bytes rather + // than calling unmarshal on each byte. This improves performance. + if elemKind := innerType.Kind(); elemKind == reflect.Uint8 { + value.SetBytes(p.UnpackFixedBytes(numElts)) + return p.Err + } + // Unmarshal each element and append it into the slice. + value.Set(reflect.MakeSlice(sliceType, 0, initialSliceLen)) + zeroValue := reflect.Zero(innerType) + for i := 0; i < numElts; i++ { + value.Set(reflect.Append(value, zeroValue)) + + startOffset := p.Offset + if err := c.unmarshal(p, value.Index(i), typeStack); err != nil { + return err + } + if startOffset == p.Offset { + return fmt.Errorf("couldn't unmarshal slice of zero length values: %w", codec.ErrUnmarshalZeroLength) + } + } + return nil + case reflect.Array: + numElts := value.Len() + if elemKind := value.Type().Elem().Kind(); elemKind == reflect.Uint8 { + unpackedBytes := p.UnpackFixedBytes(numElts) + if p.Errored() { + return p.Err + } + // Get a slice to the underlying array value + underlyingSlice := value.Slice(0, numElts).Interface().([]byte) + copy(underlyingSlice, unpackedBytes) + return nil + } + for i := 0; i < numElts; i++ { + if err := c.unmarshal(p, value.Index(i), typeStack); err != nil { + return err + } + } + return nil + case reflect.String: + value.SetString(p.UnpackStr()) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal string: %w", p.Err) + } + return nil + case reflect.Interface: + intfImplementor, err := c.typer.UnpackPrefix(p, value.Type()) + if err != nil { + return err + } + intfImplementorType := intfImplementor.Type() + if typeStack.Contains(intfImplementorType) { + return fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, intfImplementorType) + } + typeStack.Add(intfImplementorType) + + // Unmarshal into the struct + if err := c.unmarshal(p, intfImplementor, typeStack); err != nil { + return err + } + + typeStack.Remove(intfImplementorType) + value.Set(intfImplementor) + return nil + case reflect.Struct: + // Get indices of fields that will be unmarshaled into + serializedFieldIndices, err := c.fielder.GetSerializedFields(value.Type()) + if err != nil { + return fmt.Errorf("couldn't unmarshal struct: %w", err) + } + // Go through the fields and unmarshal into them + for _, fieldIndex := range serializedFieldIndices { + if err := c.unmarshal(p, value.Field(fieldIndex), typeStack); err != nil { + return err + } + } + return nil + case reflect.Ptr: + // Get the type this pointer points to + t := value.Type().Elem() + // Create a new pointer to a new value of the underlying type + v := reflect.New(t) + // Fill the value + if err := c.unmarshal(p, v.Elem(), typeStack); err != nil { + return err + } + // Assign to the top-level struct's member + value.Set(v) + return nil + case reflect.Map: + numElts32 := p.UnpackInt() + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal map: %w", p.Err) + } + if numElts32 > math.MaxInt32 { + return fmt.Errorf("%w; map length, %d, exceeds maximum length, %d", + codec.ErrMaxSliceLenExceeded, + numElts32, + math.MaxInt32, + ) + } + + var ( + numElts = int(numElts32) + mapType = value.Type() + mapKeyType = mapType.Key() + mapValueType = mapType.Elem() + prevKey []byte + ) + + // Set [value] to be a new map of the appropriate type. + value.Set(reflect.MakeMap(mapType)) + + for i := 0; i < numElts; i++ { + mapKey := reflect.New(mapKeyType).Elem() + + keyStartOffset := p.Offset + + if err := c.unmarshal(p, mapKey, typeStack); err != nil { + return err + } + + // Get the key's byte representation and check that the new key is + // actually bigger (according to bytes.Compare) than the previous + // key. + // + // We do this to enforce that key-value pairs are sorted by + // increasing key. + keyBytes := p.Bytes[keyStartOffset:p.Offset] + if i != 0 && bytes.Compare(keyBytes, prevKey) <= 0 { + return fmt.Errorf("keys aren't sorted: (%s, %s)", prevKey, mapKey) + } + prevKey = keyBytes + + // Get the value + mapValue := reflect.New(mapValueType).Elem() + if err := c.unmarshal(p, mapValue, typeStack); err != nil { + return err + } + if keyStartOffset == p.Offset { + return fmt.Errorf("couldn't unmarshal map with zero length entries: %w", codec.ErrUnmarshalZeroLength) + } + + // Assign the key-value pair in the map + value.SetMapIndex(mapKey, mapValue) + } + + return nil + default: + return fmt.Errorf("can't unmarshal unknown type %s", value.Kind().String()) + } +} diff --git a/utils/codec/registry.go b/utils/codec/registry.go new file mode 100644 index 00000000..de87e1a9 --- /dev/null +++ b/utils/codec/registry.go @@ -0,0 +1,13 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package codec + +import "errors" + +var ErrDuplicateType = errors.New("duplicate type registration") + +// Registry registers new types that can be marshaled into +type Registry interface { + RegisterType(interface{}) error +} diff --git a/utils/crypto/bls/public.go b/utils/crypto/bls/public.go new file mode 100644 index 00000000..acc3649c --- /dev/null +++ b/utils/crypto/bls/public.go @@ -0,0 +1,14 @@ +package bls + +import blst "github.com/supranational/blst/bindings/go" + +var ciphersuiteSignature = []byte("BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_") + +type PublicKey = blst.P1Affine + +// Verify the [sig] of [msg] against the [pk]. +// The [sig] and [pk] may have been an aggregation of other signatures and keys. +// Invariant: [pk] and [sig] have both been validated. +func Verify(pk *PublicKey, sig *Signature, msg []byte) bool { + return sig.Verify(false, pk, false, msg, ciphersuiteSignature) +} diff --git a/utils/crypto/bls/secret.go b/utils/crypto/bls/secret.go new file mode 100644 index 00000000..ead6ceaf --- /dev/null +++ b/utils/crypto/bls/secret.go @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bls + +import ( + "crypto/rand" + "errors" + "runtime" + + blst "github.com/supranational/blst/bindings/go" +) + +const SecretKeyLen = blst.BLST_SCALAR_BYTES + +var ( + errFailedSecretKeyDeserialize = errors.New("couldn't deserialize secret key") +) + +type SecretKey = blst.SecretKey + +// NewSecretKey generates a new secret key from the local source of +// cryptographically secure randomness. +func NewSecretKey() (*SecretKey, error) { + var ikm [32]byte + _, err := rand.Read(ikm[:]) + if err != nil { + return nil, err + } + sk := blst.KeyGen(ikm[:]) + ikm = [32]byte{} // zero out the ikm + return sk, nil +} + +// SecretKeyToBytes returns the big-endian format of the secret key. +func SecretKeyToBytes(sk *SecretKey) []byte { + return sk.Serialize() +} + +// SecretKeyFromBytes parses the big-endian format of the secret key into a +// secret key. +func SecretKeyFromBytes(skBytes []byte) (*SecretKey, error) { + sk := new(SecretKey).Deserialize(skBytes) + if sk == nil { + return nil, errFailedSecretKeyDeserialize + } + runtime.SetFinalizer(sk, func(sk *SecretKey) { + sk.Zeroize() + }) + return sk, nil +} + +// PublicFromSecretKey returns the public key that corresponds to this secret +// key. +func PublicFromSecretKey(sk *SecretKey) *PublicKey { + return new(PublicKey).From(sk) +} + +// Sign [msg] to authorize this message from this [sk]. +func Sign(sk *SecretKey, msg []byte) *Signature { + return new(Signature).Sign(sk, msg, ciphersuiteSignature) +} diff --git a/utils/crypto/bls/secret_test.go b/utils/crypto/bls/secret_test.go new file mode 100644 index 00000000..6f56a73b --- /dev/null +++ b/utils/crypto/bls/secret_test.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bls + +import ( + "testing" + + "github.com/cometbft/cometbft/crypto" + + "github.com/stretchr/testify/require" +) + +func TestSecretKeyFromBytesZero(t *testing.T) { + require := require.New(t) + + var skArr [SecretKeyLen]byte + skBytes := skArr[:] + _, err := SecretKeyFromBytes(skBytes) + require.ErrorIs(err, errFailedSecretKeyDeserialize) +} + +func TestSecretKeyFromBytesWrongSize(t *testing.T) { + require := require.New(t) + + skBytes := crypto.CRandBytes(SecretKeyLen + 1) + _, err := SecretKeyFromBytes(skBytes) + require.ErrorIs(err, errFailedSecretKeyDeserialize) +} + +func TestSecretKeyBytes(t *testing.T) { + require := require.New(t) + + msg := crypto.CRandBytes(1234) + + sk, err := NewSecretKey() + require.NoError(err) + sig := Sign(sk, msg) + skBytes := SecretKeyToBytes(sk) + + sk2, err := SecretKeyFromBytes(skBytes) + require.NoError(err) + sig2 := Sign(sk2, msg) + sk2Bytes := SecretKeyToBytes(sk2) + + require.Equal(sk, sk2) + require.Equal(skBytes, sk2Bytes) + require.Equal(sig, sig2) +} diff --git a/utils/crypto/bls/signature.go b/utils/crypto/bls/signature.go new file mode 100644 index 00000000..0d0d029b --- /dev/null +++ b/utils/crypto/bls/signature.go @@ -0,0 +1,57 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bls + +import ( + "errors" + + blst "github.com/supranational/blst/bindings/go" +) + +const SignatureLen = blst.BLST_P2_COMPRESS_BYTES + +var ( + ErrFailedSignatureDecompress = errors.New("couldn't decompress signature") + errInvalidSignature = errors.New("invalid signature") + errNoSignatures = errors.New("no signatures") + errFailedSignatureAggregation = errors.New("couldn't aggregate signatures") +) + +type ( + Signature = blst.P2Affine + AggregateSignature = blst.P2Aggregate +) + +// SignatureToBytes returns the compressed big-endian format of the signature. +func SignatureToBytes(sig *Signature) []byte { + return sig.Compress() +} + +// SignatureFromBytes parses the compressed big-endian format of the signature +// into a signature. +func SignatureFromBytes(sigBytes []byte) (*Signature, error) { + sig := new(Signature).Uncompress(sigBytes) + if sig == nil { + return nil, ErrFailedSignatureDecompress + } + if !sig.SigValidate(false) { + return nil, errInvalidSignature + } + return sig, nil +} + +// AggregateSignatures aggregates a non-zero number of signatures into a single +// aggregated signature. +// Invariant: all [sigs] have been validated. +func AggregateSignatures(sigs []*Signature) (*Signature, error) { + if len(sigs) == 0 { + return nil, errNoSignatures + } + + var agg AggregateSignature + if !agg.Aggregate(sigs, false) { + return nil, errFailedSignatureAggregation + } + return agg.ToAffine(), nil +} diff --git a/utils/crypto/bls/signature_test.go b/utils/crypto/bls/signature_test.go new file mode 100644 index 00000000..6f2f21e2 --- /dev/null +++ b/utils/crypto/bls/signature_test.go @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bls + +import ( + "testing" + + "github.com/cometbft/cometbft/crypto" + + "github.com/stretchr/testify/require" +) + +func TestSignatureBytes(t *testing.T) { + require := require.New(t) + + msg := crypto.CRandBytes(1234) + + sk, err := NewSecretKey() + require.NoError(err) + sig := Sign(sk, msg) + sigBytes := SignatureToBytes(sig) + + sig2, err := SignatureFromBytes(sigBytes) + require.NoError(err) + sig2Bytes := SignatureToBytes(sig2) + + require.Equal(sig, sig2) + require.Equal(sigBytes, sig2Bytes) +} + +func TestAggregateSignaturesNoop(t *testing.T) { + require := require.New(t) + + msg := crypto.CRandBytes(1234) + + sk, err := NewSecretKey() + require.NoError(err) + + sig := Sign(sk, msg) + sigBytes := SignatureToBytes(sig) + + aggSig, err := AggregateSignatures([]*Signature{sig}) + require.NoError(err) + + aggSigBytes := SignatureToBytes(aggSig) + require.NoError(err) + + require.Equal(sig, aggSig) + require.Equal(sigBytes, aggSigBytes) +} diff --git a/utils/set/set.go b/utils/set/set.go new file mode 100644 index 00000000..350d7946 --- /dev/null +++ b/utils/set/set.go @@ -0,0 +1,195 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package set + +import ( + "bytes" + "encoding/json" + "slices" + + "golang.org/x/exp/maps" + + "github.com/landslidenetwork/slide-sdk/utils" + "github.com/landslidenetwork/slide-sdk/utils/wrappers" +) + +// The minimum capacity of a set +const minSetSize = 16 + +// Null is the string representation of a null value +const Null = "null" + +var _ json.Marshaler = (*Set[int])(nil) + +// Set is a set of elements. +type Set[T comparable] map[T]struct{} + +// Of returns a Set initialized with [elts] +func Of[T comparable](elts ...T) Set[T] { + s := NewSet[T](len(elts)) + s.Add(elts...) + return s +} + +// Return a new set with initial capacity [size]. +// More or less than [size] elements can be added to this set. +// Using NewSet() rather than Set[T]{} is just an optimization that can +// be used if you know how many elements will be put in this set. +func NewSet[T comparable](size int) Set[T] { + if size < 0 { + return Set[T]{} + } + return make(map[T]struct{}, size) +} + +func (s *Set[T]) resize(size int) { + if *s == nil { + if minSetSize > size { + size = minSetSize + } + *s = make(map[T]struct{}, size) + } +} + +// Add all the elements to this set. +// If the element is already in the set, nothing happens. +func (s *Set[T]) Add(elts ...T) { + s.resize(2 * len(elts)) + for _, elt := range elts { + (*s)[elt] = struct{}{} + } +} + +// Union adds all the elements from the provided set to this set. +func (s *Set[T]) Union(set Set[T]) { + s.resize(2 * set.Len()) + for elt := range set { + (*s)[elt] = struct{}{} + } +} + +// Difference removes all the elements in [set] from [s]. +func (s *Set[T]) Difference(set Set[T]) { + for elt := range set { + delete(*s, elt) + } +} + +// Contains returns true iff the set contains this element. +func (s *Set[T]) Contains(elt T) bool { + _, contains := (*s)[elt] + return contains +} + +// Overlaps returns true if the intersection of the set is non-empty +func (s *Set[T]) Overlaps(big Set[T]) bool { + small := *s + if small.Len() > big.Len() { + small, big = big, small + } + + for elt := range small { + if _, ok := big[elt]; ok { + return true + } + } + return false +} + +// Len returns the number of elements in this set. +func (s Set[_]) Len() int { + return len(s) +} + +// Remove all the given elements from this set. +// If an element isn't in the set, it's ignored. +func (s *Set[T]) Remove(elts ...T) { + for _, elt := range elts { + delete(*s, elt) + } +} + +// Clear empties this set +func (s *Set[_]) Clear() { + clear(*s) +} + +// List converts this set into a list +func (s Set[T]) List() []T { + return maps.Keys(s) +} + +// Equals returns true if the sets contain the same elements +func (s Set[T]) Equals(other Set[T]) bool { + return maps.Equal(s, other) +} + +// Removes and returns an element. +// If the set is empty, does nothing and returns false. +func (s *Set[T]) Pop() (T, bool) { + for elt := range *s { + delete(*s, elt) + return elt, true + } + return utils.Zero[T](), false +} + +func (s *Set[T]) UnmarshalJSON(b []byte) error { + str := string(b) + if str == Null { + return nil + } + var elts []T + if err := json.Unmarshal(b, &elts); err != nil { + return err + } + s.Clear() + s.Add(elts...) + return nil +} + +func (s Set[_]) MarshalJSON() ([]byte, error) { + var ( + eltBytes = make([][]byte, len(s)) + i int + err error + ) + for elt := range s { + eltBytes[i], err = json.Marshal(elt) + if err != nil { + return nil, err + } + i++ + } + // Sort for determinism + slices.SortFunc(eltBytes, bytes.Compare) + + // Build the JSON + var ( + jsonBuf = bytes.Buffer{} + errs = wrappers.Errs{} + ) + _, err = jsonBuf.WriteString("[") + errs.Add(err) + for i, elt := range eltBytes { + _, err := jsonBuf.Write(elt) + errs.Add(err) + if i != len(eltBytes)-1 { + _, err := jsonBuf.WriteString(",") + errs.Add(err) + } + } + _, err = jsonBuf.WriteString("]") + errs.Add(err) + + return jsonBuf.Bytes(), errs.Err +} + +// Returns a random element. If the set is empty, returns false +func (s *Set[T]) Peek() (T, bool) { + for elt := range *s { + return elt, true + } + return utils.Zero[T](), false +} diff --git a/utils/set/set_benchmark_test.go b/utils/set/set_benchmark_test.go new file mode 100644 index 00000000..300b8c8c --- /dev/null +++ b/utils/set/set_benchmark_test.go @@ -0,0 +1,39 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package set + +import ( + "strconv" + "testing" +) + +func BenchmarkSetList(b *testing.B) { + sizes := []int{5, 25, 100, 100_000} // Test with various sizes + for size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + set := Set[int]{} + for i := 0; i < size; i++ { + set.Add(i) + } + b.ResetTimer() + for n := 0; n < b.N; n++ { + set.List() + } + }) + } +} + +func BenchmarkSetClear(b *testing.B) { + for _, numElts := range []int{10, 25, 50, 100, 250, 500, 1000} { + b.Run(strconv.Itoa(numElts), func(b *testing.B) { + set := NewSet[int](numElts) + for n := 0; n < b.N; n++ { + for i := 0; i < numElts; i++ { + set.Add(i) + } + set.Clear() + } + }) + } +} diff --git a/utils/set/set_test.go b/utils/set/set_test.go new file mode 100644 index 00000000..3b0a7e18 --- /dev/null +++ b/utils/set/set_test.go @@ -0,0 +1,228 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package set + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSet(t *testing.T) { + require := require.New(t) + id1 := 1 + + s := Set[int]{id1: struct{}{}} + + s.Add(id1) + require.True(s.Contains(id1)) + + s.Remove(id1) + require.False(s.Contains(id1)) + + s.Add(id1) + require.True(s.Contains(id1)) + require.Len(s.List(), 1) + require.Equal(id1, s.List()[0]) + + s.Clear() + require.False(s.Contains(id1)) + + s.Add(id1) + + s2 := Set[int]{} + + require.False(s.Overlaps(s2)) + + s2.Union(s) + require.True(s2.Contains(id1)) + require.True(s.Overlaps(s2)) + + s2.Difference(s) + require.False(s2.Contains(id1)) + require.False(s.Overlaps(s2)) +} + +func TestOf(t *testing.T) { + tests := []struct { + name string + elements []int + expected []int + }{ + { + name: "nil", + elements: nil, + expected: []int{}, + }, + { + name: "empty", + elements: []int{}, + expected: []int{}, + }, + { + name: "unique elements", + elements: []int{1, 2, 3}, + expected: []int{1, 2, 3}, + }, + { + name: "duplicate elements", + elements: []int{1, 2, 3, 1, 2, 3}, + expected: []int{1, 2, 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + s := Of(tt.elements...) + + require.Len(s, len(tt.expected)) + for _, expected := range tt.expected { + require.True(s.Contains(expected)) + } + }) + } +} + +func TestSetClear(t *testing.T) { + require := require.New(t) + + set := Set[int]{} + for i := 0; i < 25; i++ { + set.Add(i) + } + set.Clear() + require.Empty(set) + set.Add(1337) + require.Len(set, 1) +} + +func TestSetPop(t *testing.T) { + require := require.New(t) + + var s Set[int] + _, ok := s.Pop() + require.False(ok) + + s = make(Set[int]) + _, ok = s.Pop() + require.False(ok) + + id1, id2 := 0, 1 + s.Add(id1, id2) + + got, ok := s.Pop() + require.True(ok) + require.True(got == id1 || got == id2) + require.Equal(1, s.Len()) + + got, ok = s.Pop() + require.True(ok) + require.True(got == id1 || got == id2) + require.Zero(s.Len()) + + _, ok = s.Pop() + require.False(ok) +} + +func TestSetMarshalJSON(t *testing.T) { + require := require.New(t) + set := Set[int]{} + { + asJSON, err := set.MarshalJSON() + require.NoError(err) + require.Equal("[]", string(asJSON)) + } + id1, id2 := 1, 2 + id1JSON, err := json.Marshal(id1) + require.NoError(err) + id2JSON, err := json.Marshal(id2) + require.NoError(err) + set.Add(id1) + { + asJSON, err := set.MarshalJSON() + require.NoError(err) + require.Equal(fmt.Sprintf("[%s]", string(id1JSON)), string(asJSON)) + } + set.Add(id2) + { + asJSON, err := set.MarshalJSON() + require.NoError(err) + require.Equal(fmt.Sprintf("[%s,%s]", string(id1JSON), string(id2JSON)), string(asJSON)) + } +} + +func TestSetUnmarshalJSON(t *testing.T) { + require := require.New(t) + set := Set[int]{} + { + require.NoError(set.UnmarshalJSON([]byte("[]"))) + require.Empty(set) + } + id1, id2 := 1, 2 + id1JSON, err := json.Marshal(id1) + require.NoError(err) + id2JSON, err := json.Marshal(id2) + require.NoError(err) + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%s]", string(id1JSON))))) + require.Len(set, 1) + require.Contains(set, id1) + } + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%s,%s]", string(id1JSON), string(id2JSON))))) + require.Len(set, 2) + require.Contains(set, id1) + require.Contains(set, id2) + } + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%d,%d,%d]", 3, 4, 5)))) + require.Len(set, 3) + require.Contains(set, 3) + require.Contains(set, 4) + require.Contains(set, 5) + } + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%d,%d,%d, %d]", 3, 4, 5, 3)))) + require.Len(set, 3) + require.Contains(set, 3) + require.Contains(set, 4) + require.Contains(set, 5) + } + { + set1 := Set[int]{} + set2 := Set[int]{} + require.NoError(set1.UnmarshalJSON([]byte(fmt.Sprintf("[%s,%s]", string(id1JSON), string(id2JSON))))) + require.NoError(set2.UnmarshalJSON([]byte(fmt.Sprintf("[%s,%s]", string(id2JSON), string(id1JSON))))) + require.Equal(set1, set2) + } +} + +func TestSetReflectJSONMarshal(t *testing.T) { + require := require.New(t) + set := Set[int]{} + { + asJSON, err := json.Marshal(set) + require.NoError(err) + require.Equal("[]", string(asJSON)) + } + id1JSON, err := json.Marshal(1) + require.NoError(err) + id2JSON, err := json.Marshal(2) + require.NoError(err) + set.Add(1) + { + asJSON, err := json.Marshal(set) + require.NoError(err) + require.Equal(fmt.Sprintf("[%s]", string(id1JSON)), string(asJSON)) + } + set.Add(2) + { + asJSON, err := json.Marshal(set) + require.NoError(err) + require.Equal(fmt.Sprintf("[%s,%s]", string(id1JSON), string(id2JSON)), string(asJSON)) + } +} diff --git a/utils/warp/codec_manager.go b/utils/warp/codec_manager.go new file mode 100644 index 00000000..e1e455c2 --- /dev/null +++ b/utils/warp/codec_manager.go @@ -0,0 +1,19 @@ +package warp + +import ( + "math" + + "github.com/landslidenetwork/slide-sdk/utils/codec" + "github.com/landslidenetwork/slide-sdk/utils/codec/linearcodec" +) + +var Codec codec.Manager + +func init() { + lc := linearcodec.NewDefault() + err := lc.RegisterType(&BitSetSignature{}) + if err != nil { + panic(err) + } + Codec = codec.NewManager(math.MaxInt, lc) +} diff --git a/utils/warp/predefined.go b/utils/warp/predefined.go new file mode 100644 index 00000000..679798ea --- /dev/null +++ b/utils/warp/predefined.go @@ -0,0 +1,3 @@ +package warp + +const UnitTestID uint32 = 10 diff --git a/utils/warp/signature.go b/utils/warp/signature.go new file mode 100644 index 00000000..9d35d6cd --- /dev/null +++ b/utils/warp/signature.go @@ -0,0 +1,15 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + blst "github.com/supranational/blst/bindings/go" +) + +type BitSetSignature struct { + // Signers is a big-endian byte slice encoding which validators signed this + // message. + Signers []byte `serialize:"true"` + Signature [blst.BLST_P2_COMPRESS_BYTES]byte `serialize:"true"` +} diff --git a/utils/warp/signer.go b/utils/warp/signer.go new file mode 100644 index 00000000..6cacbe70 --- /dev/null +++ b/utils/warp/signer.go @@ -0,0 +1,52 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "errors" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + blst "github.com/supranational/blst/bindings/go" + + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +var ( + _ Signer = (*signer)(nil) + + ErrWrongSourceChainID = errors.New("wrong SourceChainID") + ErrWrongNetworkID = errors.New("wrong networkID") +) + +type Signer interface { + // Assumes the unsigned message is correctly initialized. + Sign(msg *UnsignedMessage) ([]byte, error) +} + +func NewSigner(sk *blst.SecretKey, networkID uint32, chainID ids.ID) Signer { + return &signer{ + sk: sk, + networkID: networkID, + chainID: chainID, + } +} + +type signer struct { + sk *blst.SecretKey + networkID uint32 + chainID ids.ID +} + +func (s *signer) Sign(msg *UnsignedMessage) ([]byte, error) { + if msg.SourceChainID != s.chainID { + return nil, ErrWrongSourceChainID + } + if msg.NetworkID != s.networkID { + return nil, ErrWrongNetworkID + } + + msgBytes := msg.Bytes() + signature := bls.Sign(s.sk, msgBytes) + return signature.Compress(), nil +} diff --git a/utils/warp/signer_test.go b/utils/warp/signer_test.go new file mode 100644 index 00000000..61c08b44 --- /dev/null +++ b/utils/warp/signer_test.go @@ -0,0 +1,90 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "testing" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + + "github.com/stretchr/testify/require" + + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +func TestSigner(t *testing.T) { + for name, test := range SignerTests { + t.Run(name, func(t *testing.T) { + sk, err := bls.NewSecretKey() + require.NoError(t, err) + + chainID := ids.GenerateTestID() + s := NewSigner(sk, UnitTestID, chainID) + + test(t, s, sk, UnitTestID, chainID) + }) + } +} + +// SignerTests is a list of all signer tests +var SignerTests = map[string]func(t *testing.T, s Signer, sk *bls.SecretKey, networkID uint32, chainID ids.ID){ + "WrongChainID": testWrongChainID, + "WrongNetworkID": testWrongNetworkID, + "Verifies": testVerifies, +} + +// Test that using a random SourceChainID results in an error +func testWrongChainID(t *testing.T, s Signer, _ *bls.SecretKey, _ uint32, _ ids.ID) { + require := require.New(t) + + msg, err := NewUnsignedMessage( + UnitTestID, + ids.GenerateTestID(), + []byte("payload"), + ) + require.NoError(err) + + _, err = s.Sign(msg) + // TODO: require error to be ErrWrongSourceChainID + require.Error(err) //nolint:forbidigo // currently returns grpc errors too +} + +// Test that using a different networkID results in an error +func testWrongNetworkID(t *testing.T, s Signer, _ *bls.SecretKey, networkID uint32, blockchainID ids.ID) { + require := require.New(t) + + msg, err := NewUnsignedMessage( + networkID+1, + blockchainID, + []byte("payload"), + ) + require.NoError(err) + + _, err = s.Sign(msg) + // TODO: require error to be ErrWrongNetworkID + require.Error(err) //nolint:forbidigo // currently returns grpc errors too +} + +// Test that a signature generated with the signer verifies correctly +func testVerifies(t *testing.T, s Signer, sk *bls.SecretKey, networkID uint32, chainID ids.ID) { + require := require.New(t) + + msg, err := NewUnsignedMessage( + networkID, + chainID, + []byte("payload"), + ) + require.NoError(err) + + sigBytes, err := s.Sign(msg) + require.NoError(err) + + t.Log(sigBytes) + sig, err := bls.SignatureFromBytes(sigBytes) + require.NoError(err) + + pk := bls.PublicFromSecretKey(sk) + msgBytes := msg.Bytes() + require.True(bls.Verify(pk, sig, msgBytes)) +} diff --git a/utils/warp/unsigned_message.go b/utils/warp/unsigned_message.go new file mode 100644 index 00000000..f4ff3bf2 --- /dev/null +++ b/utils/warp/unsigned_message.go @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "fmt" + + "github.com/landslidenetwork/slide-sdk/utils/hashing" + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +// UnsignedMessage defines the standard format for an unsigned Warp message. +type UnsignedMessage struct { + NetworkID uint32 `serialize:"true"` + SourceChainID ids.ID `serialize:"true"` + Payload []byte `serialize:"true"` + + bytes []byte + id ids.ID +} + +// NewUnsignedMessage creates a new *UnsignedMessage and initializes it. +func NewUnsignedMessage( + networkID uint32, + sourceChainID ids.ID, + payload []byte, +) (*UnsignedMessage, error) { + msg := &UnsignedMessage{ + NetworkID: networkID, + SourceChainID: sourceChainID, + Payload: payload, + } + err := msg.Initialize() + return msg, err +} + +// ParseUnsignedMessage converts a slice of bytes into an initialized +// *UnsignedMessage. +func ParseUnsignedMessage(b []byte) (*UnsignedMessage, error) { + msg := &UnsignedMessage{ + bytes: b, + id: hashing.ComputeHash256Array(b), + } + err := Codec.Unmarshal(b, msg) + return msg, err +} + +// Initialize recalculates the result of Bytes(). +func (m *UnsignedMessage) Initialize() error { + bytes, err := Codec.Marshal(m) + if err != nil { + return fmt.Errorf("couldn't marshal warp unsigned message: %w", err) + } + m.bytes = bytes + m.id = hashing.ComputeHash256Array(m.bytes) + return nil +} + +// Bytes returns the binary representation of this message. It assumes that the +// message is initialized from either New, Parse, or an explicit call to +// Initialize. +func (m *UnsignedMessage) Bytes() []byte { + return m.bytes +} + +// ID returns an identifier for this message. It assumes that the +// message is initialized from either New, Parse, or an explicit call to +// Initialize. +func (m *UnsignedMessage) ID() ids.ID { + return m.id +} diff --git a/utils/warp/unsigned_message_test.go b/utils/warp/unsigned_message_test.go new file mode 100644 index 00000000..900717fc --- /dev/null +++ b/utils/warp/unsigned_message_test.go @@ -0,0 +1,36 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +func TestUnsignedMessage(t *testing.T) { + require := require.New(t) + + msg, err := NewUnsignedMessage( + UnitTestID, + ids.GenerateTestID(), + []byte("payload"), + ) + require.NoError(err) + + msgBytes := msg.Bytes() + msg2, err := ParseUnsignedMessage(msgBytes) + require.NoError(err) + require.Equal(msg, msg2) +} + +func TestParseUnsignedMessageJunk(t *testing.T) { + require := require.New(t) + + bytes := []byte{0, 1, 2, 3, 4, 5, 6, 7} + _, err := ParseUnsignedMessage(bytes) + require.Error(err) +} diff --git a/warp/backend.go b/warp/backend.go new file mode 100644 index 00000000..e9b7d072 --- /dev/null +++ b/warp/backend.go @@ -0,0 +1,54 @@ +package warp + +import ( + "fmt" + + dbm "github.com/cometbft/cometbft-db" + "github.com/cometbft/cometbft/libs/log" + "github.com/landslidenetwork/slide-sdk/utils/ids" + warputils "github.com/landslidenetwork/slide-sdk/utils/warp" +) + +// Backend tracks signature-eligible warp messages and provides an interface to fetch them. +// The backend is also used to query for warp message signatures by the signature request handler. +type Backend interface { + // AddMessage signs [unsignedMessage] and adds it to the warp backend database + AddMessage(unsignedMessage *warputils.UnsignedMessage) error +} + +// backend implements Backend, keeps track of warp messages, and generates message signatures. +type backend struct { + logger log.Logger + networkID uint32 + sourceChainID ids.ID + db dbm.DB + warpSigner warputils.Signer +} + +// NewBackend creates a new Backend, and initializes the signature cache and message tracking database. +func NewBackend(networkID uint32, sourceChainID ids.ID, warpSigner warputils.Signer, db dbm.DB) Backend { + return &backend{ + networkID: networkID, + sourceChainID: sourceChainID, + db: db, + warpSigner: warpSigner, + } +} + +func (b *backend) AddMessage(unsignedMessage *warputils.UnsignedMessage) error { + messageID := unsignedMessage.ID() + + // In the case when a node restarts, and possibly changes its bls key, the cache gets emptied but the database does not. + // So to avoid having incorrect signatures saved in the database after a bls key change, we save the full message in the database. + // Whereas for the cache, after the node restart, the cache would be emptied so we can directly save the signatures. + if err := b.db.Set(messageID[:], unsignedMessage.Bytes()); err != nil { + return fmt.Errorf("failed to put warp signature in db: %w", err) + } + + _, err := b.warpSigner.Sign(unsignedMessage) + if err != nil { + return fmt.Errorf("failed to sign warp message: %w", err) + } + b.logger.Debug("Adding warp message to backend", "messageID", messageID) + return nil +} From 83fb1bfe9b11c11a2990308e3bccb8a2b13d0d0a Mon Sep 17 00:00:00 2001 From: ivansukach <47761294+ivansukach@users.noreply.github.com> Date: Wed, 20 Nov 2024 07:44:35 +0100 Subject: [PATCH 02/10] warp backend: get message (#67) Co-authored-by: Ivan Sukach --- warp/backend.go | 23 ++++++++++++-- warp/backend_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 warp/backend_test.go diff --git a/warp/backend.go b/warp/backend.go index e9b7d072..41a95fb9 100644 --- a/warp/backend.go +++ b/warp/backend.go @@ -14,6 +14,10 @@ import ( type Backend interface { // AddMessage signs [unsignedMessage] and adds it to the warp backend database AddMessage(unsignedMessage *warputils.UnsignedMessage) error + // GetMessage retrieves the [unsignedMessage] from the warp backend database if available + // TODO: After E-Upgrade, the backend no longer needs to store the mapping from messageHash + // to unsignedMessage (and this method can be removed). + GetMessage(messageHash ids.ID) (*warputils.UnsignedMessage, error) } // backend implements Backend, keeps track of warp messages, and generates message signatures. @@ -26,12 +30,13 @@ type backend struct { } // NewBackend creates a new Backend, and initializes the signature cache and message tracking database. -func NewBackend(networkID uint32, sourceChainID ids.ID, warpSigner warputils.Signer, db dbm.DB) Backend { +func NewBackend(networkID uint32, sourceChainID ids.ID, warpSigner warputils.Signer, logger log.Logger, db dbm.DB) Backend { return &backend{ networkID: networkID, sourceChainID: sourceChainID, - db: db, warpSigner: warpSigner, + logger: logger, + db: db, } } @@ -52,3 +57,17 @@ func (b *backend) AddMessage(unsignedMessage *warputils.UnsignedMessage) error { b.logger.Debug("Adding warp message to backend", "messageID", messageID) return nil } + +func (b *backend) GetMessage(messageID ids.ID) (*warputils.UnsignedMessage, error) { + unsignedMessageBytes, err := b.db.Get(messageID[:]) + if err != nil { + return nil, fmt.Errorf("failed to get warp message %s from db: %w", messageID.String(), err) + } + + unsignedMessage, err := warputils.ParseUnsignedMessage(unsignedMessageBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse unsigned message %s: %w", messageID.String(), err) + } + + return unsignedMessage, nil +} diff --git a/warp/backend_test.go b/warp/backend_test.go new file mode 100644 index 00000000..3c29fb58 --- /dev/null +++ b/warp/backend_test.go @@ -0,0 +1,74 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "os" + "testing" + + dbm "github.com/cometbft/cometbft-db" + "github.com/cometbft/cometbft/libs/log" + "github.com/cometbft/cometbft/libs/rand" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/ids" + warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + "github.com/stretchr/testify/require" +) + +var ( + err error + networkID uint32 = 54321 + sourceChainID = ids.GenerateTestID() + testUnsignedMessage *warputils.UnsignedMessage +) + +func init() { + testUnsignedMessage, err = warputils.NewUnsignedMessage(networkID, sourceChainID, []byte(rand.Str(30))) + if err != nil { + panic(err) + } +} + +func TestAddAndGetValidMessage(t *testing.T) { + logger := log.NewTMLogger(os.Stdout) + db := dbm.NewMemDB() + + sk, err := bls.NewSecretKey() + require.NoError(t, err) + warpSigner := warputils.NewSigner(sk, networkID, sourceChainID) + backend := NewBackend(networkID, sourceChainID, warpSigner, logger, db) + require.NoError(t, err) + + // Add testUnsignedMessage to the warp backend + err = backend.AddMessage(testUnsignedMessage) + require.NoError(t, err) + + // Verify that a signature is returned successfully, and compare to expected signature. + msg, err := backend.GetMessage(testUnsignedMessage.ID()) + require.NoError(t, err) + require.NotNil(t, msg) + + expectedSig, err := warpSigner.Sign(testUnsignedMessage) + require.NoError(t, err) + require.NotNil(t, expectedSig) + // TODO: get signature and compare +} + +func TestAddAndGetUnknownMessage(t *testing.T) { + logger := log.NewTMLogger(os.Stdout) + db := dbm.NewMemDB() + + sk, err := bls.NewSecretKey() + require.NoError(t, err) + warpSigner := warputils.NewSigner(sk, networkID, sourceChainID) + backend := NewBackend(networkID, sourceChainID, warpSigner, logger, db) + require.NoError(t, err) + + // Try getting a signature for a message that was not added. + msg, err := backend.GetMessage(testUnsignedMessage.ID()) + require.Error(t, err) + require.Nil(t, msg) + //TODO: add business logic to check signature +} From 4fb95b0f35037f8edcfd8fb67553380a90912d9f Mon Sep 17 00:00:00 2001 From: ivansukach <47761294+ivansukach@users.noreply.github.com> Date: Wed, 27 Nov 2024 22:04:21 +0100 Subject: [PATCH 03/10] Warp get message signature (#68) * preliminary * prelimiry: get message signature * add codec version to packer of codec * process error of codec type registration * add exception for static check * ignore U1000 linter unused function * get rid of unnecesary slicing, set space in comment * linter: sort imports --------- Co-authored-by: Ivan Sukach --- utils/codec/codec.go | 2 + utils/codec/manager.go | 10 ++++ utils/warp/messages/codec_manager.go | 15 ++++++ utils/warp/messages/payload.go | 43 ++++++++++++++++ utils/warp/payload/README.md | 47 +++++++++++++++++ utils/warp/payload/addressed_call.go | 53 +++++++++++++++++++ utils/warp/payload/addressed_call_test.go | 45 ++++++++++++++++ utils/warp/payload/codec_manager.go | 23 +++++++++ utils/warp/payload/hash.go | 49 ++++++++++++++++++ utils/warp/payload/hash_test.go | 38 ++++++++++++++ utils/warp/payload/payload.go | 39 ++++++++++++++ utils/warp/payload/payload_test.go | 63 +++++++++++++++++++++++ warp/backend.go | 56 ++++++++++++++++++++ warp/backend_test.go | 11 +++- 14 files changed, 492 insertions(+), 2 deletions(-) create mode 100644 utils/warp/messages/codec_manager.go create mode 100644 utils/warp/messages/payload.go create mode 100644 utils/warp/payload/README.md create mode 100644 utils/warp/payload/addressed_call.go create mode 100644 utils/warp/payload/addressed_call_test.go create mode 100644 utils/warp/payload/codec_manager.go create mode 100644 utils/warp/payload/hash.go create mode 100644 utils/warp/payload/hash_test.go create mode 100644 utils/warp/payload/payload.go create mode 100644 utils/warp/payload/payload_test.go diff --git a/utils/codec/codec.go b/utils/codec/codec.go index 1ffc0213..0eaca4d5 100644 --- a/utils/codec/codec.go +++ b/utils/codec/codec.go @@ -16,6 +16,8 @@ var ( ErrUnexportedField = errors.New("unexported field") ErrMarshalZeroLength = errors.New("can't marshal zero length value") ErrUnmarshalZeroLength = errors.New("can't unmarshal zero length value") + ErrCantPackVersion = errors.New("couldn't pack codec version") + ErrCantUnpackVersion = errors.New("couldn't unpack codec version") ) // Codec marshals and unmarshals diff --git a/utils/codec/manager.go b/utils/codec/manager.go index 25273700..82122457 100644 --- a/utils/codec/manager.go +++ b/utils/codec/manager.go @@ -15,6 +15,8 @@ const ( // Larger value --> need less memory allocations but possibly have allocated but unused memory // Smaller value --> need more memory allocations but more efficient use of allocated memory initialSliceCap = 128 + // default version of codec + defaultVersion = 0 ) var ( @@ -78,6 +80,10 @@ func (m *manager) Marshal(value interface{}) ([]byte, error) { MaxSize: m.maxSize, Bytes: make([]byte, 0, initialSliceCap), } + p.PackShort(defaultVersion) + if p.Errored() { + return nil, ErrCantPackVersion // Should never happen + } return p.Bytes, m.codec.MarshalInto(value, &p) } @@ -95,6 +101,10 @@ func (m *manager) Unmarshal(bytes []byte, dest interface{}) error { p := wrappers.Packer{ Bytes: bytes, } + p.UnpackShort() + if p.Errored() { // Make sure the codec version is correct + return ErrCantUnpackVersion + } if err := m.codec.UnmarshalFrom(&p, dest); err != nil { return err } diff --git a/utils/warp/messages/codec_manager.go b/utils/warp/messages/codec_manager.go new file mode 100644 index 00000000..9a790415 --- /dev/null +++ b/utils/warp/messages/codec_manager.go @@ -0,0 +1,15 @@ +package messages + +import ( + "github.com/landslidenetwork/slide-sdk/utils/codec" + "github.com/landslidenetwork/slide-sdk/utils/codec/linearcodec" +) + +const MaxMessageSize = 24 * 1024 + +var Codec codec.Manager + +func init() { + lc := linearcodec.NewDefault() + Codec = codec.NewManager(MaxMessageSize, lc) +} diff --git a/utils/warp/messages/payload.go b/utils/warp/messages/payload.go new file mode 100644 index 00000000..85aba96e --- /dev/null +++ b/utils/warp/messages/payload.go @@ -0,0 +1,43 @@ +// (c) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package messages + +import ( + "fmt" +) + +// Payload provides a common interface for all payloads implemented by this +// package. +type Payload interface { + // Bytes returns the binary representation of this payload. + Bytes() []byte + + // initialize the payload with the provided binary representation. + initialize(b []byte) +} + +// Signable is an optional interface that payloads can implement to allow +// on-the-fly signing of incoming messages by the warp backend. +type Signable interface { + VerifyMesssage(sourceAddress []byte) error +} + +func Parse(bytes []byte) (Payload, error) { + var payload Payload + if err := Codec.Unmarshal(bytes, &payload); err != nil { + return nil, err + } + payload.initialize(bytes) + return payload, nil +} + +//lint:ignore U1000 will be implemented and used for ValidatorUptime struct +func initialize(p Payload) error { + bytes, err := Codec.Marshal(&p) + if err != nil { + return fmt.Errorf("couldn't marshal %T payload: %w", p, err) + } + p.initialize(bytes) + return nil +} diff --git a/utils/warp/payload/README.md b/utils/warp/payload/README.md new file mode 100644 index 00000000..2da32ee5 --- /dev/null +++ b/utils/warp/payload/README.md @@ -0,0 +1,47 @@ +# Payload + +An Avalanche Unsigned Warp Message already includes a `networkID`, `sourceChainID`, and `payload` field. The `payload` field can be parsed into one of the types included in this package to be further handled by the VM. + +## Hash + +Hash: +``` ++-----------------+----------+-----------+ +| codecID : uint16 | 2 bytes | ++-----------------+----------+-----------+ +| typeID : uint32 | 4 bytes | ++-----------------+----------+-----------+ +| hash : [32]byte | 32 bytes | ++-----------------+----------+-----------+ + | 38 bytes | + +-----------+ +``` + +- `codecID` is the codec version used to serialize the payload and is hardcoded to `0x0000` +- `typeID` is the payload type identifier and is `0x00000000` for `Hash` +- `hash` is a hash from the `sourceChainID`. The format of the expected preimage is chain specific. Some examples for valid hash values are: + - root of a merkle tree + - accepted block hash on the source chain + - accepted transaction hash on the source chain + +## AddressedCall + +AddressedCall: +``` ++---------------------+--------+----------------------------------+ +| codecID : uint16 | 2 bytes | ++---------------------+--------+----------------------------------+ +| typeID : uint32 | 4 bytes | ++---------------------+--------+----------------------------------+ +| sourceAddress : []byte | 4 + len(address) | ++---------------------+--------+----------------------------------+ +| payload : []byte | 4 + len(payload) | ++---------------------+--------+----------------------------------+ + | 14 + len(payload) + len(address) | + +----------------------------------+ +``` + +- `codecID` is the codec version used to serialize the payload and is hardcoded to `0x0000` +- `typeID` is the payload type identifier and is `0x00000001` for `AddressedCall` +- `sourceAddress` is the address that sent this message from the source chain +- `payload` is an arbitrary byte array payload diff --git a/utils/warp/payload/addressed_call.go b/utils/warp/payload/addressed_call.go new file mode 100644 index 00000000..b3617ce4 --- /dev/null +++ b/utils/warp/payload/addressed_call.go @@ -0,0 +1,53 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package payload + +import "fmt" + +var _ Payload = (*AddressedCall)(nil) + +// AddressedCall defines the format for delivering a call across VMs including a +// source address and a payload. +// +// Note: If a destination address is expected, it should be encoded in the +// payload. +type AddressedCall struct { + SourceAddress []byte `serialize:"true"` + Payload []byte `serialize:"true"` + + bytes []byte +} + +// NewAddressedCall creates a new *AddressedCall and initializes it. +func NewAddressedCall(sourceAddress []byte, payload []byte) (*AddressedCall, error) { + ap := &AddressedCall{ + SourceAddress: sourceAddress, + Payload: payload, + } + return ap, initialize(ap) +} + +// ParseAddressedCall converts a slice of bytes into an initialized +// AddressedCall. +func ParseAddressedCall(b []byte) (*AddressedCall, error) { + payloadIntf, err := Parse(b) + if err != nil { + return nil, err + } + payload, ok := payloadIntf.(*AddressedCall) + if !ok { + return nil, fmt.Errorf("%w: %T", errWrongType, payloadIntf) + } + return payload, nil +} + +// Bytes returns the binary representation of this payload. It assumes that the +// payload is initialized from either NewAddressedCall or Parse. +func (a *AddressedCall) Bytes() []byte { + return a.bytes +} + +func (a *AddressedCall) initialize(bytes []byte) { + a.bytes = bytes +} diff --git a/utils/warp/payload/addressed_call_test.go b/utils/warp/payload/addressed_call_test.go new file mode 100644 index 00000000..fd18ee8d --- /dev/null +++ b/utils/warp/payload/addressed_call_test.go @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package payload + +import ( + "encoding/base64" + "testing" + + "github.com/cometbft/cometbft/libs/rand" + + "github.com/stretchr/testify/require" +) + +func TestAddressedCall(t *testing.T) { + require := require.New(t) + shortID := []byte(rand.Str(2)) + + addressedPayload, err := NewAddressedCall( + shortID, + []byte{1, 2, 3}, + ) + require.NoError(err) + + addressedPayloadBytes := addressedPayload.Bytes() + parsedAddressedPayload, err := ParseAddressedCall(addressedPayloadBytes) + require.NoError(err) + require.Equal(addressedPayload, parsedAddressedPayload) +} + +func TestParseAddressedCallJunk(t *testing.T) { + _, err := ParseAddressedCall(junkBytes) + require.Error(t, err) +} + +func TestAddressedCallBytes(t *testing.T) { + require := require.New(t) + base64Payload := "AAAAAAABAAAAEAECAwAAAAAAAAAAAAAAAAAAAAADCgsM" + addressedPayload, err := NewAddressedCall( + []byte{1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + []byte{10, 11, 12}, + ) + require.NoError(err) + require.Equal(base64Payload, base64.StdEncoding.EncodeToString(addressedPayload.Bytes())) +} diff --git a/utils/warp/payload/codec_manager.go b/utils/warp/payload/codec_manager.go new file mode 100644 index 00000000..9a4ff480 --- /dev/null +++ b/utils/warp/payload/codec_manager.go @@ -0,0 +1,23 @@ +package payload + +import ( + "github.com/landslidenetwork/slide-sdk/utils/codec" + "github.com/landslidenetwork/slide-sdk/utils/codec/linearcodec" +) + +const MaxMessageSize = 24 * 1024 + +var Codec codec.Manager + +func init() { + lc := linearcodec.NewDefault() + err := lc.RegisterType(&Hash{}) + if err != nil { + panic(err) + } + err = lc.RegisterType(&AddressedCall{}) + if err != nil { + panic(err) + } + Codec = codec.NewManager(MaxMessageSize, lc) +} diff --git a/utils/warp/payload/hash.go b/utils/warp/payload/hash.go new file mode 100644 index 00000000..aae7a17e --- /dev/null +++ b/utils/warp/payload/hash.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package payload + +import ( + "fmt" + + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +var _ Payload = (*Hash)(nil) + +type Hash struct { + Hash ids.ID `serialize:"true"` + + bytes []byte +} + +// NewHash creates a new *Hash and initializes it. +func NewHash(hash ids.ID) (*Hash, error) { + bhp := &Hash{ + Hash: hash, + } + return bhp, initialize(bhp) +} + +// ParseHash converts a slice of bytes into an initialized Hash. +func ParseHash(b []byte) (*Hash, error) { + payloadIntf, err := Parse(b) + if err != nil { + return nil, err + } + payload, ok := payloadIntf.(*Hash) + if !ok { + return nil, fmt.Errorf("%w: %T", errWrongType, payloadIntf) + } + return payload, nil +} + +// Bytes returns the binary representation of this payload. It assumes that the +// payload is initialized from either NewHash or Parse. +func (b *Hash) Bytes() []byte { + return b.bytes +} + +func (b *Hash) initialize(bytes []byte) { + b.bytes = bytes +} diff --git a/utils/warp/payload/hash_test.go b/utils/warp/payload/hash_test.go new file mode 100644 index 00000000..f20addf4 --- /dev/null +++ b/utils/warp/payload/hash_test.go @@ -0,0 +1,38 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package payload + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +func TestHash(t *testing.T) { + require := require.New(t) + + hashPayload, err := NewHash(ids.GenerateTestID()) + require.NoError(err) + + hashPayloadBytes := hashPayload.Bytes() + parsedHashPayload, err := ParseHash(hashPayloadBytes) + require.NoError(err) + require.Equal(hashPayload, parsedHashPayload) +} + +func TestParseHashJunk(t *testing.T) { + _, err := ParseHash(junkBytes) + require.Error(t, err) +} + +func TestHashBytes(t *testing.T) { + require := require.New(t) + base64Payload := "AAAAAAAABAUGAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + hashPayload, err := NewHash(ids.ID{4, 5, 6}) + require.NoError(err) + require.Equal(base64Payload, base64.StdEncoding.EncodeToString(hashPayload.Bytes())) +} diff --git a/utils/warp/payload/payload.go b/utils/warp/payload/payload.go new file mode 100644 index 00000000..554a68b7 --- /dev/null +++ b/utils/warp/payload/payload.go @@ -0,0 +1,39 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package payload + +import ( + "errors" + "fmt" +) + +var errWrongType = errors.New("wrong payload type") + +// Payload provides a common interface for all payloads implemented by this +// package. +type Payload interface { + // Bytes returns the binary representation of this payload. + Bytes() []byte + + // initialize the payload with the provided binary representation. + initialize(b []byte) +} + +func Parse(bytes []byte) (Payload, error) { + var payload Payload + if err := Codec.Unmarshal(bytes, &payload); err != nil { + return nil, err + } + payload.initialize(bytes) + return payload, nil +} + +func initialize(p Payload) error { + bytes, err := Codec.Marshal(&p) + if err != nil { + return fmt.Errorf("couldn't marshal %T payload: %w", p, err) + } + p.initialize(bytes) + return nil +} diff --git a/utils/warp/payload/payload_test.go b/utils/warp/payload/payload_test.go new file mode 100644 index 00000000..b4d120d6 --- /dev/null +++ b/utils/warp/payload/payload_test.go @@ -0,0 +1,63 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package payload + +import ( + "testing" + + "github.com/cometbft/cometbft/libs/rand" + + "github.com/stretchr/testify/require" + + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +var junkBytes = []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + +func TestParseJunk(t *testing.T) { + require := require.New(t) + _, err := Parse(junkBytes) + require.Error(err) +} + +func TestParseWrongPayloadType(t *testing.T) { + require := require.New(t) + hash, err := ids.ToID([]byte(rand.Str(32))) + require.NoError(err) + hashPayload, err := NewHash(hash) + require.NoError(err) + + shortID := []byte(rand.Str(20)) + addressedPayload, err := NewAddressedCall( + shortID, + []byte{1, 2, 3}, + ) + require.NoError(err) + + _, err = ParseAddressedCall(hashPayload.Bytes()) + require.ErrorIs(err, errWrongType) + + _, err = ParseHash(addressedPayload.Bytes()) + require.ErrorIs(err, errWrongType) +} + +func TestParse(t *testing.T) { + require := require.New(t) + hashPayload, err := NewHash(ids.ID{4, 5, 6}) + require.NoError(err) + + parsedHashPayload, err := Parse(hashPayload.Bytes()) + require.NoError(err) + require.Equal(hashPayload, parsedHashPayload) + + addressedPayload, err := NewAddressedCall( + []byte{1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + []byte{10, 11, 12}, + ) + require.NoError(err) + + parsedAddressedPayload, err := Parse(addressedPayload.Bytes()) + require.NoError(err) + require.Equal(addressedPayload, parsedAddressedPayload) +} diff --git a/warp/backend.go b/warp/backend.go index 41a95fb9..bcef5bda 100644 --- a/warp/backend.go +++ b/warp/backend.go @@ -3,10 +3,14 @@ package warp import ( "fmt" + "github.com/landslidenetwork/slide-sdk/utils/warp/messages" + dbm "github.com/cometbft/cometbft-db" "github.com/cometbft/cometbft/libs/log" + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" "github.com/landslidenetwork/slide-sdk/utils/ids" warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + "github.com/landslidenetwork/slide-sdk/utils/warp/payload" ) // Backend tracks signature-eligible warp messages and provides an interface to fetch them. @@ -14,6 +18,8 @@ import ( type Backend interface { // AddMessage signs [unsignedMessage] and adds it to the warp backend database AddMessage(unsignedMessage *warputils.UnsignedMessage) error + // GetMessageSignature returns the signature of the requested message. + GetMessageSignature(message *warputils.UnsignedMessage) ([bls.SignatureLen]byte, error) // GetMessage retrieves the [unsignedMessage] from the warp backend database if available // TODO: After E-Upgrade, the backend no longer needs to store the mapping from messageHash // to unsignedMessage (and this method can be removed). @@ -54,6 +60,7 @@ func (b *backend) AddMessage(unsignedMessage *warputils.UnsignedMessage) error { if err != nil { return fmt.Errorf("failed to sign warp message: %w", err) } + //TODO: save message signature to prefixdb b.logger.Debug("Adding warp message to backend", "messageID", messageID) return nil } @@ -71,3 +78,52 @@ func (b *backend) GetMessage(messageID ids.ID) (*warputils.UnsignedMessage, erro return unsignedMessage, nil } + +func (b *backend) GetMessageSignature(unsignedMessage *warputils.UnsignedMessage) ([bls.SignatureLen]byte, error) { + messageID := unsignedMessage.ID() + + b.logger.Debug("Getting warp message from backend", "messageID", messageID) + if err := b.ValidateMessage(unsignedMessage); err != nil { + return [bls.SignatureLen]byte{}, fmt.Errorf("failed to validate warp message: %w", err) + } + + var signature [bls.SignatureLen]byte + sig, err := b.warpSigner.Sign(unsignedMessage) + if err != nil { + return [bls.SignatureLen]byte{}, fmt.Errorf("failed to sign warp message: %w", err) + } + + copy(signature[:], sig) + return signature, nil +} + +func (b *backend) ValidateMessage(unsignedMessage *warputils.UnsignedMessage) error { + // Known on-chain messages should be signed + if _, err := b.GetMessage(unsignedMessage.ID()); err == nil { + return nil + } + + // Try to parse the payload as an AddressedCall + addressedCall, err := payload.ParseAddressedCall(unsignedMessage.Payload) + if err != nil { + return fmt.Errorf("failed to parse unknown message as AddressedCall: %w", err) + } + + // Further, parse the payload to see if it is a known type. + parsed, err := messages.Parse(addressedCall.Payload) + if err != nil { + return fmt.Errorf("failed to parse unknown message: %w", err) + } + + // Check if the message is a known type that can be signed on demand + signable, ok := parsed.(messages.Signable) + if !ok { + return fmt.Errorf("parsed message is not Signable: %T", signable) + } + + // Check if the message should be signed according to its type + if err := signable.VerifyMesssage(addressedCall.SourceAddress); err != nil { + return fmt.Errorf("failed to verify Signable message: %w", err) + } + return nil +} diff --git a/warp/backend_test.go b/warp/backend_test.go index 3c29fb58..3e886963 100644 --- a/warp/backend_test.go +++ b/warp/backend_test.go @@ -53,7 +53,12 @@ func TestAddAndGetValidMessage(t *testing.T) { expectedSig, err := warpSigner.Sign(testUnsignedMessage) require.NoError(t, err) require.NotNil(t, expectedSig) - // TODO: get signature and compare + + // Verify that a signature is returned successfully, and compare to expected signature. + signature, err := backend.GetMessageSignature(testUnsignedMessage) + require.NoError(t, err) + require.NoError(t, err) + require.Equal(t, expectedSig, signature[:]) } func TestAddAndGetUnknownMessage(t *testing.T) { @@ -70,5 +75,7 @@ func TestAddAndGetUnknownMessage(t *testing.T) { msg, err := backend.GetMessage(testUnsignedMessage.ID()) require.Error(t, err) require.Nil(t, msg) - //TODO: add business logic to check signature + // Try getting a signature for a message that was not added. + _, err = backend.GetMessageSignature(testUnsignedMessage) + require.Error(t, err) } From 6621fb021c9f5ec51e096a449f64ad4e9c1baa19 Mon Sep 17 00:00:00 2001 From: ivansukach <47761294+ivansukach@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:10:16 +0100 Subject: [PATCH 04/10] Warp signer (#69) * move warpSigner backend functionality to signMessage backend method * get rid of unnecessary slicing --------- Co-authored-by: DESKTOP-765JFGJ\Admin --- warp/backend.go | 28 ++++++++++++++-------------- warp/backend_test.go | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/warp/backend.go b/warp/backend.go index bcef5bda..74e1aedb 100644 --- a/warp/backend.go +++ b/warp/backend.go @@ -7,7 +7,6 @@ import ( dbm "github.com/cometbft/cometbft-db" "github.com/cometbft/cometbft/libs/log" - "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" "github.com/landslidenetwork/slide-sdk/utils/ids" warputils "github.com/landslidenetwork/slide-sdk/utils/warp" "github.com/landslidenetwork/slide-sdk/utils/warp/payload" @@ -19,7 +18,7 @@ type Backend interface { // AddMessage signs [unsignedMessage] and adds it to the warp backend database AddMessage(unsignedMessage *warputils.UnsignedMessage) error // GetMessageSignature returns the signature of the requested message. - GetMessageSignature(message *warputils.UnsignedMessage) ([bls.SignatureLen]byte, error) + GetMessageSignature(message *warputils.UnsignedMessage) ([]byte, error) // GetMessage retrieves the [unsignedMessage] from the warp backend database if available // TODO: After E-Upgrade, the backend no longer needs to store the mapping from messageHash // to unsignedMessage (and this method can be removed). @@ -56,8 +55,7 @@ func (b *backend) AddMessage(unsignedMessage *warputils.UnsignedMessage) error { return fmt.Errorf("failed to put warp signature in db: %w", err) } - _, err := b.warpSigner.Sign(unsignedMessage) - if err != nil { + if _, err := b.signMessage(unsignedMessage); err != nil { return fmt.Errorf("failed to sign warp message: %w", err) } //TODO: save message signature to prefixdb @@ -79,22 +77,15 @@ func (b *backend) GetMessage(messageID ids.ID) (*warputils.UnsignedMessage, erro return unsignedMessage, nil } -func (b *backend) GetMessageSignature(unsignedMessage *warputils.UnsignedMessage) ([bls.SignatureLen]byte, error) { +func (b *backend) GetMessageSignature(unsignedMessage *warputils.UnsignedMessage) ([]byte, error) { messageID := unsignedMessage.ID() b.logger.Debug("Getting warp message from backend", "messageID", messageID) if err := b.ValidateMessage(unsignedMessage); err != nil { - return [bls.SignatureLen]byte{}, fmt.Errorf("failed to validate warp message: %w", err) - } - - var signature [bls.SignatureLen]byte - sig, err := b.warpSigner.Sign(unsignedMessage) - if err != nil { - return [bls.SignatureLen]byte{}, fmt.Errorf("failed to sign warp message: %w", err) + return []byte{}, fmt.Errorf("failed to validate warp message: %w", err) } - copy(signature[:], sig) - return signature, nil + return b.signMessage(unsignedMessage) } func (b *backend) ValidateMessage(unsignedMessage *warputils.UnsignedMessage) error { @@ -127,3 +118,12 @@ func (b *backend) ValidateMessage(unsignedMessage *warputils.UnsignedMessage) er } return nil } + +func (b *backend) signMessage(unsignedMessage *warputils.UnsignedMessage) ([]byte, error) { + sig, err := b.warpSigner.Sign(unsignedMessage) + if err != nil { + return nil, fmt.Errorf("failed to sign warp message: %w", err) + } + + return sig, nil +} diff --git a/warp/backend_test.go b/warp/backend_test.go index 3e886963..d9bfb286 100644 --- a/warp/backend_test.go +++ b/warp/backend_test.go @@ -58,7 +58,7 @@ func TestAddAndGetValidMessage(t *testing.T) { signature, err := backend.GetMessageSignature(testUnsignedMessage) require.NoError(t, err) require.NoError(t, err) - require.Equal(t, expectedSig, signature[:]) + require.Equal(t, expectedSig, signature) } func TestAddAndGetUnknownMessage(t *testing.T) { From b7311820084d3b05722b89e0d709d0a7fc666bc2 Mon Sep 17 00:00:00 2001 From: ivansukach <47761294+ivansukach@users.noreply.github.com> Date: Tue, 24 Dec 2024 11:19:04 +0100 Subject: [PATCH 05/10] Warp service rpc modification (#72) * WIP: prelimiry warp service * warp service vm integration * warp get message test * warp RPC: get message signature implementation * get rid of RPC layer for warp service * fix staticcheck warnings * goimport, add space between space and slash --------- Co-authored-by: DESKTOP-765JFGJ\Admin Co-authored-by: Ivan Sukach --- vm/rpc.go | 7 ++++ vm/rpc_test.go | 44 +++++++++++++++++++ vm/types/config.go | 1 + vm/vm.go | 42 +++++++++++++++++++ vm/vm_test.go | 21 ++++++++++ vm/warp_service.go | 102 +++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 217 insertions(+) create mode 100644 vm/warp_service.go diff --git a/vm/rpc.go b/vm/rpc.go index e493df3b..46db638d 100644 --- a/vm/rpc.go +++ b/vm/rpc.go @@ -34,6 +34,7 @@ func NewRPC(vm *LandslideVM) *RPC { } func (rpc *RPC) Routes() map[string]*jsonrpc.RPCFunc { + return map[string]*jsonrpc.RPCFunc{ // info AP @@ -63,6 +64,12 @@ func (rpc *RPC) Routes() map[string]*jsonrpc.RPCFunc { // abci API "abci_query": jsonrpc.NewRPCFunc(rpc.ABCIQuery, "path,data,height,prove"), "abci_info": jsonrpc.NewRPCFunc(rpc.ABCIInfo, "", jsonrpc.Cacheable()), + + // warp + "warp_get_message": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessage, "messageID"), + "warp_get_message_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessageSignature, "messageID"), + "warp_get_message_aggregate_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessageAggregateSignature, "messageID,quorumNum,subnetID"), + "warp_get_block_aggregate_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetBlockAggregateSignature, "blockID,quorumNum,subnetID"), } } diff --git a/vm/rpc_test.go b/vm/rpc_test.go index 64640f70..fa179521 100644 --- a/vm/rpc_test.go +++ b/vm/rpc_test.go @@ -6,6 +6,10 @@ import ( "testing" "time" + "github.com/cometbft/cometbft/libs/rand" + "github.com/landslidenetwork/slide-sdk/utils/ids" + warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + ctypes "github.com/cometbft/cometbft/rpc/core/types" "github.com/cometbft/cometbft/rpc/jsonrpc/client" rpctypes "github.com/cometbft/cometbft/rpc/jsonrpc/types" @@ -59,6 +63,46 @@ func TestStatus(t *testing.T) { t.Logf("Status result %+v", result) } +func TestWarpGetMessage(t *testing.T) { + server, vm, rpcClient := setupRPC(t) + defer server.Close() + + chainID, err := ids.ToID(vm.appOpts.ChainID) + require.NoError(t, err) + testUnsignedMessage, err := warputils.NewUnsignedMessage(vm.appOpts.NetworkID, chainID, []byte(rand.Str(30))) + require.NoError(t, err) + vm.warpBackend.AddMessage(testUnsignedMessage) + result := new(ResultGetMessage) + _, err = rpcClient.Call(context.Background(), "warp_get_message", map[string]interface{}{"messageID": testUnsignedMessage.ID().String()}, result) + require.NoError(t, err) + t.Log(result.Message) + t.Log(testUnsignedMessage.Bytes()) + require.Equal(t, result.Message, testUnsignedMessage.Bytes()) +} + +func TestWarpGetMessageSignature(t *testing.T) { + server, vm, rpcClient := setupRPC(t) + defer server.Close() + + chainID, err := ids.ToID(vm.appOpts.ChainID) + require.NoError(t, err) + testUnsignedMessage, err := warputils.NewUnsignedMessage(vm.appOpts.NetworkID, chainID, []byte(rand.Str(30))) + require.NoError(t, err) + vm.warpBackend.AddMessage(testUnsignedMessage) + result := new(ResultGetMessageSignature) + _, err = rpcClient.Call(context.Background(), "warp_get_message_signature", map[string]interface{}{"messageID": testUnsignedMessage.ID().String()}, result) + require.NoError(t, err) + expectedSig, err := vm.warpSigner.Sign(testUnsignedMessage) + require.NoError(t, err) + require.NotNil(t, expectedSig) + + t.Log(result.Signature) + t.Log(expectedSig) + + require.NoError(t, err) + require.Equal(t, expectedSig, result.Signature) +} + // TestRPC is a test RPC server for the LandslideVM. type TestRPC struct { vm *LandslideVM diff --git a/vm/types/config.go b/vm/types/config.go index a8151f62..5adaba1f 100644 --- a/vm/types/config.go +++ b/vm/types/config.go @@ -31,6 +31,7 @@ type ( ConsensusParams ConsensusParams `json:"consensus_params"` MaxSubscriptionClients int `json:"max_subscription_clients"` MaxSubscriptionsPerClient int `json:"max_subscriptions_per_client"` + BLSSecretKey []byte `json:"bls_secret_key"` } // ConsensusParams contains consensus critical parameters that determine the diff --git a/vm/vm.go b/vm/vm.go index d656b298..aeeefcf0 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -12,6 +12,10 @@ import ( "sync" "time" + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + "github.com/landslidenetwork/slide-sdk/warp" + dbm "github.com/cometbft/cometbft-db" abcitypes "github.com/cometbft/cometbft/abci/types" "github.com/cometbft/cometbft/config" @@ -64,6 +68,7 @@ var ( dbPrefixStateStore = []byte("state-store") dbPrefixTxIndexer = []byte("tx-indexer") dbPrefixBlockIndexer = []byte("block-indexer") + dbPrefixWarp = []byte("warp") // TODO: use internal app validators instead proposerAddress = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} @@ -135,6 +140,12 @@ type ( preferred [32]byte wrappedBlocks *vmstate.WrappedBlocksStorage + // Avalanche Warp Messaging backend + // Used to serve BLS signatures of warp messages over RPC + warpBackend warp.Backend + warpSigner warputils.Signer + warpService *API + clientConn grpc.ClientConnInterface optClientConn *grpc.ClientConn config vmtypes.VMConfig @@ -443,6 +454,37 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest parentHash := block.ParentHash(blk) + warpDB := dbm.NewPrefixDB(vm.database, dbPrefixWarp) + // TODO: implement bls secret key check + // if vm.config.BLSSecretKey == nil { + // if err != nil { + // return nil, err + // } + // } + chainID, err := ids.ToID(req.ChainId) + if err != nil { + return nil, err + } + fmt.Println(vm.config.BLSSecretKey) + secretKey, err := bls.SecretKeyFromBytes(vm.config.BLSSecretKey) + if err != nil { + return nil, err + } + vm.warpSigner = warputils.NewSigner(secretKey, req.NetworkId, chainID) + vm.warpBackend = warp.NewBackend( + req.NetworkId, + chainID, + vm.warpSigner, + vm.logger, + warpDB, + ) + + subnetID, err := ids.ToID(req.SubnetId) + if err != nil { + return nil, err + } + vm.warpService = NewAPI(vm, req.NetworkId, subnetID, chainID, vm.warpBackend) + return &vmpb.InitializeResponse{ LastAcceptedId: blk.Hash(), LastAcceptedParentId: parentHash[:], diff --git a/vm/vm_test.go b/vm/vm_test.go index a5001c2f..8365f883 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -3,11 +3,15 @@ package vm import ( "context" _ "embed" + "encoding/json" "net" "testing" dbm "github.com/cometbft/cometbft-db" "github.com/cometbft/cometbft/abci/example/kvstore" + "github.com/cometbft/cometbft/libs/rand" + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + vmtypes "github.com/landslidenetwork/slide-sdk/vm/types" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -55,9 +59,26 @@ func newKvApp(t *testing.T, vmdb, appdb dbm.DB) vmpb.VMServer { return kvstore.NewApplication(appdb), nil }, WithOptClientConn(mockConn)) require.NotNil(t, vm) + sk, err := bls.NewSecretKey() + if err != nil { + t.Fatalf("Failed to generate secret key: %v", err) + } + skBytes := bls.SecretKeyToBytes(sk) + vmCfg := vmtypes.Config{} + vmCfg.VMConfig.SetDefaults() + + vmCfg.VMConfig.BLSSecretKey = skBytes + + cfg, err := json.Marshal(vmCfg) + if err != nil { + t.Fatalf("Failed to marshal vm config to json: %v", err) + } initRes, err := vm.Initialize(context.TODO(), &vmpb.InitializeRequest{ DbServerAddr: "inmemory", GenesisBytes: kvstorevmGenesis, + ChainId: []byte(rand.Str(32)), + SubnetId: []byte(rand.Str(32)), + ConfigBytes: cfg, }) require.NoError(t, err) require.NotNil(t, initRes) diff --git a/vm/warp_service.go b/vm/warp_service.go new file mode 100644 index 00000000..4621bda0 --- /dev/null +++ b/vm/warp_service.go @@ -0,0 +1,102 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package vm + +import ( + "context" + "fmt" + + tmbytes "github.com/cometbft/cometbft/libs/bytes" + rpctypes "github.com/cometbft/cometbft/rpc/jsonrpc/types" + "github.com/landslidenetwork/slide-sdk/utils/ids" + warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + "github.com/landslidenetwork/slide-sdk/utils/warp/payload" + "github.com/landslidenetwork/slide-sdk/warp" +) + +const failedParseIDPattern = "failed to parse ID %s with error %w" + +type ResultGetMessage struct { + Message []byte `json:"message"` +} + +type ResultGetMessageSignature struct { + Signature []byte `json:"signature"` +} + +// API introduces snowman specific functionality to the evm +type API struct { + vm *LandslideVM + networkID uint32 + sourceSubnetID, sourceChainID ids.ID + backend warp.Backend +} + +func NewAPI(vm *LandslideVM, networkID uint32, sourceSubnetID ids.ID, sourceChainID ids.ID, backend warp.Backend) *API { + return &API{ + vm: vm, + networkID: networkID, + sourceSubnetID: sourceSubnetID, + sourceChainID: sourceChainID, + backend: backend, + } +} + +// GetMessage returns the Warp message associated with a messageID. +func (a *API) GetMessage(_ *rpctypes.Context, messageID string) (*ResultGetMessage, error) { + msgID, err := ids.FromString(messageID) + if err != nil { + return nil, fmt.Errorf(failedParseIDPattern, messageID, err) + } + message, err := a.backend.GetMessage(msgID) + if err != nil { + return nil, fmt.Errorf("failed to get message %s with error %w", messageID, err) + } + return &ResultGetMessage{Message: message.Bytes()}, nil +} + +// GetMessageSignature returns the BLS signature associated with a messageID. +func (a *API) GetMessageSignature(_ *rpctypes.Context, messageID string) (*ResultGetMessageSignature, error) { + msgID, err := ids.FromString(messageID) + if err != nil { + return nil, fmt.Errorf(failedParseIDPattern, messageID, err) + } + unsignedMessage, err := a.backend.GetMessage(msgID) + if err != nil { + return nil, fmt.Errorf("failed to get message %s with error %w", messageID, err) + } + signature, err := a.backend.GetMessageSignature(unsignedMessage) + if err != nil { + return nil, fmt.Errorf("failed to get signature for message %s with error %w", messageID, err) + } + return &ResultGetMessageSignature{Signature: signature}, nil +} + +// GetMessageAggregateSignature fetches the aggregate signature for the requested [messageID] +func (a *API) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64, subnetIDStr string) (signedMessageBytes tmbytes.HexBytes, err error) { + unsignedMessage, err := a.backend.GetMessage(messageID) + if err != nil { + return nil, err + } + return a.aggregateSignatures(ctx, unsignedMessage, quorumNum, subnetIDStr) +} + +// GetBlockAggregateSignature fetches the aggregate signature for the requested [blockID] +func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64, subnetIDStr string) (signedMessageBytes tmbytes.HexBytes, err error) { + blockHashPayload, err := payload.NewHash(blockID) + if err != nil { + return nil, err + } + unsignedMessage, err := warputils.NewUnsignedMessage(a.networkID, a.sourceChainID, blockHashPayload.Bytes()) + if err != nil { + return nil, err + } + + return a.aggregateSignatures(ctx, unsignedMessage, quorumNum, subnetIDStr) +} + +func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warputils.UnsignedMessage, quorumNum uint64, subnetIDStr string) (tmbytes.HexBytes, error) { + // TODO: implement aggregateSignatures + return nil, nil +} From 4cbcc9885cd3fa48e6ee17398e7be639beba8cfb Mon Sep 17 00:00:00 2001 From: ivansukach <47761294+ivansukach@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:41:56 +0100 Subject: [PATCH 06/10] Simplification of warp signature aggregation (#76) * preliminary version of warp aggregate signature functionality * aggregate signature: signature getter, aggregator, message * panic implementation of peer network client * simplify validator set receiving * get rid of client field, set signature getter instead * signature getter: replace peer.NetworkSignatureGetter with apiFetcher * get rid of unnecessary functionality * fix sprintf template for string representation of message type * add exceptions for linter * get rid of field index of Validator struct * move comment to another position * nolint:SA1019 * change format of linter exceptional comment * unslice; goimports; add G507 exception * setup exception comment for linter * setup #nosec exception comment for linter * disable G507 check * get rid of import deprecated ripemd160 library --------- Co-authored-by: Ivan Sukach --- .golangci.yml | 3 +- go.mod | 2 +- .../gvalidators/validator_state_client.go | 144 ++++ proto/validatorstate/validator_state.pb.go | 795 ++++++++++++++++++ proto/validatorstate/validator_state.proto | 68 ++ .../validatorstate/validator_state_grpc.pb.go | 272 ++++++ utils/crypto/bls/public.go | 66 +- utils/hashing/hashing.go | 14 + utils/ids/node_id.go | 84 ++ utils/ids/node_id_test.go | 218 +++++ utils/ids/short.go | 125 +++ utils/math/safe_math.go | 60 ++ utils/set/bits.go | 102 +++ utils/set/bits_test.go | 508 +++++++++++ utils/validators/state.go | 116 +++ utils/validators/validator.go | 37 + utils/version/application.go | 68 ++ utils/warp/aggregator/aggregator.go | 174 ++++ utils/warp/aggregator/aggregator_test.go | 10 + utils/warp/aggregator/signature_getter.go | 18 + utils/warp/message.go | 59 ++ utils/warp/message_test.go | 44 + utils/warp/messages/payload.go | 2 +- utils/warp/signature.go | 165 ++++ utils/warp/validator.go | 131 +++ utils/warp/validator_test.go | 69 ++ utils/warp/validators/state.go | 58 ++ utils/warp/validators/state_test.go | 12 + vm/rpc.go | 1 + vm/types/config.go | 13 +- vm/vm.go | 22 +- vm/warp_service.go | 85 +- warp/backend.go | 26 +- warp/client.go | 90 ++ warp/fetcher.go | 54 ++ 35 files changed, 3692 insertions(+), 23 deletions(-) create mode 100644 grpcutils/gvalidators/validator_state_client.go create mode 100644 proto/validatorstate/validator_state.pb.go create mode 100644 proto/validatorstate/validator_state.proto create mode 100644 proto/validatorstate/validator_state_grpc.pb.go create mode 100644 utils/ids/node_id.go create mode 100644 utils/ids/node_id_test.go create mode 100644 utils/ids/short.go create mode 100644 utils/math/safe_math.go create mode 100644 utils/set/bits.go create mode 100644 utils/set/bits_test.go create mode 100644 utils/validators/state.go create mode 100644 utils/validators/validator.go create mode 100644 utils/version/application.go create mode 100644 utils/warp/aggregator/aggregator.go create mode 100644 utils/warp/aggregator/aggregator_test.go create mode 100644 utils/warp/aggregator/signature_getter.go create mode 100644 utils/warp/message.go create mode 100644 utils/warp/message_test.go create mode 100644 utils/warp/validator.go create mode 100644 utils/warp/validator_test.go create mode 100644 utils/warp/validators/state.go create mode 100644 utils/warp/validators/state_test.go create mode 100644 warp/client.go create mode 100644 warp/fetcher.go diff --git a/.golangci.yml b/.golangci.yml index d61f6395..54b0247e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -19,6 +19,7 @@ linters-settings: gosec: excludes: - G115 # integer overflow conversion int -> uint8 (gosec) + - G507 revive: ignore-generated-header: true severity: warning @@ -364,4 +365,4 @@ issues: - goconst - gosec - noctx - - wrapcheck \ No newline at end of file + - wrapcheck diff --git a/go.mod b/go.mod index 5b78e8ca..4fbf7978 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/supranational/blst v0.3.13 go.uber.org/mock v0.4.0 + golang.org/x/crypto v0.25.0 golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 golang.org/x/sync v0.7.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 @@ -188,7 +189,6 @@ require ( go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/trace v1.24.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.25.0 // indirect golang.org/x/net v0.27.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect golang.org/x/sys v0.22.0 // indirect diff --git a/grpcutils/gvalidators/validator_state_client.go b/grpcutils/gvalidators/validator_state_client.go new file mode 100644 index 00000000..64f89fe9 --- /dev/null +++ b/grpcutils/gvalidators/validator_state_client.go @@ -0,0 +1,144 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package gvalidators + +import ( + "context" + "errors" + + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/ids" + "github.com/landslidenetwork/slide-sdk/utils/validators" + + pb "github.com/landslidenetwork/slide-sdk/proto/validatorstate" +) + +var ( + _ validators.State = (*Client)(nil) + errFailedPublicKeyDeserialize = errors.New("couldn't deserialize public key") +) + +type Client struct { + client pb.ValidatorStateClient +} + +func NewClient(client pb.ValidatorStateClient) *Client { + return &Client{client: client} +} + +func (c *Client) GetMinimumHeight(ctx context.Context) (uint64, error) { + resp, err := c.client.GetMinimumHeight(ctx, &emptypb.Empty{}) + if err != nil { + return 0, err + } + return resp.Height, nil +} + +func (c *Client) GetCurrentHeight(ctx context.Context) (uint64, error) { + resp, err := c.client.GetCurrentHeight(ctx, &emptypb.Empty{}) + if err != nil { + return 0, err + } + return resp.Height, nil +} + +func (c *Client) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { + resp, err := c.client.GetSubnetID(ctx, &pb.GetSubnetIDRequest{ + ChainId: chainID[:], + }) + if err != nil { + return ids.Empty, err + } + return ids.ToID(resp.SubnetId) +} + +func (c *Client) GetValidatorSet( + ctx context.Context, + height uint64, + subnetID ids.ID, +) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + resp, err := c.client.GetValidatorSet(ctx, &pb.GetValidatorSetRequest{ + Height: height, + SubnetId: subnetID[:], + }) + if err != nil { + return nil, err + } + + vdrs := make(map[ids.NodeID]*validators.GetValidatorOutput, len(resp.Validators)) + for _, validator := range resp.Validators { + nodeID, err := ids.ToNodeID(validator.NodeId) + if err != nil { + return nil, err + } + var publicKey *bls.PublicKey + if len(validator.PublicKey) > 0 { + // PublicKeyFromValidUncompressedBytes is used rather than + // PublicKeyFromCompressedBytes because it is significantly faster + // due to the avoidance of decompression and key re-verification. We + // can safely assume that the BLS Public Keys are verified before + // being added to the P-Chain and served by the gRPC server. + publicKey = bls.PublicKeyFromValidUncompressedBytes(validator.PublicKey) + if publicKey == nil { + return nil, errFailedPublicKeyDeserialize + } + } + vdrs[nodeID] = &validators.GetValidatorOutput{ + NodeID: nodeID, + PublicKey: publicKey, + Weight: validator.Weight, + } + } + return vdrs, nil +} + +func (c *Client) GetCurrentValidatorSet( + ctx context.Context, + subnetID ids.ID, +) (map[ids.ID]*validators.GetCurrentValidatorOutput, uint64, error) { + resp, err := c.client.GetCurrentValidatorSet(ctx, &pb.GetCurrentValidatorSetRequest{ + SubnetId: subnetID[:], + }) + if err != nil { + return nil, 0, err + } + + vdrs := make(map[ids.ID]*validators.GetCurrentValidatorOutput, len(resp.Validators)) + for _, validator := range resp.Validators { + nodeID, err := ids.ToNodeID(validator.NodeId) + if err != nil { + return nil, 0, err + } + var publicKey *bls.PublicKey + if len(validator.PublicKey) > 0 { + // PublicKeyFromValidUncompressedBytes is used rather than + // PublicKeyFromCompressedBytes because it is significantly faster + // due to the avoidance of decompression and key re-verification. We + // can safely assume that the BLS Public Keys are verified before + // being added to the P-Chain and served by the gRPC server. + publicKey = bls.PublicKeyFromValidUncompressedBytes(validator.PublicKey) + if publicKey == nil { + return nil, 0, errFailedPublicKeyDeserialize + } + } + validationID, err := ids.ToID(validator.ValidationId) + if err != nil { + return nil, 0, err + } + + vdrs[validationID] = &validators.GetCurrentValidatorOutput{ + ValidationID: validationID, + NodeID: nodeID, + PublicKey: publicKey, + Weight: validator.Weight, + StartTime: validator.StartTime, + MinNonce: validator.MinNonce, + IsActive: validator.IsActive, + IsL1Validator: validator.IsL1Validator, + } + } + return vdrs, resp.GetCurrentHeight(), nil +} diff --git a/proto/validatorstate/validator_state.pb.go b/proto/validatorstate/validator_state.pb.go new file mode 100644 index 00000000..739d4698 --- /dev/null +++ b/proto/validatorstate/validator_state.pb.go @@ -0,0 +1,795 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.33.0 +// protoc (unknown) +// source: validatorstate/validator_state.proto + +package validatorstate + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type GetMinimumHeightResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Height uint64 `protobuf:"varint,1,opt,name=height,proto3" json:"height,omitempty"` +} + +func (x *GetMinimumHeightResponse) Reset() { + *x = GetMinimumHeightResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetMinimumHeightResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetMinimumHeightResponse) ProtoMessage() {} + +func (x *GetMinimumHeightResponse) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetMinimumHeightResponse.ProtoReflect.Descriptor instead. +func (*GetMinimumHeightResponse) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{0} +} + +func (x *GetMinimumHeightResponse) GetHeight() uint64 { + if x != nil { + return x.Height + } + return 0 +} + +type GetCurrentHeightResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Height uint64 `protobuf:"varint,1,opt,name=height,proto3" json:"height,omitempty"` +} + +func (x *GetCurrentHeightResponse) Reset() { + *x = GetCurrentHeightResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetCurrentHeightResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetCurrentHeightResponse) ProtoMessage() {} + +func (x *GetCurrentHeightResponse) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetCurrentHeightResponse.ProtoReflect.Descriptor instead. +func (*GetCurrentHeightResponse) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{1} +} + +func (x *GetCurrentHeightResponse) GetHeight() uint64 { + if x != nil { + return x.Height + } + return 0 +} + +type GetSubnetIDRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` +} + +func (x *GetSubnetIDRequest) Reset() { + *x = GetSubnetIDRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetSubnetIDRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSubnetIDRequest) ProtoMessage() {} + +func (x *GetSubnetIDRequest) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSubnetIDRequest.ProtoReflect.Descriptor instead. +func (*GetSubnetIDRequest) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{2} +} + +func (x *GetSubnetIDRequest) GetChainId() []byte { + if x != nil { + return x.ChainId + } + return nil +} + +type GetSubnetIDResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SubnetId []byte `protobuf:"bytes,1,opt,name=subnet_id,json=subnetId,proto3" json:"subnet_id,omitempty"` +} + +func (x *GetSubnetIDResponse) Reset() { + *x = GetSubnetIDResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetSubnetIDResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSubnetIDResponse) ProtoMessage() {} + +func (x *GetSubnetIDResponse) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSubnetIDResponse.ProtoReflect.Descriptor instead. +func (*GetSubnetIDResponse) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{3} +} + +func (x *GetSubnetIDResponse) GetSubnetId() []byte { + if x != nil { + return x.SubnetId + } + return nil +} + +type GetValidatorSetRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Height uint64 `protobuf:"varint,1,opt,name=height,proto3" json:"height,omitempty"` + SubnetId []byte `protobuf:"bytes,2,opt,name=subnet_id,json=subnetId,proto3" json:"subnet_id,omitempty"` +} + +func (x *GetValidatorSetRequest) Reset() { + *x = GetValidatorSetRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetValidatorSetRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetValidatorSetRequest) ProtoMessage() {} + +func (x *GetValidatorSetRequest) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetValidatorSetRequest.ProtoReflect.Descriptor instead. +func (*GetValidatorSetRequest) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{4} +} + +func (x *GetValidatorSetRequest) GetHeight() uint64 { + if x != nil { + return x.Height + } + return 0 +} + +func (x *GetValidatorSetRequest) GetSubnetId() []byte { + if x != nil { + return x.SubnetId + } + return nil +} + +type GetCurrentValidatorSetRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SubnetId []byte `protobuf:"bytes,1,opt,name=subnet_id,json=subnetId,proto3" json:"subnet_id,omitempty"` +} + +func (x *GetCurrentValidatorSetRequest) Reset() { + *x = GetCurrentValidatorSetRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetCurrentValidatorSetRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetCurrentValidatorSetRequest) ProtoMessage() {} + +func (x *GetCurrentValidatorSetRequest) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetCurrentValidatorSetRequest.ProtoReflect.Descriptor instead. +func (*GetCurrentValidatorSetRequest) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{5} +} + +func (x *GetCurrentValidatorSetRequest) GetSubnetId() []byte { + if x != nil { + return x.SubnetId + } + return nil +} + +type Validator struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + NodeId []byte `protobuf:"bytes,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + Weight uint64 `protobuf:"varint,2,opt,name=weight,proto3" json:"weight,omitempty"` + PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` // Uncompressed public key, can be empty + StartTime uint64 `protobuf:"varint,4,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` // can be empty + MinNonce uint64 `protobuf:"varint,5,opt,name=min_nonce,json=minNonce,proto3" json:"min_nonce,omitempty"` // can be empty + IsActive bool `protobuf:"varint,6,opt,name=is_active,json=isActive,proto3" json:"is_active,omitempty"` // can be empty + ValidationId []byte `protobuf:"bytes,7,opt,name=validation_id,json=validationId,proto3" json:"validation_id,omitempty"` // can be empty + IsL1Validator bool `protobuf:"varint,8,opt,name=is_l1_validator,json=isL1Validator,proto3" json:"is_l1_validator,omitempty"` // can be empty +} + +func (x *Validator) Reset() { + *x = Validator{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Validator) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Validator) ProtoMessage() {} + +func (x *Validator) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Validator.ProtoReflect.Descriptor instead. +func (*Validator) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{6} +} + +func (x *Validator) GetNodeId() []byte { + if x != nil { + return x.NodeId + } + return nil +} + +func (x *Validator) GetWeight() uint64 { + if x != nil { + return x.Weight + } + return 0 +} + +func (x *Validator) GetPublicKey() []byte { + if x != nil { + return x.PublicKey + } + return nil +} + +func (x *Validator) GetStartTime() uint64 { + if x != nil { + return x.StartTime + } + return 0 +} + +func (x *Validator) GetMinNonce() uint64 { + if x != nil { + return x.MinNonce + } + return 0 +} + +func (x *Validator) GetIsActive() bool { + if x != nil { + return x.IsActive + } + return false +} + +func (x *Validator) GetValidationId() []byte { + if x != nil { + return x.ValidationId + } + return nil +} + +func (x *Validator) GetIsL1Validator() bool { + if x != nil { + return x.IsL1Validator + } + return false +} + +type GetValidatorSetResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Validators []*Validator `protobuf:"bytes,1,rep,name=validators,proto3" json:"validators,omitempty"` +} + +func (x *GetValidatorSetResponse) Reset() { + *x = GetValidatorSetResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetValidatorSetResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetValidatorSetResponse) ProtoMessage() {} + +func (x *GetValidatorSetResponse) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetValidatorSetResponse.ProtoReflect.Descriptor instead. +func (*GetValidatorSetResponse) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{7} +} + +func (x *GetValidatorSetResponse) GetValidators() []*Validator { + if x != nil { + return x.Validators + } + return nil +} + +type GetCurrentValidatorSetResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Validators []*Validator `protobuf:"bytes,1,rep,name=validators,proto3" json:"validators,omitempty"` + CurrentHeight uint64 `protobuf:"varint,2,opt,name=current_height,json=currentHeight,proto3" json:"current_height,omitempty"` +} + +func (x *GetCurrentValidatorSetResponse) Reset() { + *x = GetCurrentValidatorSetResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_validatorstate_validator_state_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetCurrentValidatorSetResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetCurrentValidatorSetResponse) ProtoMessage() {} + +func (x *GetCurrentValidatorSetResponse) ProtoReflect() protoreflect.Message { + mi := &file_validatorstate_validator_state_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetCurrentValidatorSetResponse.ProtoReflect.Descriptor instead. +func (*GetCurrentValidatorSetResponse) Descriptor() ([]byte, []int) { + return file_validatorstate_validator_state_proto_rawDescGZIP(), []int{8} +} + +func (x *GetCurrentValidatorSetResponse) GetValidators() []*Validator { + if x != nil { + return x.Validators + } + return nil +} + +func (x *GetCurrentValidatorSetResponse) GetCurrentHeight() uint64 { + if x != nil { + return x.CurrentHeight + } + return 0 +} + +var File_validatorstate_validator_state_proto protoreflect.FileDescriptor + +var file_validatorstate_validator_state_proto_rawDesc = []byte{ + 0x0a, 0x24, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, + 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, + 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x22, 0x32, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x69, 0x6e, 0x69, 0x6d, 0x75, + 0x6d, 0x48, 0x65, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x16, 0x0a, 0x06, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x06, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x32, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x43, 0x75, + 0x72, 0x72, 0x65, 0x6e, 0x74, 0x48, 0x65, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x06, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x2f, 0x0a, 0x12, 0x47, + 0x65, 0x74, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x44, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x19, 0x0a, 0x08, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x49, 0x64, 0x22, 0x32, 0x0a, 0x13, + 0x47, 0x65, 0x74, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x44, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x5f, 0x69, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x64, + 0x22, 0x4d, 0x0a, 0x16, 0x47, 0x65, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, + 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x68, 0x65, + 0x69, 0x67, 0x68, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x68, 0x65, 0x69, 0x67, + 0x68, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x5f, 0x69, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x64, 0x22, + 0x3c, 0x0a, 0x1d, 0x47, 0x65, 0x74, 0x43, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x64, 0x22, 0x81, 0x02, + 0x0a, 0x09, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x17, 0x0a, 0x07, 0x6e, + 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x6e, 0x6f, + 0x64, 0x65, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x12, 0x1d, 0x0a, 0x0a, + 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1d, 0x0a, 0x0a, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x6d, 0x69, + 0x6e, 0x5f, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x04, 0x52, 0x08, 0x6d, + 0x69, 0x6e, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x73, 0x5f, 0x61, 0x63, + 0x74, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x73, 0x41, 0x63, + 0x74, 0x69, 0x76, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x76, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x26, 0x0a, 0x0f, 0x69, 0x73, 0x5f, + 0x6c, 0x31, 0x5f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x08, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0d, 0x69, 0x73, 0x4c, 0x31, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, + 0x72, 0x22, 0x54, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, + 0x72, 0x53, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x0a, + 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x19, 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, + 0x65, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x52, 0x0a, 0x76, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x1e, 0x47, 0x65, 0x74, 0x43, + 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x53, + 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x0a, 0x76, 0x61, + 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x19, + 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, + 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x52, 0x0a, 0x76, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x6f, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, + 0x5f, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0d, 0x63, + 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x48, 0x65, 0x69, 0x67, 0x68, 0x74, 0x32, 0xf1, 0x03, 0x0a, + 0x0e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x54, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x69, 0x6e, 0x69, 0x6d, 0x75, 0x6d, 0x48, 0x65, 0x69, + 0x67, 0x68, 0x74, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x28, 0x2e, 0x76, 0x61, + 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, 0x47, 0x65, 0x74, + 0x4d, 0x69, 0x6e, 0x69, 0x6d, 0x75, 0x6d, 0x48, 0x65, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x54, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, 0x75, 0x72, 0x72, + 0x65, 0x6e, 0x74, 0x48, 0x65, 0x69, 0x67, 0x68, 0x74, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x1a, 0x28, 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x48, 0x65, 0x69, + 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x56, 0x0a, 0x0b, 0x47, + 0x65, 0x74, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x22, 0x2e, 0x76, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x53, + 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x44, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, + 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, + 0x47, 0x65, 0x74, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x49, 0x44, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x62, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, + 0x74, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x12, 0x26, 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, + 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x27, + 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, + 0x47, 0x65, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x77, 0x0a, 0x16, 0x47, 0x65, 0x74, 0x43, 0x75, + 0x72, 0x72, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x53, 0x65, + 0x74, 0x12, 0x2d, 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x2e, 0x2e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, + 0x65, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x69, + 0x64, 0x61, 0x74, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x42, 0x3c, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, + 0x61, 0x6e, 0x64, 0x73, 0x6c, 0x69, 0x64, 0x65, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2f, + 0x73, 0x6c, 0x69, 0x64, 0x65, 0x2d, 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, + 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x6f, 0x72, 0x73, 0x74, 0x61, 0x74, 0x65, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_validatorstate_validator_state_proto_rawDescOnce sync.Once + file_validatorstate_validator_state_proto_rawDescData = file_validatorstate_validator_state_proto_rawDesc +) + +func file_validatorstate_validator_state_proto_rawDescGZIP() []byte { + file_validatorstate_validator_state_proto_rawDescOnce.Do(func() { + file_validatorstate_validator_state_proto_rawDescData = protoimpl.X.CompressGZIP(file_validatorstate_validator_state_proto_rawDescData) + }) + return file_validatorstate_validator_state_proto_rawDescData +} + +var file_validatorstate_validator_state_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_validatorstate_validator_state_proto_goTypes = []interface{}{ + (*GetMinimumHeightResponse)(nil), // 0: validatorstate.GetMinimumHeightResponse + (*GetCurrentHeightResponse)(nil), // 1: validatorstate.GetCurrentHeightResponse + (*GetSubnetIDRequest)(nil), // 2: validatorstate.GetSubnetIDRequest + (*GetSubnetIDResponse)(nil), // 3: validatorstate.GetSubnetIDResponse + (*GetValidatorSetRequest)(nil), // 4: validatorstate.GetValidatorSetRequest + (*GetCurrentValidatorSetRequest)(nil), // 5: validatorstate.GetCurrentValidatorSetRequest + (*Validator)(nil), // 6: validatorstate.Validator + (*GetValidatorSetResponse)(nil), // 7: validatorstate.GetValidatorSetResponse + (*GetCurrentValidatorSetResponse)(nil), // 8: validatorstate.GetCurrentValidatorSetResponse + (*emptypb.Empty)(nil), // 9: google.protobuf.Empty +} +var file_validatorstate_validator_state_proto_depIdxs = []int32{ + 6, // 0: validatorstate.GetValidatorSetResponse.validators:type_name -> validatorstate.Validator + 6, // 1: validatorstate.GetCurrentValidatorSetResponse.validators:type_name -> validatorstate.Validator + 9, // 2: validatorstate.ValidatorState.GetMinimumHeight:input_type -> google.protobuf.Empty + 9, // 3: validatorstate.ValidatorState.GetCurrentHeight:input_type -> google.protobuf.Empty + 2, // 4: validatorstate.ValidatorState.GetSubnetID:input_type -> validatorstate.GetSubnetIDRequest + 4, // 5: validatorstate.ValidatorState.GetValidatorSet:input_type -> validatorstate.GetValidatorSetRequest + 5, // 6: validatorstate.ValidatorState.GetCurrentValidatorSet:input_type -> validatorstate.GetCurrentValidatorSetRequest + 0, // 7: validatorstate.ValidatorState.GetMinimumHeight:output_type -> validatorstate.GetMinimumHeightResponse + 1, // 8: validatorstate.ValidatorState.GetCurrentHeight:output_type -> validatorstate.GetCurrentHeightResponse + 3, // 9: validatorstate.ValidatorState.GetSubnetID:output_type -> validatorstate.GetSubnetIDResponse + 7, // 10: validatorstate.ValidatorState.GetValidatorSet:output_type -> validatorstate.GetValidatorSetResponse + 8, // 11: validatorstate.ValidatorState.GetCurrentValidatorSet:output_type -> validatorstate.GetCurrentValidatorSetResponse + 7, // [7:12] is the sub-list for method output_type + 2, // [2:7] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_validatorstate_validator_state_proto_init() } +func file_validatorstate_validator_state_proto_init() { + if File_validatorstate_validator_state_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_validatorstate_validator_state_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetMinimumHeightResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetCurrentHeightResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetSubnetIDRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetSubnetIDResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetValidatorSetRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetCurrentValidatorSetRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Validator); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetValidatorSetResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_validatorstate_validator_state_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetCurrentValidatorSetResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_validatorstate_validator_state_proto_rawDesc, + NumEnums: 0, + NumMessages: 9, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_validatorstate_validator_state_proto_goTypes, + DependencyIndexes: file_validatorstate_validator_state_proto_depIdxs, + MessageInfos: file_validatorstate_validator_state_proto_msgTypes, + }.Build() + File_validatorstate_validator_state_proto = out.File + file_validatorstate_validator_state_proto_rawDesc = nil + file_validatorstate_validator_state_proto_goTypes = nil + file_validatorstate_validator_state_proto_depIdxs = nil +} diff --git a/proto/validatorstate/validator_state.proto b/proto/validatorstate/validator_state.proto new file mode 100644 index 00000000..c1219fdf --- /dev/null +++ b/proto/validatorstate/validator_state.proto @@ -0,0 +1,68 @@ +syntax = "proto3"; + +package validatorstate; + +import "google/protobuf/empty.proto"; + +option go_package = "github.com/landslidenetwork/slide-sdk/proto/validatorstate"; + +service ValidatorState { + // GetMinimumHeight returns the minimum height of the blocks in the optimal + // proposal window. + rpc GetMinimumHeight(google.protobuf.Empty) returns (GetMinimumHeightResponse); + // GetCurrentHeight returns the current height of the P-chain. + rpc GetCurrentHeight(google.protobuf.Empty) returns (GetCurrentHeightResponse); + // GetSubnetID returns the subnetID of the provided chain. + rpc GetSubnetID(GetSubnetIDRequest) returns (GetSubnetIDResponse); + // GetValidatorSet returns the weights of the nodeIDs for the provided + // subnet at the requested P-chain height. + rpc GetValidatorSet(GetValidatorSetRequest) returns (GetValidatorSetResponse); + // GetCurrentValidatorSet returns the validator set for the provided subnet at + // the current P-chain height. + rpc GetCurrentValidatorSet(GetCurrentValidatorSetRequest) returns (GetCurrentValidatorSetResponse); +} + +message GetMinimumHeightResponse { + uint64 height = 1; +} + +message GetCurrentHeightResponse { + uint64 height = 1; +} + +message GetSubnetIDRequest { + bytes chain_id = 1; +} + +message GetSubnetIDResponse { + bytes subnet_id = 1; +} + +message GetValidatorSetRequest { + uint64 height = 1; + bytes subnet_id = 2; +} + +message GetCurrentValidatorSetRequest { + bytes subnet_id = 1; +} + +message Validator { + bytes node_id = 1; + uint64 weight = 2; + bytes public_key = 3; // Uncompressed public key, can be empty + uint64 start_time = 4; // can be empty + uint64 min_nonce = 5; // can be empty + bool is_active = 6; // can be empty + bytes validation_id = 7; // can be empty + bool is_l1_validator = 8; // can be empty +} + +message GetValidatorSetResponse { + repeated Validator validators = 1; +} + +message GetCurrentValidatorSetResponse { + repeated Validator validators = 1; + uint64 current_height = 2; +} diff --git a/proto/validatorstate/validator_state_grpc.pb.go b/proto/validatorstate/validator_state_grpc.pb.go new file mode 100644 index 00000000..5d351fb5 --- /dev/null +++ b/proto/validatorstate/validator_state_grpc.pb.go @@ -0,0 +1,272 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.3.0 +// - protoc (unknown) +// source: validatorstate/validator_state.proto + +package validatorstate + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +const ( + ValidatorState_GetMinimumHeight_FullMethodName = "/validatorstate.ValidatorState/GetMinimumHeight" + ValidatorState_GetCurrentHeight_FullMethodName = "/validatorstate.ValidatorState/GetCurrentHeight" + ValidatorState_GetSubnetID_FullMethodName = "/validatorstate.ValidatorState/GetSubnetID" + ValidatorState_GetValidatorSet_FullMethodName = "/validatorstate.ValidatorState/GetValidatorSet" + ValidatorState_GetCurrentValidatorSet_FullMethodName = "/validatorstate.ValidatorState/GetCurrentValidatorSet" +) + +// ValidatorStateClient is the client API for ValidatorState service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ValidatorStateClient interface { + // GetMinimumHeight returns the minimum height of the blocks in the optimal + // proposal window. + GetMinimumHeight(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMinimumHeightResponse, error) + // GetCurrentHeight returns the current height of the P-chain. + GetCurrentHeight(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetCurrentHeightResponse, error) + // GetSubnetID returns the subnetID of the provided chain. + GetSubnetID(ctx context.Context, in *GetSubnetIDRequest, opts ...grpc.CallOption) (*GetSubnetIDResponse, error) + // GetValidatorSet returns the weights of the nodeIDs for the provided + // subnet at the requested P-chain height. + GetValidatorSet(ctx context.Context, in *GetValidatorSetRequest, opts ...grpc.CallOption) (*GetValidatorSetResponse, error) + // GetCurrentValidatorSet returns the validator set for the provided subnet at + // the current P-chain height. + GetCurrentValidatorSet(ctx context.Context, in *GetCurrentValidatorSetRequest, opts ...grpc.CallOption) (*GetCurrentValidatorSetResponse, error) +} + +type validatorStateClient struct { + cc grpc.ClientConnInterface +} + +func NewValidatorStateClient(cc grpc.ClientConnInterface) ValidatorStateClient { + return &validatorStateClient{cc} +} + +func (c *validatorStateClient) GetMinimumHeight(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMinimumHeightResponse, error) { + out := new(GetMinimumHeightResponse) + err := c.cc.Invoke(ctx, ValidatorState_GetMinimumHeight_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *validatorStateClient) GetCurrentHeight(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetCurrentHeightResponse, error) { + out := new(GetCurrentHeightResponse) + err := c.cc.Invoke(ctx, ValidatorState_GetCurrentHeight_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *validatorStateClient) GetSubnetID(ctx context.Context, in *GetSubnetIDRequest, opts ...grpc.CallOption) (*GetSubnetIDResponse, error) { + out := new(GetSubnetIDResponse) + err := c.cc.Invoke(ctx, ValidatorState_GetSubnetID_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *validatorStateClient) GetValidatorSet(ctx context.Context, in *GetValidatorSetRequest, opts ...grpc.CallOption) (*GetValidatorSetResponse, error) { + out := new(GetValidatorSetResponse) + err := c.cc.Invoke(ctx, ValidatorState_GetValidatorSet_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *validatorStateClient) GetCurrentValidatorSet(ctx context.Context, in *GetCurrentValidatorSetRequest, opts ...grpc.CallOption) (*GetCurrentValidatorSetResponse, error) { + out := new(GetCurrentValidatorSetResponse) + err := c.cc.Invoke(ctx, ValidatorState_GetCurrentValidatorSet_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ValidatorStateServer is the server API for ValidatorState service. +// All implementations should embed UnimplementedValidatorStateServer +// for forward compatibility +type ValidatorStateServer interface { + // GetMinimumHeight returns the minimum height of the blocks in the optimal + // proposal window. + GetMinimumHeight(context.Context, *emptypb.Empty) (*GetMinimumHeightResponse, error) + // GetCurrentHeight returns the current height of the P-chain. + GetCurrentHeight(context.Context, *emptypb.Empty) (*GetCurrentHeightResponse, error) + // GetSubnetID returns the subnetID of the provided chain. + GetSubnetID(context.Context, *GetSubnetIDRequest) (*GetSubnetIDResponse, error) + // GetValidatorSet returns the weights of the nodeIDs for the provided + // subnet at the requested P-chain height. + GetValidatorSet(context.Context, *GetValidatorSetRequest) (*GetValidatorSetResponse, error) + // GetCurrentValidatorSet returns the validator set for the provided subnet at + // the current P-chain height. + GetCurrentValidatorSet(context.Context, *GetCurrentValidatorSetRequest) (*GetCurrentValidatorSetResponse, error) +} + +// UnimplementedValidatorStateServer should be embedded to have forward compatible implementations. +type UnimplementedValidatorStateServer struct { +} + +func (UnimplementedValidatorStateServer) GetMinimumHeight(context.Context, *emptypb.Empty) (*GetMinimumHeightResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetMinimumHeight not implemented") +} +func (UnimplementedValidatorStateServer) GetCurrentHeight(context.Context, *emptypb.Empty) (*GetCurrentHeightResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetCurrentHeight not implemented") +} +func (UnimplementedValidatorStateServer) GetSubnetID(context.Context, *GetSubnetIDRequest) (*GetSubnetIDResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetSubnetID not implemented") +} +func (UnimplementedValidatorStateServer) GetValidatorSet(context.Context, *GetValidatorSetRequest) (*GetValidatorSetResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetValidatorSet not implemented") +} +func (UnimplementedValidatorStateServer) GetCurrentValidatorSet(context.Context, *GetCurrentValidatorSetRequest) (*GetCurrentValidatorSetResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetCurrentValidatorSet not implemented") +} + +// UnsafeValidatorStateServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ValidatorStateServer will +// result in compilation errors. +type UnsafeValidatorStateServer interface { + mustEmbedUnimplementedValidatorStateServer() +} + +func RegisterValidatorStateServer(s grpc.ServiceRegistrar, srv ValidatorStateServer) { + s.RegisterService(&ValidatorState_ServiceDesc, srv) +} + +func _ValidatorState_GetMinimumHeight_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(emptypb.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ValidatorStateServer).GetMinimumHeight(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ValidatorState_GetMinimumHeight_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ValidatorStateServer).GetMinimumHeight(ctx, req.(*emptypb.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _ValidatorState_GetCurrentHeight_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(emptypb.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ValidatorStateServer).GetCurrentHeight(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ValidatorState_GetCurrentHeight_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ValidatorStateServer).GetCurrentHeight(ctx, req.(*emptypb.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _ValidatorState_GetSubnetID_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetSubnetIDRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ValidatorStateServer).GetSubnetID(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ValidatorState_GetSubnetID_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ValidatorStateServer).GetSubnetID(ctx, req.(*GetSubnetIDRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ValidatorState_GetValidatorSet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetValidatorSetRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ValidatorStateServer).GetValidatorSet(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ValidatorState_GetValidatorSet_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ValidatorStateServer).GetValidatorSet(ctx, req.(*GetValidatorSetRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ValidatorState_GetCurrentValidatorSet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetCurrentValidatorSetRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ValidatorStateServer).GetCurrentValidatorSet(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ValidatorState_GetCurrentValidatorSet_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ValidatorStateServer).GetCurrentValidatorSet(ctx, req.(*GetCurrentValidatorSetRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// ValidatorState_ServiceDesc is the grpc.ServiceDesc for ValidatorState service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ValidatorState_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "validatorstate.ValidatorState", + HandlerType: (*ValidatorStateServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetMinimumHeight", + Handler: _ValidatorState_GetMinimumHeight_Handler, + }, + { + MethodName: "GetCurrentHeight", + Handler: _ValidatorState_GetCurrentHeight_Handler, + }, + { + MethodName: "GetSubnetID", + Handler: _ValidatorState_GetSubnetID_Handler, + }, + { + MethodName: "GetValidatorSet", + Handler: _ValidatorState_GetValidatorSet_Handler, + }, + { + MethodName: "GetCurrentValidatorSet", + Handler: _ValidatorState_GetCurrentValidatorSet_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "validatorstate/validator_state.proto", +} diff --git a/utils/crypto/bls/public.go b/utils/crypto/bls/public.go index acc3649c..2043b083 100644 --- a/utils/crypto/bls/public.go +++ b/utils/crypto/bls/public.go @@ -1,10 +1,23 @@ package bls -import blst "github.com/supranational/blst/bindings/go" +import ( + "errors" -var ciphersuiteSignature = []byte("BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_") + blst "github.com/supranational/blst/bindings/go" +) -type PublicKey = blst.P1Affine +var ( + ciphersuiteSignature = []byte("BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_") + ErrNoPublicKeys = errors.New("no public keys") + ErrFailedPublicKeyDecompress = errors.New("couldn't decompress public key") + errInvalidPublicKey = errors.New("invalid public key") + errFailedPublicKeyAggregation = errors.New("couldn't aggregate public keys") +) + +type ( + PublicKey = blst.P1Affine + AggregatePublicKey = blst.P1Aggregate +) // Verify the [sig] of [msg] against the [pk]. // The [sig] and [pk] may have been an aggregation of other signatures and keys. @@ -12,3 +25,50 @@ type PublicKey = blst.P1Affine func Verify(pk *PublicKey, sig *Signature, msg []byte) bool { return sig.Verify(false, pk, false, msg, ciphersuiteSignature) } + +// PublicKeyFromValidUncompressedBytes parses the uncompressed big-endian format +// of the public key into a public key. It is assumed that the provided bytes +// are valid. +func PublicKeyFromValidUncompressedBytes(pkBytes []byte) *PublicKey { + return new(PublicKey).Deserialize(pkBytes) +} + +// PublicKeyToUncompressedBytes returns the uncompressed big-endian format of +// the public key. +func PublicKeyToUncompressedBytes(key *PublicKey) []byte { + return key.Serialize() +} + +// PublicKeyToCompressedBytes returns the compressed big-endian format of the +// public key. +func PublicKeyToCompressedBytes(pk *PublicKey) []byte { + return pk.Compress() +} + +// PublicKeyFromCompressedBytes parses the compressed big-endian format of the +// public key into a public key. +func PublicKeyFromCompressedBytes(pkBytes []byte) (*PublicKey, error) { + pk := new(PublicKey).Uncompress(pkBytes) + if pk == nil { + return nil, ErrFailedPublicKeyDecompress + } + if !pk.KeyValidate() { + return nil, errInvalidPublicKey + } + return pk, nil +} + +// AggregatePublicKeys aggregates a non-zero number of public keys into a single +// aggregated public key. +// Invariant: all [pks] have been validated. +func AggregatePublicKeys(pks []*PublicKey) (*PublicKey, error) { + if len(pks) == 0 { + return nil, ErrNoPublicKeys + } + + var agg AggregatePublicKey + if !agg.Aggregate(pks, false) { + return nil, errFailedPublicKeyAggregation + } + return agg.ToAffine(), nil +} diff --git a/utils/hashing/hashing.go b/utils/hashing/hashing.go index e9ec2177..419c0152 100644 --- a/utils/hashing/hashing.go +++ b/utils/hashing/hashing.go @@ -8,6 +8,8 @@ import ( const ( HashLen = sha256.Size + // The size of the checksum in bytes. + RIPEMD160Size = 20 ) var ErrInvalidHashLen = errors.New("invalid hash length") @@ -15,6 +17,9 @@ var ErrInvalidHashLen = errors.New("invalid hash length") // Hash256 A 256 bit long hash value. type Hash256 = [HashLen]byte +// Hash160 A 160 bit long hash value. +type Hash160 = [RIPEMD160Size]byte + // ComputeHash256Array computes a cryptographically strong 256 bit hash of the // input byte slice. func ComputeHash256Array(buf []byte) Hash256 { @@ -62,3 +67,12 @@ func ToHash256(bytes []byte) (Hash256, error) { copy(hash[:], bytes) return hash, nil } + +func ToHash160(bytes []byte) (Hash160, error) { + hash := Hash160{} + if bytesLen := len(bytes); bytesLen != RIPEMD160Size { + return hash, fmt.Errorf("%w: expected 20 bytes but got %d", ErrInvalidHashLen, bytesLen) + } + copy(hash[:], bytes) + return hash, nil +} diff --git a/utils/ids/node_id.go b/utils/ids/node_id.go new file mode 100644 index 00000000..ab00af13 --- /dev/null +++ b/utils/ids/node_id.go @@ -0,0 +1,84 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package ids + +import ( + "bytes" + "errors" + "fmt" + + "github.com/landslidenetwork/slide-sdk/utils" +) + +const ( + NodeIDPrefix = "NodeID-" + NodeIDLen = ShortIDLen +) + +var ( + EmptyNodeID = NodeID{} + + errShortNodeID = errors.New("insufficient NodeID length") + + _ utils.Sortable[NodeID] = NodeID{} +) + +type NodeID ShortID + +func (id NodeID) String() string { + return ShortID(id).PrefixedString(NodeIDPrefix) +} + +func (id NodeID) Bytes() []byte { + return id[:] +} + +func (id NodeID) MarshalJSON() ([]byte, error) { + return []byte(`"` + id.String() + `"`), nil +} + +func (id NodeID) MarshalText() ([]byte, error) { + return []byte(id.String()), nil +} + +func (id *NodeID) UnmarshalJSON(b []byte) error { + str := string(b) + if str == nullStr { // If "null", do nothing + return nil + } else if len(str) <= 2+len(NodeIDPrefix) { + return fmt.Errorf("%w: expected to be > %d", errShortNodeID, 2+len(NodeIDPrefix)) + } + + lastIndex := len(str) - 1 + if str[0] != '"' || str[lastIndex] != '"' { + return errMissingQuotes + } + + var err error + *id, err = NodeIDFromString(str[1:lastIndex]) + return err +} + +func (id *NodeID) UnmarshalText(text []byte) error { + return id.UnmarshalJSON(text) +} + +func (id NodeID) Compare(other NodeID) int { + return bytes.Compare(id[:], other[:]) +} + +// ToNodeID attempt to convert a byte slice into a node id +func ToNodeID(bytes []byte) (NodeID, error) { + nodeID, err := ToShortID(bytes) + return NodeID(nodeID), err +} + +// NodeIDFromString is the inverse of NodeID.String() +func NodeIDFromString(nodeIDStr string) (NodeID, error) { + asShort, err := ShortFromPrefixedString(nodeIDStr, NodeIDPrefix) + if err != nil { + return NodeID{}, err + } + return NodeID(asShort), nil +} diff --git a/utils/ids/node_id_test.go b/utils/ids/node_id_test.go new file mode 100644 index 00000000..dfb7d7c3 --- /dev/null +++ b/utils/ids/node_id_test.go @@ -0,0 +1,218 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package ids + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/landslidenetwork/slide-sdk/utils/cb58" +) + +func TestNodeIDEquality(t *testing.T) { + require := require.New(t) + + id := NodeID{24} + idCopy := NodeID{24} + require.Equal(id, idCopy) + id2 := NodeID{} + require.NotEqual(id, id2) +} + +func TestNodeIDFromString(t *testing.T) { + require := require.New(t) + + id := NodeID{'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'} + idStr := id.String() + id2, err := NodeIDFromString(idStr) + require.NoError(err) + require.Equal(id, id2) + expected := "NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz" + require.Equal(expected, idStr) +} + +func TestNodeIDFromStringError(t *testing.T) { + tests := []struct { + in string + expectedErr error + }{ + { + in: "", + expectedErr: cb58.ErrBase58Decoding, + }, + { + in: "foo", + expectedErr: cb58.ErrMissingChecksum, + }, + { + in: "foobar", + expectedErr: cb58.ErrBadChecksum, + }, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + _, err := FromString(tt.in) + require.ErrorIs(t, err, tt.expectedErr) + }) + } +} + +func TestNodeIDMarshalJSON(t *testing.T) { + tests := []struct { + label string + in NodeID + out []byte + err error + }{ + { + "NodeID{}", + NodeID{}, + []byte(`"NodeID-111111111111111111116DBWJs"`), + nil, + }, + { + `ID("ava labs")`, + NodeID{'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'}, + []byte(`"NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz"`), + nil, + }, + } + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + require := require.New(t) + + out, err := tt.in.MarshalJSON() + require.ErrorIs(err, tt.err) + require.Equal(tt.out, out) + }) + } +} + +func TestNodeIDUnmarshalJSON(t *testing.T) { + tests := []struct { + label string + in []byte + out NodeID + expectedErr error + }{ + { + "NodeID{}", + []byte("null"), + NodeID{}, + nil, + }, + { + `NodeID("ava labs")`, + []byte(`"NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz"`), + NodeID{'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'}, + nil, + }, + { + "missing start quote", + []byte(`NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz"`), + NodeID{}, + errMissingQuotes, + }, + { + "missing end quote", + []byte(`"NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz`), + NodeID{}, + errMissingQuotes, + }, + { + "NodeID-", + []byte(`"NodeID-"`), + NodeID{}, + errShortNodeID, + }, + { + "NodeID-1", + []byte(`"NodeID-1"`), + NodeID{}, + cb58.ErrMissingChecksum, + }, + { + "NodeID-9tLMkeWFhWXd8QZc4rSiS5meuVXF5kRsz1", + []byte(`"NodeID-1"`), + NodeID{}, + cb58.ErrMissingChecksum, + }, + } + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + require := require.New(t) + + foo := NodeID{} + err := foo.UnmarshalJSON(tt.in) + require.ErrorIs(err, tt.expectedErr) + require.Equal(tt.out, foo) + }) + } +} + +func TestNodeIDString(t *testing.T) { + tests := []struct { + label string + id NodeID + expected string + }{ + {"NodeID{}", NodeID{}, "NodeID-111111111111111111116DBWJs"}, + {"NodeID{24}", NodeID{24}, "NodeID-3BuDc2d1Efme5Apba6SJ8w3Tz7qeh6mHt"}, + } + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + require.Equal(t, tt.expected, tt.id.String()) + }) + } +} + +func TestNodeIDMapMarshalling(t *testing.T) { + require := require.New(t) + + originalMap := map[NodeID]int{ + {'e', 'v', 'a', ' ', 'l', 'a', 'b', 's'}: 1, + {'a', 'v', 'a', ' ', 'l', 'a', 'b', 's'}: 2, + } + mapJSON, err := json.Marshal(originalMap) + require.NoError(err) + + var unmarshalledMap map[NodeID]int + require.NoError(json.Unmarshal(mapJSON, &unmarshalledMap)) + require.Equal(originalMap, unmarshalledMap) +} + +func TestNodeIDCompare(t *testing.T) { + tests := []struct { + a NodeID + b NodeID + expected int + }{ + { + a: NodeID{1}, + b: NodeID{0}, + expected: 1, + }, + { + a: NodeID{1}, + b: NodeID{1}, + expected: 0, + }, + { + a: NodeID{1, 0}, + b: NodeID{1, 2}, + expected: -1, + }, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%s_%s_%d", test.a, test.b, test.expected), func(t *testing.T) { + require := require.New(t) + + require.Equal(test.expected, test.a.Compare(test.b)) + require.Equal(-test.expected, test.b.Compare(test.a)) + }) + } +} diff --git a/utils/ids/short.go b/utils/ids/short.go new file mode 100644 index 00000000..7a266038 --- /dev/null +++ b/utils/ids/short.go @@ -0,0 +1,125 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package ids + +import ( + "bytes" + "encoding/hex" + "fmt" + "strings" + + "github.com/landslidenetwork/slide-sdk/utils" + "github.com/landslidenetwork/slide-sdk/utils/cb58" + "github.com/landslidenetwork/slide-sdk/utils/hashing" +) + +const ShortIDLen = 20 + +// ShortEmpty is a useful all zero value +var ( + ShortEmpty = ShortID{} + + _ utils.Sortable[ShortID] = ShortID{} +) + +// ShortID wraps a 20 byte hash as an identifier +type ShortID [ShortIDLen]byte + +// ToShortID attempt to convert a byte slice into an id +func ToShortID(bytes []byte) (ShortID, error) { + return hashing.ToHash160(bytes) +} + +// ShortFromString is the inverse of ShortID.String() +func ShortFromString(idStr string) (ShortID, error) { + bytes, err := cb58.Decode(idStr) + if err != nil { + return ShortID{}, err + } + return ToShortID(bytes) +} + +// ShortFromPrefixedString returns a ShortID assuming the cb58 format is +// prefixed +func ShortFromPrefixedString(idStr, prefix string) (ShortID, error) { + if !strings.HasPrefix(idStr, prefix) { + return ShortID{}, fmt.Errorf("ID: %s is missing the prefix: %s", idStr, prefix) + } + return ShortFromString(strings.TrimPrefix(idStr, prefix)) +} + +func (id ShortID) MarshalJSON() ([]byte, error) { + str, err := cb58.Encode(id[:]) + if err != nil { + return nil, err + } + return []byte(`"` + str + `"`), nil +} + +func (id *ShortID) UnmarshalJSON(b []byte) error { + str := string(b) + if str == nullStr { // If "null", do nothing + return nil + } else if len(str) < 2 { + return errMissingQuotes + } + + lastIndex := len(str) - 1 + if str[0] != '"' || str[lastIndex] != '"' { + return errMissingQuotes + } + + // Parse CB58 formatted string to bytes + bytes, err := cb58.Decode(str[1:lastIndex]) + if err != nil { + return fmt.Errorf("couldn't decode ID to bytes: %w", err) + } + *id, err = ToShortID(bytes) + return err +} + +func (id *ShortID) UnmarshalText(text []byte) error { + return id.UnmarshalJSON(text) +} + +// Bytes returns the 20 byte hash as a slice. It is assumed this slice is not +// modified. +func (id ShortID) Bytes() []byte { + return id[:] +} + +// Hex returns a hex encoded string of this id. +func (id ShortID) Hex() string { + return hex.EncodeToString(id.Bytes()) +} + +func (id ShortID) String() string { + // We assume that the maximum size of a byte slice that + // can be stringified is at least the length of an ID + str, _ := cb58.Encode(id.Bytes()) + return str +} + +// PrefixedString returns the String representation with a prefix added +func (id ShortID) PrefixedString(prefix string) string { + return prefix + id.String() +} + +func (id ShortID) MarshalText() ([]byte, error) { + return []byte(id.String()), nil +} + +func (id ShortID) Compare(other ShortID) int { + return bytes.Compare(id[:], other[:]) +} + +// ShortIDsToStrings converts an array of shortIDs to an array of their string +// representations +func ShortIDsToStrings(ids []ShortID) []string { + idStrs := make([]string, len(ids)) + for i, id := range ids { + idStrs[i] = id.String() + } + return idStrs +} diff --git a/utils/math/safe_math.go b/utils/math/safe_math.go new file mode 100644 index 00000000..1a3aa0bb --- /dev/null +++ b/utils/math/safe_math.go @@ -0,0 +1,60 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package math + +import ( + "errors" + + "golang.org/x/exp/constraints" +) + +var ( + ErrOverflow = errors.New("overflow") + ErrUnderflow = errors.New("underflow") + + // Deprecated: Add64 is deprecated. Use Add[uint64] instead. + Add64 = Add[uint64] + + // Deprecated: Mul64 is deprecated. Use Mul[uint64] instead. + Mul64 = Mul[uint64] +) + +// MaxUint returns the maximum value of an unsigned integer of type T. +func MaxUint[T constraints.Unsigned]() T { + return ^T(0) +} + +// Add returns: +// 1) a + b +// 2) If there is overflow, an error +func Add[T constraints.Unsigned](a, b T) (T, error) { + if a > MaxUint[T]()-b { + return 0, ErrOverflow + } + return a + b, nil +} + +// Sub returns: +// 1) a - b +// 2) If there is underflow, an error +func Sub[T constraints.Unsigned](a, b T) (T, error) { + if a < b { + return 0, ErrUnderflow + } + return a - b, nil +} + +// Mul returns: +// 1) a * b +// 2) If there is overflow, an error +func Mul[T constraints.Unsigned](a, b T) (T, error) { + if b != 0 && a > MaxUint[T]()/b { + return 0, ErrOverflow + } + return a * b, nil +} + +func AbsDiff[T constraints.Unsigned](a, b T) T { + return max(a, b) - min(a, b) +} diff --git a/utils/set/bits.go b/utils/set/bits.go new file mode 100644 index 00000000..a6e74fb6 --- /dev/null +++ b/utils/set/bits.go @@ -0,0 +1,102 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package set + +import ( + "encoding/hex" + "math/big" + "math/bits" +) + +// Bits is a bit-set backed by a big.Int +// Holds values ranging from [0, INT_MAX] (arch-dependent) +// Trying to use negative values will result in a panic. +// This implementation is NOT thread-safe. +type Bits struct { + bits *big.Int +} + +// NewBits returns a new instance of Bits with [bits] set to 1. +// +// Invariants: +// 1. Negative bits will cause a panic. +// 2. Duplicate bits are allowed but will cause a no-op. +func NewBits(bits ...int) Bits { + b := Bits{new(big.Int)} + for _, bit := range bits { + b.Add(bit) + } + return b +} + +// Add sets the [i]'th bit to 1 +func (b Bits) Add(i int) { + b.bits.SetBit(b.bits, i, 1) +} + +// Union performs the set union with another set. +// This adds all elements in [other] to [b] +func (b Bits) Union(other Bits) { + b.bits.Or(b.bits, other.bits) +} + +// Intersection performs the set intersection with another set +// This sets [b] to include only elements in both [b] and [other] +func (b Bits) Intersection(other Bits) { + b.bits.And(b.bits, other.bits) +} + +// Difference removes all the elements in [other] from this set +func (b Bits) Difference(other Bits) { + b.bits.AndNot(b.bits, other.bits) +} + +// Remove sets the [i]'th bit to 0 +func (b Bits) Remove(i int) { + b.bits.SetBit(b.bits, i, 0) +} + +// Clear empties out the bitset +func (b Bits) Clear() { + b.bits.SetUint64(0) +} + +// Contains returns true if the [i]'th bit is 1, and false otherwise +func (b Bits) Contains(i int) bool { + return b.bits.Bit(i) == 1 +} + +// BitLen returns the bit length of this bitset +func (b Bits) BitLen() int { + return b.bits.BitLen() +} + +// Len returns the amount of 1's in the bitset +// +// This is typically referred to as the "Hamming Weight" +// of a set of bits. +func (b Bits) Len() int { + result := 0 + for _, word := range b.bits.Bits() { + result += bits.OnesCount(uint(word)) + } + return result +} + +// Returns the byte representation of this bitset +func (b Bits) Bytes() []byte { + return b.bits.Bytes() +} + +// Inverse of Bits.Bytes() +func BitsFromBytes(bytes []byte) Bits { + return Bits{ + bits: new(big.Int).SetBytes(bytes), + } +} + +// String returns the hex representation of this bitset +func (b Bits) String() string { + return hex.EncodeToString(b.bits.Bytes()) +} diff --git a/utils/set/bits_test.go b/utils/set/bits_test.go new file mode 100644 index 00000000..c4c838d5 --- /dev/null +++ b/utils/set/bits_test.go @@ -0,0 +1,508 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package set + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Bits_New(t *testing.T) { + tests := []struct { + name string + bits []int + length int + }{ + { + name: "empty", + bits: []int{}, + length: 0, + }, + { + name: "populated", + bits: []int{0, 9, 99, 999, 9999}, + length: 10_000, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + b := NewBits(test.bits...) + + for _, bit := range test.bits { + require.True(b.Contains(bit)) + } + + require.Equal(test.length, b.BitLen()) + }) + } +} + +func Test_Bits_AddRemove(t *testing.T) { + tests := []struct { + name string + toAdd []int + toRemove []int + expectedElements []int + expectedLen int + }{ + { + name: "empty sets", + toAdd: []int{}, + toRemove: []int{}, + expectedElements: []int{}, // [] + expectedLen: 0, + }, + { + name: "add only", + toAdd: []int{0, 1, 2}, + toRemove: []int{}, + expectedElements: []int{0, 1, 2}, // [1, 1, 1] + expectedLen: 3, + }, + { + name: "remove left-most", + toAdd: []int{0, 1, 2}, + toRemove: []int{0}, + expectedElements: []int{1, 2}, // [1, 1, 0] + expectedLen: 3, + }, + { + name: "remove middle", + toAdd: []int{0, 1, 2}, + toRemove: []int{1}, + expectedElements: []int{2, 0}, // [1, 0, 1] + expectedLen: 3, + }, + { + name: "remove right-most", + toAdd: []int{0, 1, 2}, + toRemove: []int{2}, + expectedElements: []int{0, 1}, // [1, 1] + expectedLen: 2, + }, + { + name: "remove all", + toAdd: []int{0, 1, 2}, + toRemove: []int{0, 1, 2}, + expectedElements: []int{}, // [1, 1, 1] + expectedLen: 0, + }, + { + name: "remove reverse-order", + toAdd: []int{0, 1, 2}, + toRemove: []int{2, 1, 0}, + expectedElements: []int{}, // [] + expectedLen: 0, + }, + { + name: "remove non-existent elements", + toAdd: []int{0, 1, 2}, + toRemove: []int{3, 4, 5}, + expectedElements: []int{0, 1, 2}, // [1, 1, 1] + expectedLen: 3, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + b := NewBits() + + for _, add := range test.toAdd { + b.Add(add) + } + + for _, remove := range test.toRemove { + b.Remove(remove) + } + + for _, element := range test.expectedElements { + require.True(b.Contains(element)) + } + + require.Equal(test.expectedLen, b.BitLen()) + }) + } +} + +func Test_Bits_Union(t *testing.T) { + tests := []struct { + name string + left []int + right []int + expected []int + expectedLen int + }{ + { + name: "empty sets", + left: []int{}, + right: []int{}, + expected: []int{}, // [] + expectedLen: 0, + }, + { + name: "left and right are same", + left: []int{2, 1, 0}, + right: []int{2, 1, 0}, + expected: []int{2, 1, 0}, // [1, 1, 1] + expectedLen: 3, + }, + { + name: "left and no right", + left: []int{2, 1, 0}, + right: []int{}, + expected: []int{2, 1, 0}, // [1, 1, 1] + expectedLen: 3, + }, + { + name: "right and no left", + left: []int{}, + right: []int{2, 1, 0}, + expected: []int{2, 1, 0}, // [1, 1, 1] + expectedLen: 3, + }, + { + name: "left and right overlap", + left: []int{2, 1}, + right: []int{1, 0}, + expected: []int{2, 1, 0}, // [1, 1, 1] + expectedLen: 3, + }, + { + name: "left and right overlap different sizes", + left: []int{5, 3, 1}, + right: []int{8, 6, 4, 2, 0}, + expected: []int{8, 6, 5, 4, 3, 2, 1, 0}, // [1, 0, 1, 1, 1, 1, 1, 1, 1] + expectedLen: 9, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + b := NewBits() + + for _, add := range test.left { + b.Add(add) + } + for _, add := range test.right { + b.Add(add) + } + + for _, element := range test.expected { + require.True(b.Contains(element)) + } + + require.Equal(test.expectedLen, b.BitLen()) + }) + } +} + +func Test_Bits_Intersection(t *testing.T) { + tests := []struct { + name string + left []int + right []int + expected []int + expectedLen int + }{ + { + name: "empty sets", + left: []int{}, + right: []int{}, + expected: []int{}, // [] + expectedLen: 0, + }, + { + name: "left and right are same", + left: []int{2, 1, 0}, + right: []int{2, 1, 0}, + expected: []int{2, 1, 0}, // [1, 1, 1] + expectedLen: 3, + }, + { + name: "left and no right", + left: []int{2, 1, 0}, + right: []int{}, + expected: []int{}, // [] + expectedLen: 0, + }, + { + name: "right and no left", + left: []int{}, + right: []int{2, 1, 0}, + expected: []int{}, // [] + expectedLen: 0, + }, + { + name: "left and right overlap", + left: []int{2, 1}, + right: []int{1, 0}, + expected: []int{1}, // [1, 0] + expectedLen: 2, + }, + { + name: "left and right overlap different sizes", + left: []int{5, 3, 1}, + right: []int{8, 6, 4, 2, 0}, + expected: []int{}, // [] + expectedLen: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + left := NewBits() + right := NewBits() + for _, add := range test.left { + left.Add(add) + } + for _, add := range test.right { + right.Add(add) + } + + left.Intersection(right) + + expected := NewBits() + for _, element := range test.expected { + expected.Add(element) + } + + require.ElementsMatch(left.bits.Bits(), expected.bits.Bits()) + }) + } +} + +func Test_Bits_Difference(t *testing.T) { + tests := []struct { + name string + left []int + right []int + expected []int + expectedLen int + }{ + { + name: "empty sets", + left: []int{}, + right: []int{}, + expected: []int{}, // [] + expectedLen: 0, + }, + { + name: "left and right are same", + left: []int{2, 1, 0}, + right: []int{2, 1, 0}, + expected: []int{}, // [] + expectedLen: 0, + }, + { + name: "left and no right", + left: []int{2, 1, 0}, + right: []int{}, + expected: []int{2, 1, 0}, // [1, 1, 1] + expectedLen: 3, + }, + { + name: "right and no left", + left: []int{}, + right: []int{2, 1, 0}, + expected: []int{}, // [] + expectedLen: 3, + }, + { + name: "left and right overlap", + left: []int{2, 1}, + right: []int{1, 0}, + expected: []int{2}, // [1, 0, 0] + expectedLen: 3, + }, + { + name: "left and right overlap different sizes", + left: []int{5, 3, 1}, + right: []int{8, 6, 4, 2, 0}, + expected: []int{5, 3, 1}, // [1, 0, 1, 0, 1, 0] + expectedLen: 6, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + left := NewBits() + right := NewBits() + for _, add := range test.left { + left.Add(add) + } + for _, add := range test.right { + right.Add(add) + } + + left.Difference(right) + + expected := NewBits() + for _, element := range test.expected { + expected.Add(element) + } + + require.ElementsMatch(left.bits.Bits(), expected.bits.Bits()) + }) + } +} + +func Test_Bits_Clear(t *testing.T) { + tests := []struct { + name string + bitset []int + }{ + { + name: "empty", + bitset: []int{}, // [] + }, + { + name: "populated", + bitset: []int{5, 4, 3, 2, 1}, // [1, 1, 1, 1, 1] + }, + { + name: "populated - big", + bitset: []int{255}, // [1, 0...] + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + b := NewBits() + + for bit := range test.bitset { + b.Add(bit) + } + + b.Clear() + + require.Zero(b.BitLen()) + }) + } +} + +func Test_Bits_String(t *testing.T) { + tests := []struct { + name string + bitset []int + expected string + }{ + { + name: "empty", + bitset: []int{}, + expected: "", // [] + }, + { + name: "populated", + bitset: []int{7, 6, 5, 4, 3, 2, 1, 0}, // [1, 1, 1, 1, 1, 1, 1, 1] + expected: "ff", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + b := NewBits() + + for _, bit := range test.bitset { + b.Add(bit) + } + + require.Equal(test.expected, b.String()) + }) + } +} + +func Test_Bits_Len(t *testing.T) { + tests := []struct { + name string + bitset []int + expected int + }{ + { + name: "empty", + bitset: []int{}, // [] + expected: 0, + }, + { + name: "populated - more than one word", + bitset: []int{255}, // [1, 0...] + expected: 1, + }, + { + name: "populated - all ones", + bitset: []int{5, 4, 3, 2, 1, 0}, // [1, 1, 1, 1, 1, 1] + expected: 6, + }, + { + name: "populated - trailing zeroes", + bitset: []int{5, 4, 3}, // [1, 1, 1, 0, 0, 0] + expected: 3, + }, + { + name: "populated - interwoven 1", + bitset: []int{4, 2, 0}, // [1, 0, 1, 0, 1] + expected: 3, + }, + { + name: "populated - interwoven 2", + bitset: []int{3, 1}, // [1, 0, 1, 0] + expected: 2, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + b := NewBits() + + for _, bit := range test.bitset { + b.Add(bit) + } + + require.Equal(test.expected, b.Len()) + }) + } +} + +func Test_Bits_Bytes(t *testing.T) { + type test struct { + name string + elts []int + } + + tests := []test{ + { + name: "empty", + elts: []int{}, + }, + { + name: "single; element > 63", + elts: []int{1337}, + }, + { + name: "multiple", + elts: []int{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + b := NewBits(tt.elts...) + bytes := b.Bytes() + fromBytes := BitsFromBytes(bytes) + + require.Equal(len(tt.elts), fromBytes.Len()) + for _, elt := range tt.elts { + require.True(fromBytes.Contains(elt)) + } + }) + } +} diff --git a/utils/validators/state.go b/utils/validators/state.go new file mode 100644 index 00000000..b864655b --- /dev/null +++ b/utils/validators/state.go @@ -0,0 +1,116 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package validators + +import ( + "context" + "sync" + + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +var _ State = (*lockedState)(nil) + +// State allows the lookup of validator sets on specified subnets at the +// requested P-chain height. +type State interface { + // GetMinimumHeight returns the minimum height of the block still in the + // proposal window. + GetMinimumHeight(context.Context) (uint64, error) + // GetCurrentHeight returns the current height of the P-chain. + GetCurrentHeight(context.Context) (uint64, error) + + // GetSubnetID returns the subnetID of the provided chain. + GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) + + // GetValidatorSet returns the validators of the provided subnet at the + // requested P-chain height. + // The returned map should not be modified. + GetValidatorSet( + ctx context.Context, + height uint64, + subnetID ids.ID, + ) (map[ids.NodeID]*GetValidatorOutput, error) + + // GetCurrentValidatorSet returns the current validators of the provided subnet + // and the current P-Chain height. + // Map is keyed by ValidationID. + GetCurrentValidatorSet( + ctx context.Context, + subnetID ids.ID, + ) (map[ids.ID]*GetCurrentValidatorOutput, uint64, error) +} + +type lockedState struct { + lock sync.Locker + s State +} + +func NewLockedState(lock sync.Locker, s State) State { + return &lockedState{ + lock: lock, + s: s, + } +} + +func (s *lockedState) GetMinimumHeight(ctx context.Context) (uint64, error) { + s.lock.Lock() + defer s.lock.Unlock() + + return s.s.GetMinimumHeight(ctx) +} + +func (s *lockedState) GetCurrentHeight(ctx context.Context) (uint64, error) { + s.lock.Lock() + defer s.lock.Unlock() + + return s.s.GetCurrentHeight(ctx) +} + +func (s *lockedState) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { + s.lock.Lock() + defer s.lock.Unlock() + + return s.s.GetSubnetID(ctx, chainID) +} + +func (s *lockedState) GetValidatorSet( + ctx context.Context, + height uint64, + subnetID ids.ID, +) (map[ids.NodeID]*GetValidatorOutput, error) { + s.lock.Lock() + defer s.lock.Unlock() + + return s.s.GetValidatorSet(ctx, height, subnetID) +} + +func (s *lockedState) GetCurrentValidatorSet( + ctx context.Context, + subnetID ids.ID, +) (map[ids.ID]*GetCurrentValidatorOutput, uint64, error) { + s.lock.Lock() + defer s.lock.Unlock() + + return s.s.GetCurrentValidatorSet(ctx, subnetID) +} + +type noValidators struct { + State +} + +func NewNoValidatorsState(state State) State { + return &noValidators{ + State: state, + } +} + +func (*noValidators) GetValidatorSet(context.Context, uint64, ids.ID) (map[ids.NodeID]*GetValidatorOutput, error) { + return nil, nil +} + +func (n *noValidators) GetCurrentValidatorSet(ctx context.Context, _ ids.ID) (map[ids.ID]*GetCurrentValidatorOutput, uint64, error) { + height, err := n.GetCurrentHeight(ctx) + return nil, height, err +} diff --git a/utils/validators/validator.go b/utils/validators/validator.go new file mode 100644 index 00000000..83572748 --- /dev/null +++ b/utils/validators/validator.go @@ -0,0 +1,37 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package validators + +import ( + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +// Validator is a struct that contains the base values representing a validator +// of the Avalanche Network. +type Validator struct { + NodeID ids.NodeID + PublicKey *bls.PublicKey + TxID ids.ID + Weight uint64 +} + +// GetValidatorOutput is a struct that contains the publicly relevant values of +// a validator of the Avalanche Network for the output of GetValidator. +type GetValidatorOutput struct { + NodeID ids.NodeID + PublicKey *bls.PublicKey + Weight uint64 +} + +type GetCurrentValidatorOutput struct { + ValidationID ids.ID + NodeID ids.NodeID + PublicKey *bls.PublicKey + Weight uint64 + StartTime uint64 + MinNonce uint64 + IsActive bool + IsL1Validator bool +} diff --git a/utils/version/application.go b/utils/version/application.go new file mode 100644 index 00000000..f3740df3 --- /dev/null +++ b/utils/version/application.go @@ -0,0 +1,68 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package version + +import ( + "errors" + "fmt" + "sync" +) + +var ( + errDifferentMajor = errors.New("different major version") + + _ fmt.Stringer = (*Application)(nil) +) + +type Application struct { + Name string `json:"name" yaml:"name"` + Major int `json:"major" yaml:"major"` + Minor int `json:"minor" yaml:"minor"` + Patch int `json:"patch" yaml:"patch"` + + makeStrOnce sync.Once + str string +} + +// The only difference here between Application and Semantic is that Application +// prepends the client name rather than "v". +func (a *Application) String() string { + a.makeStrOnce.Do(a.initString) + return a.str +} + +func (a *Application) initString() { + a.str = fmt.Sprintf( + "%s/%d.%d.%d", + a.Name, + a.Major, + a.Minor, + a.Patch, + ) +} + +func (a *Application) Compatible(o *Application) error { + switch { + case a.Major > o.Major: + return errDifferentMajor + default: + return nil + } +} + +func (a *Application) Before(o *Application) bool { + return a.Compare(o) < 0 +} + +// Compare returns a positive number if s > o, 0 if s == o, or a negative number +// if s < o. +func (a *Application) Compare(o *Application) int { + if a.Major != o.Major { + return a.Major - o.Major + } + if a.Minor != o.Minor { + return a.Minor - o.Minor + } + return a.Patch - o.Patch +} diff --git a/utils/warp/aggregator/aggregator.go b/utils/warp/aggregator/aggregator.go new file mode 100644 index 00000000..a44340c9 --- /dev/null +++ b/utils/warp/aggregator/aggregator.go @@ -0,0 +1,174 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package aggregator + +import ( + "context" + "fmt" + + "github.com/cometbft/cometbft/libs/log" + warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/set" +) + +const WarpQuorumDenominator uint64 = 100 + +type AggregateSignatureResult struct { + // Weight of validators included in the aggregate signature. + SignatureWeight uint64 + // Total weight of all validators in the subnet. + TotalWeight uint64 + // The message with the aggregate signature. + Message *warputils.Message +} + +type signatureFetchResult struct { + sig *bls.Signature + index int + weight uint64 +} + +// Aggregator requests signatures from validators and +// aggregates them into a single signature. +type Aggregator struct { + logger log.Logger + validators []*warputils.Validator + totalWeight uint64 + client SignatureGetter +} + +// New returns a signature aggregator that will attempt to aggregate signatures from [validators]. +func New(client SignatureGetter, logger log.Logger, validators []*warputils.Validator, totalWeight uint64) *Aggregator { + return &Aggregator{ + client: client, + logger: logger, + validators: validators, + totalWeight: totalWeight, + } +} + +// Returns an aggregate signature over [unsignedMessage]. +// The returned signature's weight exceeds the threshold given by [quorumNum]. +func (a *Aggregator) AggregateSignatures(ctx context.Context, unsignedMessage *warputils.UnsignedMessage, quorumNum uint64) (*AggregateSignatureResult, error) { + // Create a child context to cancel signature fetching if we reach signature threshold. + signatureFetchCtx, signatureFetchCancel := context.WithCancel(ctx) + defer signatureFetchCancel() + + // Fetch signatures from validators concurrently. + signatureFetchResultChan := make(chan *signatureFetchResult) + for i, validator := range a.validators { + var ( + i = i + validator = validator + // TODO: update from a single nodeID to the original slice and use extra nodeIDs as backup. + nodeID = validator.NodeIDs[0] + ) + go func() { + a.logger.Debug("Fetching warp signature", + "nodeID", nodeID, + "index", i, + "msgID", unsignedMessage.ID(), + ) + + signature, err := a.client.GetSignature(signatureFetchCtx, nodeID, unsignedMessage) + if err != nil { + a.logger.Debug("Failed to fetch warp signature", + "nodeID", nodeID, + "index", i, + "err", err, + "msgID", unsignedMessage.ID(), + ) + signatureFetchResultChan <- nil + return + } + + a.logger.Debug("Retrieved warp signature", + "nodeID", nodeID, + "msgID", unsignedMessage.ID(), + "index", i, + ) + + if !bls.Verify(validator.PublicKey, signature, unsignedMessage.Bytes()) { + a.logger.Debug("Failed to verify warp signature", + "nodeID", nodeID, + "index", i, + "msgID", unsignedMessage.ID(), + ) + signatureFetchResultChan <- nil + return + } + + signatureFetchResultChan <- &signatureFetchResult{ + sig: signature, + index: i, + weight: validator.Weight, + } + }() + } + + var ( + signatures = make([]*bls.Signature, 0, len(a.validators)) + signersBitset = set.NewBits() + signaturesWeight = uint64(0) + signaturesPassedThreshold = false + ) + + for i := 0; i < len(a.validators); i++ { + signatureFetchResult := <-signatureFetchResultChan + if signatureFetchResult == nil { + continue + } + + signatures = append(signatures, signatureFetchResult.sig) + signersBitset.Add(signatureFetchResult.index) + signaturesWeight += signatureFetchResult.weight + a.logger.Debug("Updated weight", + "totalWeight", signaturesWeight, + "addedWeight", signatureFetchResult.weight, + "msgID", unsignedMessage.ID(), + ) + + // If the signature weight meets the requested threshold, cancel signature fetching + if err := warputils.VerifyWeight(signaturesWeight, a.totalWeight, quorumNum, WarpQuorumDenominator); err == nil { + a.logger.Debug("Verify weight passed, exiting aggregation early", + "quorumNum", quorumNum, + "totalWeight", a.totalWeight, + "signatureWeight", signaturesWeight, + "msgID", unsignedMessage.ID(), + ) + signatureFetchCancel() + signaturesPassedThreshold = true + break + } + } + + // If I failed to fetch sufficient signature stake, return an error + if !signaturesPassedThreshold { + return nil, warputils.ErrInsufficientWeight + } + + // Otherwise, return the aggregate signature + aggregateSignature, err := bls.AggregateSignatures(signatures) + if err != nil { + return nil, fmt.Errorf("failed to aggregate BLS signatures: %w", err) + } + + warpSignature := &warputils.BitSetSignature{ + Signers: signersBitset.Bytes(), + } + copy(warpSignature.Signature[:], bls.SignatureToBytes(aggregateSignature)) + + msg, err := warputils.NewMessage(unsignedMessage, warpSignature) + if err != nil { + return nil, fmt.Errorf("failed to construct warp message: %w", err) + } + + return &AggregateSignatureResult{ + Message: msg, + SignatureWeight: signaturesWeight, + TotalWeight: a.totalWeight, + }, nil +} diff --git a/utils/warp/aggregator/aggregator_test.go b/utils/warp/aggregator/aggregator_test.go new file mode 100644 index 00000000..824d6ce0 --- /dev/null +++ b/utils/warp/aggregator/aggregator_test.go @@ -0,0 +1,10 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package aggregator + +import "testing" + +func TestAggregateSignatures(t *testing.T) { + // TODO: implement test +} diff --git a/utils/warp/aggregator/signature_getter.go b/utils/warp/aggregator/signature_getter.go new file mode 100644 index 00000000..47c0221a --- /dev/null +++ b/utils/warp/aggregator/signature_getter.go @@ -0,0 +1,18 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package aggregator + +import ( + "context" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/ids" + avalancheWarp "github.com/landslidenetwork/slide-sdk/utils/warp" +) + +// SignatureGetter defines the minimum network interface to perform signature aggregation +type SignatureGetter interface { + // GetSignature attempts to fetch a BLS Signature from [nodeID] for [unsignedWarpMessage] + GetSignature(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *avalancheWarp.UnsignedMessage) (*bls.Signature, error) +} diff --git a/utils/warp/message.go b/utils/warp/message.go new file mode 100644 index 00000000..edd8a71c --- /dev/null +++ b/utils/warp/message.go @@ -0,0 +1,59 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "fmt" +) + +// Message defines the standard format for a Warp message. +type Message struct { + UnsignedMessage `serialize:"true"` + Signature Signature `serialize:"true"` + + bytes []byte +} + +// NewMessage creates a new *Message and initializes it. +func NewMessage( + unsignedMsg *UnsignedMessage, + signature Signature, +) (*Message, error) { + msg := &Message{ + UnsignedMessage: *unsignedMsg, + Signature: signature, + } + return msg, msg.Initialize() +} + +// ParseMessage converts a slice of bytes into an initialized *Message. +func ParseMessage(b []byte) (*Message, error) { + msg := &Message{ + bytes: b, + } + err := Codec.Unmarshal(b, msg) + if err != nil { + return nil, err + } + return msg, msg.UnsignedMessage.Initialize() +} + +// Initialize recalculates the result of Bytes(). It does not call Initialize() +// on the UnsignedMessage. +func (m *Message) Initialize() error { + bytes, err := Codec.Marshal(m) + m.bytes = bytes + return err +} + +// Bytes returns the binary representation of this message. It assumes that the +// message is initialized from either New, Parse, or an explicit call to +// Initialize. +func (m *Message) Bytes() []byte { + return m.bytes +} + +func (m *Message) String() string { + return fmt.Sprintf("WarpMessage(%s, %s)", m.UnsignedMessage.Bytes(), m.Signature) +} diff --git a/utils/warp/message_test.go b/utils/warp/message_test.go new file mode 100644 index 00000000..d3d9f8fe --- /dev/null +++ b/utils/warp/message_test.go @@ -0,0 +1,44 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +func TestMessage(t *testing.T) { + require := require.New(t) + + unsignedMsg, err := NewUnsignedMessage( + UnitTestID, + ids.GenerateTestID(), + []byte("payload"), + ) + require.NoError(err) + + msg, err := NewMessage( + unsignedMsg, + &BitSetSignature{ + Signers: []byte{1, 2, 3}, + Signature: [bls.SignatureLen]byte{4, 5, 6}, + }, + ) + require.NoError(err) + + msgBytes := msg.Bytes() + msg2, err := ParseMessage(msgBytes) + require.NoError(err) + require.Equal(msg, msg2) +} + +func TestParseMessageJunk(t *testing.T) { + bytes := []byte{0, 1, 2, 3, 4, 5, 6, 7} + _, err := ParseMessage(bytes) + require.Error(t, err) +} diff --git a/utils/warp/messages/payload.go b/utils/warp/messages/payload.go index 85aba96e..e9ab55ac 100644 --- a/utils/warp/messages/payload.go +++ b/utils/warp/messages/payload.go @@ -20,7 +20,7 @@ type Payload interface { // Signable is an optional interface that payloads can implement to allow // on-the-fly signing of incoming messages by the warp backend. type Signable interface { - VerifyMesssage(sourceAddress []byte) error + VerifyMessage(sourceAddress []byte) error } func Parse(bytes []byte) (Payload, error) { diff --git a/utils/warp/signature.go b/utils/warp/signature.go index 9d35d6cd..03b82975 100644 --- a/utils/warp/signature.go +++ b/utils/warp/signature.go @@ -4,12 +4,177 @@ package warp import ( + "context" + "errors" + "fmt" + "math/big" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/set" + "github.com/landslidenetwork/slide-sdk/utils/validators" blst "github.com/supranational/blst/bindings/go" ) +var ( + _ Signature = (*BitSetSignature)(nil) + + ErrInvalidBitSet = errors.New("bitset is invalid") + ErrInsufficientWeight = errors.New("signature weight is insufficient") + ErrInvalidSignature = errors.New("signature is invalid") + ErrParseSignature = errors.New("failed to parse signature") +) + +type Signature interface { + fmt.Stringer + + // NumSigners is the number of [bls.PublicKeys] that participated in the + // [Signature]. This is exposed because users of these signatures typically + // impose a verification fee that is a function of the number of + // signers. + NumSigners() (int, error) + + // Verify that this signature was signed by at least [quorumNum]/[quorumDen] + // of the validators of [msg.SourceChainID] at [pChainHeight]. + // + // Invariant: [msg] is correctly initialized. + Verify( + ctx context.Context, + msg *UnsignedMessage, + networkID uint32, + pChainState validators.State, + pChainHeight uint64, + quorumNum uint64, + quorumDen uint64, + ) error +} + type BitSetSignature struct { // Signers is a big-endian byte slice encoding which validators signed this // message. Signers []byte `serialize:"true"` Signature [blst.BLST_P2_COMPRESS_BYTES]byte `serialize:"true"` } + +func (s *BitSetSignature) NumSigners() (int, error) { + // Parse signer bit vector + // + // We assert that the length of [signerIndices.Bytes()] is equal + // to [len(s.Signers)] to ensure that [s.Signers] does not have + // any unnecessary zero-padding to represent the [set.Bits]. + signerIndices := set.BitsFromBytes(s.Signers) + if len(signerIndices.Bytes()) != len(s.Signers) { + return 0, ErrInvalidBitSet + } + return signerIndices.Len(), nil +} + +func (s *BitSetSignature) Verify( + ctx context.Context, + msg *UnsignedMessage, + networkID uint32, + pChainState validators.State, + pChainHeight uint64, + quorumNum uint64, + quorumDen uint64, +) error { + if msg.NetworkID != networkID { + return ErrWrongNetworkID + } + + subnetID, err := pChainState.GetSubnetID(ctx, msg.SourceChainID) + if err != nil { + return err + } + + // Get the validator set at the given height. + vdrSet, err := pChainState.GetValidatorSet(ctx, pChainHeight, subnetID) + if err != nil { + return fmt.Errorf("failed to get validator set: %w", err) + } + + // Convert the validator set into the canonical ordering. + vdrs, totalWeight, err := FlattenValidatorSet(vdrSet) + if err != nil { + return err + } + + // Parse signer bit vector + // + // We assert that the length of [signerIndices.Bytes()] is equal + // to [len(s.Signers)] to ensure that [s.Signers] does not have + // any unnecessary zero-padding to represent the [set.Bits]. + signerIndices := set.BitsFromBytes(s.Signers) + if len(signerIndices.Bytes()) != len(s.Signers) { + return ErrInvalidBitSet + } + + // Get the validators that (allegedly) signed the message. + signers, err := FilterValidators(signerIndices, vdrs) + if err != nil { + return err + } + + // Because [signers] is a subset of [vdrs], this can never error. + sigWeight, _ := SumWeight(signers) + + // Make sure the signature's weight is sufficient. + err = VerifyWeight( + sigWeight, + totalWeight, + quorumNum, + quorumDen, + ) + if err != nil { + return err + } + + // Parse the aggregate signature + aggSig, err := bls.SignatureFromBytes(s.Signature[:]) + if err != nil { + return fmt.Errorf("%w: %w", ErrParseSignature, err) + } + + // Create the aggregate public key + aggPubKey, err := AggregatePublicKeys(signers) + if err != nil { + return err + } + + // Verify the signature + unsignedBytes := msg.Bytes() + if !bls.Verify(aggPubKey, aggSig, unsignedBytes) { + return ErrInvalidSignature + } + return nil +} + +func (s *BitSetSignature) String() string { + return fmt.Sprintf("BitSetSignature(Signers = %x, Signature = %x)", s.Signers, s.Signature) +} + +// VerifyWeight returns [nil] if [sigWeight] is at least [quorumNum]/[quorumDen] +// of [totalWeight]. +// If [sigWeight >= totalWeight * quorumNum / quorumDen] then return [nil] +func VerifyWeight( + sigWeight uint64, + totalWeight uint64, + quorumNum uint64, + quorumDen uint64, +) error { + // Verifies that quorumNum * totalWeight <= quorumDen * sigWeight + scaledTotalWeight := new(big.Int).SetUint64(totalWeight) + scaledTotalWeight.Mul(scaledTotalWeight, new(big.Int).SetUint64(quorumNum)) + scaledSigWeight := new(big.Int).SetUint64(sigWeight) + scaledSigWeight.Mul(scaledSigWeight, new(big.Int).SetUint64(quorumDen)) + if scaledTotalWeight.Cmp(scaledSigWeight) == 1 { + return fmt.Errorf( + "%w: %d*%d > %d*%d", + ErrInsufficientWeight, + quorumNum, + totalWeight, + quorumDen, + sigWeight, + ) + } + return nil +} diff --git a/utils/warp/validator.go b/utils/warp/validator.go new file mode 100644 index 00000000..a715198e --- /dev/null +++ b/utils/warp/validator.go @@ -0,0 +1,131 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "bytes" + "errors" + "fmt" + + "github.com/landslidenetwork/slide-sdk/utils/set" + + "golang.org/x/exp/maps" + + "github.com/landslidenetwork/slide-sdk/utils" + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/ids" + "github.com/landslidenetwork/slide-sdk/utils/math" + "github.com/landslidenetwork/slide-sdk/utils/validators" +) + +var ( + _ utils.Sortable[*Validator] = (*Validator)(nil) + + ErrUnknownValidator = errors.New("unknown validator") + ErrWeightOverflow = errors.New("weight overflowed") +) + +type Validator struct { + PublicKey *bls.PublicKey + PublicKeyBytes []byte + Weight uint64 + NodeIDs []ids.NodeID +} + +func (v *Validator) Compare(o *Validator) int { + return bytes.Compare(v.PublicKeyBytes, o.PublicKeyBytes) +} + +// FlattenValidatorSet converts the provided [vdrSet] into a canonical ordering. +// Also returns the total weight of the validator set. +func FlattenValidatorSet(vdrSet map[ids.NodeID]*validators.GetValidatorOutput) ([]*Validator, uint64, error) { + var ( + vdrs = make(map[string]*Validator, len(vdrSet)) + totalWeight uint64 + err error + ) + for _, vdr := range vdrSet { + totalWeight, err = math.Add(totalWeight, vdr.Weight) + if err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrWeightOverflow, err) + } + + if vdr.PublicKey == nil { + continue + } + + pkBytes := bls.PublicKeyToUncompressedBytes(vdr.PublicKey) + uniqueVdr, ok := vdrs[string(pkBytes)] + if !ok { + uniqueVdr = &Validator{ + PublicKey: vdr.PublicKey, + PublicKeyBytes: pkBytes, + } + vdrs[string(pkBytes)] = uniqueVdr + } + + uniqueVdr.Weight += vdr.Weight // Impossible to overflow here + uniqueVdr.NodeIDs = append(uniqueVdr.NodeIDs, vdr.NodeID) + } + + // Sort validators by public key + vdrList := maps.Values(vdrs) + utils.Sort(vdrList) + return vdrList, totalWeight, nil +} + +// FilterValidators returns the validators in [vdrs] whose bit is set to 1 in +// [indices]. +// +// Returns an error if [indices] references an unknown validator. +func FilterValidators( + indices set.Bits, + vdrs []*Validator, +) ([]*Validator, error) { + // Verify that all alleged signers exist + if indices.BitLen() > len(vdrs) { + return nil, fmt.Errorf( + "%w: NumIndices (%d) >= NumFilteredValidators (%d)", + ErrUnknownValidator, + indices.BitLen()-1, // -1 to convert from length to index + len(vdrs), + ) + } + + filteredVdrs := make([]*Validator, 0, len(vdrs)) + for i, vdr := range vdrs { + if !indices.Contains(i) { + continue + } + + filteredVdrs = append(filteredVdrs, vdr) + } + return filteredVdrs, nil +} + +// SumWeight returns the total weight of the provided validators. +func SumWeight(vdrs []*Validator) (uint64, error) { + var ( + weight uint64 + err error + ) + for _, vdr := range vdrs { + weight, err = math.Add(weight, vdr.Weight) + if err != nil { + return 0, fmt.Errorf("%w: %w", ErrWeightOverflow, err) + } + } + return weight, nil +} + +// AggregatePublicKeys returns the public key of the provided validators. +// +// Invariant: All of the public keys in [vdrs] are valid. +func AggregatePublicKeys(vdrs []*Validator) (*bls.PublicKey, error) { + pks := make([]*bls.PublicKey, len(vdrs)) + for i, vdr := range vdrs { + pks[i] = vdr.PublicKey + } + return bls.AggregatePublicKeys(pks) +} diff --git a/utils/warp/validator_test.go b/utils/warp/validator_test.go new file mode 100644 index 00000000..d5f1fc26 --- /dev/null +++ b/utils/warp/validator_test.go @@ -0,0 +1,69 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "math" + "testing" +) + +func TestGetCanonicalValidatorSet(t *testing.T) { + //TODO: implement test +} + +func TestFilterValidators(t *testing.T) { + //TODO: implement test +} + +func TestSumWeight(t *testing.T) { + vdr0 := &Validator{ + Weight: 1, + } + vdr1 := &Validator{ + Weight: 2, + } + vdr2 := &Validator{ + Weight: math.MaxUint64, + } + + type test struct { + name string + vdrs []*Validator + expectedSum uint64 + expectedErr error + } + + tests := []test{ + { + name: "empty", + vdrs: []*Validator{}, + expectedSum: 0, + }, + { + name: "one", + vdrs: []*Validator{vdr0}, + expectedSum: 1, + }, + { + name: "two", + vdrs: []*Validator{vdr0, vdr1}, + expectedSum: 3, + }, + { + name: "overflow", + vdrs: []*Validator{vdr0, vdr2}, + expectedErr: ErrWeightOverflow, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + //TODO: implement test + }) + } +} + +func BenchmarkGetCanonicalValidatorSet(b *testing.B) { + //TODO: implement test +} diff --git a/utils/warp/validators/state.go b/utils/warp/validators/state.go new file mode 100644 index 00000000..19ff807a --- /dev/null +++ b/utils/warp/validators/state.go @@ -0,0 +1,58 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package validators + +import ( + "context" + + "github.com/landslidenetwork/slide-sdk/utils/ids" + "github.com/landslidenetwork/slide-sdk/utils/validators" +) + +var ( + _ validators.State = (*State)(nil) + PrimaryNetworkID = ids.Empty + PlatformChainID = ids.Empty +) + +// State provides a special case used to handle Avalanche Warp Message verification for messages sent +// from the Primary Network. Subnets have strictly fewer validators than the Primary Network, so we require +// signatures from a threshold of the RECEIVING subnet validator set rather than the full Primary Network +// since the receiving subnet already relies on a majority of its validators being correct. +type State struct { + validators.State + mySubnetID ids.ID + sourceChainID ids.ID + requirePrimaryNetworkSigners bool +} + +// NewState returns a wrapper of [validators.State] which special cases the handling of the Primary Network. +// +// The wrapped state will return the [mySubnetID's] validator set instead of the Primary Network when +// the Primary Network SubnetID is passed in. +func NewState(state validators.State, mySubnetID ids.ID, sourceChainID ids.ID, requirePrimaryNetworkSigners bool) *State { + return &State{ + State: state, + mySubnetID: mySubnetID, + sourceChainID: sourceChainID, + requirePrimaryNetworkSigners: requirePrimaryNetworkSigners, + } +} + +func (s *State) GetValidatorSet( + ctx context.Context, + height uint64, + subnetID ids.ID, +) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + // If the subnetID is anything other than the Primary Network, or Primary + // Network signers are required (except P-Chain), this is a direct passthrough. + usePrimary := s.requirePrimaryNetworkSigners && s.sourceChainID != PlatformChainID + if usePrimary || subnetID != PrimaryNetworkID { + return s.State.GetValidatorSet(ctx, height, subnetID) + } + + // If the requested subnet is the primary network, then we return the validator + // set for the Subnet that is receiving the message instead. + return s.State.GetValidatorSet(ctx, height, s.mySubnetID) +} diff --git a/utils/warp/validators/state_test.go b/utils/warp/validators/state_test.go new file mode 100644 index 00000000..66721406 --- /dev/null +++ b/utils/warp/validators/state_test.go @@ -0,0 +1,12 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package validators + +import ( + "testing" +) + +func TestGetValidatorSetPrimaryNetwork(t *testing.T) { + //TODO: implement test +} diff --git a/vm/rpc.go b/vm/rpc.go index 46db638d..cc95b2c3 100644 --- a/vm/rpc.go +++ b/vm/rpc.go @@ -69,6 +69,7 @@ func (rpc *RPC) Routes() map[string]*jsonrpc.RPCFunc { "warp_get_message": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessage, "messageID"), "warp_get_message_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessageSignature, "messageID"), "warp_get_message_aggregate_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessageAggregateSignature, "messageID,quorumNum,subnetID"), + "warp_get_block_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetBlockSignature, "blockID"), "warp_get_block_aggregate_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetBlockAggregateSignature, "blockID,quorumNum,subnetID"), } } diff --git a/vm/types/config.go b/vm/types/config.go index 5adaba1f..8ea233fb 100644 --- a/vm/types/config.go +++ b/vm/types/config.go @@ -26,12 +26,13 @@ type ( // VMConfig contains the configuration of the VM. VMConfig struct { - NetworkName string `json:"network_name"` - TimeoutBroadcastTxCommit uint16 `json:"timeout_broadcast_tx_commit"` - ConsensusParams ConsensusParams `json:"consensus_params"` - MaxSubscriptionClients int `json:"max_subscription_clients"` - MaxSubscriptionsPerClient int `json:"max_subscriptions_per_client"` - BLSSecretKey []byte `json:"bls_secret_key"` + NetworkName string `json:"network_name"` + TimeoutBroadcastTxCommit uint16 `json:"timeout_broadcast_tx_commit"` + ConsensusParams ConsensusParams `json:"consensus_params"` + MaxSubscriptionClients int `json:"max_subscription_clients"` + MaxSubscriptionsPerClient int `json:"max_subscriptions_per_client"` + BLSSecretKey []byte `json:"bls_secret_key"` + AddressBook map[string]string `json:"address_book"` } // ConsensusParams contains consensus critical parameters that determine the diff --git a/vm/vm.go b/vm/vm.go index aeeefcf0..8c0ce565 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -12,6 +12,8 @@ import ( "sync" "time" + "github.com/landslidenetwork/slide-sdk/grpcutils/gvalidators" + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" warputils "github.com/landslidenetwork/slide-sdk/utils/warp" "github.com/landslidenetwork/slide-sdk/warp" @@ -48,6 +50,7 @@ import ( httppb "github.com/landslidenetwork/slide-sdk/proto/http" messengerpb "github.com/landslidenetwork/slide-sdk/proto/messenger" "github.com/landslidenetwork/slide-sdk/proto/rpcdb" + validatorstatepb "github.com/landslidenetwork/slide-sdk/proto/validatorstate" vmpb "github.com/landslidenetwork/slide-sdk/proto/vm" "github.com/landslidenetwork/slide-sdk/utils/ids" vmtypes "github.com/landslidenetwork/slide-sdk/vm/types" @@ -58,7 +61,8 @@ import ( ) const ( - genesisChunkSize = 16 * 1024 * 1024 // 16 + genesisChunkSize = 16 * 1024 * 1024 // 16 + requirePrimaryNetworkSigners = true ) var ( @@ -243,6 +247,8 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest msgClient := messengerpb.NewMessengerClient(vm.clientConn) + validatorStateClient := gvalidators.NewClient(validatorstatepb.NewValidatorStateClient(vm.clientConn)) + vm.toEngine = make(chan messengerpb.Message, 1) vm.closed = make(chan struct{}) go func() { @@ -483,7 +489,19 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest if err != nil { return nil, err } - vm.warpService = NewAPI(vm, req.NetworkId, subnetID, chainID, vm.warpBackend) + rpcClients := make(map[ids.NodeID]warp.Client) + for id, nodeURI := range vm.config.AddressBook { + nodeID, err := ids.ToNodeID([]byte(id)) + if err != nil { + return nil, err + } + rpcClient, err := warp.NewClient(nodeURI, string(req.ChainId)) + if err != nil { + return nil, err + } + rpcClients[nodeID] = rpcClient + } + vm.warpService = NewAPI(vm, vm.logger, req.NetworkId, validatorStateClient, subnetID, chainID, vm.warpBackend, rpcClients, requirePrimaryNetworkSigners) return &vmpb.InitializeResponse{ LastAcceptedId: blk.Hash(), diff --git a/vm/warp_service.go b/vm/warp_service.go index 4621bda0..cf82748e 100644 --- a/vm/warp_service.go +++ b/vm/warp_service.go @@ -5,18 +5,25 @@ package vm import ( "context" + "errors" "fmt" tmbytes "github.com/cometbft/cometbft/libs/bytes" + "github.com/cometbft/cometbft/libs/log" rpctypes "github.com/cometbft/cometbft/rpc/jsonrpc/types" "github.com/landslidenetwork/slide-sdk/utils/ids" + "github.com/landslidenetwork/slide-sdk/utils/validators" warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + "github.com/landslidenetwork/slide-sdk/utils/warp/aggregator" "github.com/landslidenetwork/slide-sdk/utils/warp/payload" + warpValidators "github.com/landslidenetwork/slide-sdk/utils/warp/validators" "github.com/landslidenetwork/slide-sdk/warp" ) const failedParseIDPattern = "failed to parse ID %s with error %w" +var errNoValidators = errors.New("cannot aggregate signatures from subnet with no validators") + type ResultGetMessage struct { Message []byte `json:"message"` } @@ -28,18 +35,31 @@ type ResultGetMessageSignature struct { // API introduces snowman specific functionality to the evm type API struct { vm *LandslideVM + logger log.Logger networkID uint32 + valState *warpValidators.State sourceSubnetID, sourceChainID ids.ID backend warp.Backend + signatureGetter aggregator.SignatureGetter + // TODO: investigate necessity to set up value according to validation of Primary Network + // requirePrimaryNetworkSigners returns true if warp messages from the primary + // network must be signed by the primary network validators. + // This is necessary when the subnet is not validating the primary network. + requirePrimaryNetworkSigners bool } -func NewAPI(vm *LandslideVM, networkID uint32, sourceSubnetID ids.ID, sourceChainID ids.ID, backend warp.Backend) *API { +func NewAPI(vm *LandslideVM, logger log.Logger, networkID uint32, state validators.State, sourceSubnetID ids.ID, sourceChainID ids.ID, + backend warp.Backend, rpcClients map[ids.NodeID]warp.Client, requirePrimaryNetworkSigners bool) *API { return &API{ - vm: vm, - networkID: networkID, - sourceSubnetID: sourceSubnetID, - sourceChainID: sourceChainID, - backend: backend, + vm: vm, + logger: logger, + networkID: networkID, + valState: warpValidators.NewState(state, sourceSubnetID, sourceChainID, requirePrimaryNetworkSigners), + sourceSubnetID: sourceSubnetID, + sourceChainID: sourceChainID, + backend: backend, + signatureGetter: warp.NewAPIFetcher(rpcClients), + requirePrimaryNetworkSigners: requirePrimaryNetworkSigners, } } @@ -82,6 +102,15 @@ func (a *API) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID return a.aggregateSignatures(ctx, unsignedMessage, quorumNum, subnetIDStr) } +// GetBlockSignature returns the BLS signature associated with a blockID. +func (a *API) GetBlockSignature(ctx context.Context, blockID ids.ID) (tmbytes.HexBytes, error) { + signature, err := a.backend.GetBlockSignature(blockID) + if err != nil { + return nil, fmt.Errorf("failed to get signature for block %s with error %w", blockID, err) + } + return signature, nil +} + // GetBlockAggregateSignature fetches the aggregate signature for the requested [blockID] func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64, subnetIDStr string) (signedMessageBytes tmbytes.HexBytes, err error) { blockHashPayload, err := payload.NewHash(blockID) @@ -97,6 +126,46 @@ func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, qu } func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warputils.UnsignedMessage, quorumNum uint64, subnetIDStr string) (tmbytes.HexBytes, error) { - // TODO: implement aggregateSignatures - return nil, nil + subnetID := a.sourceSubnetID + if len(subnetIDStr) > 0 { + sid, err := ids.FromString(subnetIDStr) + if err != nil { + return nil, fmt.Errorf("failed to parse subnetID: %q", subnetIDStr) + } + subnetID = sid + } + pChainHeight, err := a.valState.GetCurrentHeight(ctx) + if err != nil { + return nil, err + } + // Get the validator set at the given height. + vdrSet, err := a.valState.GetValidatorSet(ctx, pChainHeight, subnetID) + if err != nil { + return nil, fmt.Errorf("failed to get validator set: %w", err) + } + + // Convert the validator set into the canonical ordering. + validators, totalWeight, err := warputils.FlattenValidatorSet(vdrSet) + if err != nil { + return nil, fmt.Errorf("failed to convert the validator set into the canonical ordering: %w", err) + } + if len(validators) == 0 { + return nil, fmt.Errorf("%w (SubnetID: %s, Height: %d)", errNoValidators, subnetID, pChainHeight) + } + + a.logger.Debug("Fetching signature", + "sourceSubnetID", subnetID, + "height", pChainHeight, + "numValidators", len(validators), + "totalWeight", totalWeight, + ) + agg := aggregator.New(a.signatureGetter, a.logger, validators, totalWeight) + signatureResult, err := agg.AggregateSignatures(ctx, unsignedMessage, quorumNum) + if err != nil { + return nil, err + } + // TODO: return the signature and total weight as well to the caller for more complete details + // Need to decide on the best UI for this and write up documentation with the potential + // gotchas that could impact signed messages becoming invalid. + return signatureResult.Message.Bytes(), nil } diff --git a/warp/backend.go b/warp/backend.go index 74e1aedb..3938f21e 100644 --- a/warp/backend.go +++ b/warp/backend.go @@ -19,6 +19,8 @@ type Backend interface { AddMessage(unsignedMessage *warputils.UnsignedMessage) error // GetMessageSignature returns the signature of the requested message. GetMessageSignature(message *warputils.UnsignedMessage) ([]byte, error) + // GetBlockSignature returns the signature of a hash payload containing blockID if it's the ID of an accepted block. + GetBlockSignature(blockID ids.ID) ([]byte, error) // GetMessage retrieves the [unsignedMessage] from the warp backend database if available // TODO: After E-Upgrade, the backend no longer needs to store the mapping from messageHash // to unsignedMessage (and this method can be removed). @@ -88,6 +90,28 @@ func (b *backend) GetMessageSignature(unsignedMessage *warputils.UnsignedMessage return b.signMessage(unsignedMessage) } +func (b *backend) GetBlockSignature(blockID ids.ID) ([]byte, error) { + b.logger.Debug("Getting block from backend", "blockID", blockID) + + blockHashPayload, err := payload.NewHash(blockID) + if err != nil { + return nil, fmt.Errorf("failed to create new block hash payload: %w", err) + } + + unsignedMessage, err := warputils.NewUnsignedMessage(b.networkID, b.sourceChainID, blockHashPayload.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to create new unsigned warp message: %w", err) + } + + //TODO: validate block by hash + + sig, err := b.warpSigner.Sign(unsignedMessage) + if err != nil { + return nil, fmt.Errorf("failed to sign warp message: %w", err) + } + return sig, nil +} + func (b *backend) ValidateMessage(unsignedMessage *warputils.UnsignedMessage) error { // Known on-chain messages should be signed if _, err := b.GetMessage(unsignedMessage.ID()); err == nil { @@ -113,7 +137,7 @@ func (b *backend) ValidateMessage(unsignedMessage *warputils.UnsignedMessage) er } // Check if the message should be signed according to its type - if err := signable.VerifyMesssage(addressedCall.SourceAddress); err != nil { + if err := signable.VerifyMessage(addressedCall.SourceAddress); err != nil { return fmt.Errorf("failed to verify Signable message: %w", err) } return nil diff --git a/warp/client.go b/warp/client.go new file mode 100644 index 00000000..a9773bd7 --- /dev/null +++ b/warp/client.go @@ -0,0 +1,90 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "context" + "fmt" + + tmbytes "github.com/cometbft/cometbft/libs/bytes" + + jsonrpc "github.com/cometbft/cometbft/rpc/jsonrpc/client" + "github.com/landslidenetwork/slide-sdk/utils/ids" +) + +var _ Client = (*client)(nil) + +type Client interface { + GetMessage(ctx context.Context, messageID ids.ID) ([]byte, error) + GetMessageSignature(ctx context.Context, messageID ids.ID) ([]byte, error) + GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error) + GetBlockSignature(ctx context.Context, blockID ids.ID) ([]byte, error) + GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error) +} + +// client implementation for interacting with EVM [chain] +type client struct { + *jsonrpc.Client +} + +// NewClient returns a Client for interacting with EVM [chain] +func NewClient(uri, chain string) (Client, error) { + rpcClient, err := jsonrpc.New(fmt.Sprintf("%s/ext/bc/%s/rpc", uri, chain)) + if err != nil { + return nil, fmt.Errorf("failed to dial client. err: %w", err) + } + return &client{ + Client: rpcClient, + }, nil +} + +func (c *client) GetMessage(ctx context.Context, messageID ids.ID) ([]byte, error) { + var res tmbytes.HexBytes + if _, err := c.Call(ctx, "warp_get_message", map[string]interface{}{"messageID": messageID}, &res); err != nil { + return nil, fmt.Errorf("call to warp_get_message failed. err: %w", err) + } + return res, nil +} + +func (c *client) GetMessageSignature(ctx context.Context, messageID ids.ID) ([]byte, error) { + var res tmbytes.HexBytes + if _, err := c.Call(ctx, "warp_get_message_signature", map[string]interface{}{"messageID": messageID}, &res); err != nil { + return nil, fmt.Errorf("call to warp_get_message_signature failed. err: %w", err) + } + return res, nil +} + +func (c *client) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error) { + var res tmbytes.HexBytes + if _, err := c.Call(ctx, "warp_get_message_aggregate_signature", + map[string]interface{}{ + "messageID": messageID, + "quorumNum": quorumNum, + "subnetIDStr": subnetIDStr, + }, &res); err != nil { + return nil, fmt.Errorf("call to warp_get_message_aggregate_signature failed. err: %w", err) + } + return res, nil +} + +func (c *client) GetBlockSignature(ctx context.Context, blockID ids.ID) ([]byte, error) { + var res tmbytes.HexBytes + if _, err := c.Call(ctx, "warp_get_block_signature", map[string]interface{}{"blockID": blockID}, &res); err != nil { + return nil, fmt.Errorf("call to warp_get_block_signature failed. err: %w", err) + } + return res, nil +} + +func (c *client) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error) { + var res tmbytes.HexBytes + if _, err := c.Call(ctx, "warp_get_block_aggregate_signature", + map[string]interface{}{ + "blockID": blockID, + "quorumNum": quorumNum, + "subnetIDStr": subnetIDStr, + }, &res); err != nil { + return nil, fmt.Errorf("call to warp_get_block_aggregate_signature failed. err: %w", err) + } + return res, nil +} diff --git a/warp/fetcher.go b/warp/fetcher.go new file mode 100644 index 00000000..b45894d5 --- /dev/null +++ b/warp/fetcher.go @@ -0,0 +1,54 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package warp + +import ( + "context" + "fmt" + + "github.com/landslidenetwork/slide-sdk/utils/crypto/bls" + "github.com/landslidenetwork/slide-sdk/utils/ids" + warputils "github.com/landslidenetwork/slide-sdk/utils/warp" + "github.com/landslidenetwork/slide-sdk/utils/warp/aggregator" + "github.com/landslidenetwork/slide-sdk/utils/warp/payload" +) + +var _ aggregator.SignatureGetter = (*apiFetcher)(nil) + +type apiFetcher struct { + clients map[ids.NodeID]Client +} + +func NewAPIFetcher(clients map[ids.NodeID]Client) *apiFetcher { + return &apiFetcher{ + clients: clients, + } +} + +func (f *apiFetcher) GetSignature(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *warputils.UnsignedMessage) (*bls.Signature, error) { + client, ok := f.clients[nodeID] + if !ok { + return nil, fmt.Errorf("no warp client for nodeID: %s", nodeID) + } + var signatureBytes []byte + parsedPayload, err := payload.Parse(unsignedWarpMessage.Payload) + if err != nil { + return nil, fmt.Errorf("failed to parse unsigned message payload: %w", err) + } + switch p := parsedPayload.(type) { + case *payload.AddressedCall: + signatureBytes, err = client.GetMessageSignature(ctx, unsignedWarpMessage.ID()) + case *payload.Hash: + signatureBytes, err = client.GetBlockSignature(ctx, p.Hash) + } + if err != nil { + return nil, err + } + + signature, err := bls.SignatureFromBytes(signatureBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse signature from client %s: %w", nodeID, err) + } + return signature, nil +} From 7adf3ba43e5ab54d163e1cd64abe9ff267aa38f6 Mon Sep 17 00:00:00 2001 From: ivansukach <47761294+ivansukach@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:43:15 +0100 Subject: [PATCH 07/10] Warp rpc add message (#78) * AddMessage RPC method * change error template * goimports warp_service --------- Co-authored-by: Ivan Sukach --- vm/rpc.go | 2 ++ vm/warp_service.go | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/vm/rpc.go b/vm/rpc.go index cc95b2c3..a6b26af1 100644 --- a/vm/rpc.go +++ b/vm/rpc.go @@ -66,6 +66,8 @@ func (rpc *RPC) Routes() map[string]*jsonrpc.RPCFunc { "abci_info": jsonrpc.NewRPCFunc(rpc.ABCIInfo, "", jsonrpc.Cacheable()), // warp + // TODO: implement add message functionality through block Accept process + "warp_add_message": jsonrpc.NewRPCFunc(rpc.vm.warpService.AddMessage, "message"), "warp_get_message": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessage, "messageID"), "warp_get_message_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessageSignature, "messageID"), "warp_get_message_aggregate_signature": jsonrpc.NewRPCFunc(rpc.vm.warpService.GetMessageAggregateSignature, "messageID,quorumNum,subnetID"), diff --git a/vm/warp_service.go b/vm/warp_service.go index cf82748e..bc2c7d2d 100644 --- a/vm/warp_service.go +++ b/vm/warp_service.go @@ -20,7 +20,14 @@ import ( "github.com/landslidenetwork/slide-sdk/warp" ) -const failedParseIDPattern = "failed to parse ID %s with error %w" +const ( + failedParseIDPattern = "failed to parse ID %s with error %w" + failedParseWARPMessage = "failed to parse warp message %s with error %w" +) + +type ResultAddMessage struct { + MessageID string `json:"messageID"` +} var errNoValidators = errors.New("cannot aggregate signatures from subnet with no validators") @@ -63,6 +70,19 @@ func NewAPI(vm *LandslideVM, logger log.Logger, networkID uint32, state validato } } +// AddMessage returns the Warp message associated with a messageID. +func (a *API) AddMessage(_ *rpctypes.Context, message []byte) (*ResultAddMessage, error) { + msg, err := warputils.ParseUnsignedMessage(message) + if err != nil { + return nil, fmt.Errorf(failedParseWARPMessage, message, err) + } + err = a.backend.AddMessage(msg) + if err != nil { + return nil, fmt.Errorf("failed to add message {ID: %s} with error %w", msg.ID().String(), err) + } + return &ResultAddMessage{MessageID: msg.ID().String()}, nil +} + // GetMessage returns the Warp message associated with a messageID. func (a *API) GetMessage(_ *rpctypes.Context, messageID string) (*ResultGetMessage, error) { msgID, err := ids.FromString(messageID) From da74573058c7d4087510aa65f00b8533d3cfcdad Mon Sep 17 00:00:00 2001 From: Ivan Sukach Date: Fri, 10 Jan 2025 02:10:42 +0100 Subject: [PATCH 08/10] warp client: rpcAddr as single param for client, addMessage method --- vm/vm.go | 2 +- warp/client.go | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vm/vm.go b/vm/vm.go index 8c0ce565..6212bdb2 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -495,7 +495,7 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest if err != nil { return nil, err } - rpcClient, err := warp.NewClient(nodeURI, string(req.ChainId)) + rpcClient, err := warp.NewClient(fmt.Sprintf("%s/ext/bc/%s/rpc", nodeURI, string(req.ChainId))) if err != nil { return nil, err } diff --git a/warp/client.go b/warp/client.go index a9773bd7..a646201b 100644 --- a/warp/client.go +++ b/warp/client.go @@ -16,6 +16,7 @@ import ( var _ Client = (*client)(nil) type Client interface { + AddMessage(ctx context.Context, message []byte) ([]byte, error) GetMessage(ctx context.Context, messageID ids.ID) ([]byte, error) GetMessageSignature(ctx context.Context, messageID ids.ID) ([]byte, error) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID, quorumNum uint64, subnetIDStr string) ([]byte, error) @@ -29,8 +30,8 @@ type client struct { } // NewClient returns a Client for interacting with EVM [chain] -func NewClient(uri, chain string) (Client, error) { - rpcClient, err := jsonrpc.New(fmt.Sprintf("%s/ext/bc/%s/rpc", uri, chain)) +func NewClient(rpcAddr string) (Client, error) { + rpcClient, err := jsonrpc.New(rpcAddr) if err != nil { return nil, fmt.Errorf("failed to dial client. err: %w", err) } @@ -39,6 +40,14 @@ func NewClient(uri, chain string) (Client, error) { }, nil } +func (c *client) AddMessage(ctx context.Context, message []byte) ([]byte, error) { + var res tmbytes.HexBytes + if _, err := c.Call(ctx, "warp_add_message", map[string]interface{}{"message": message}, &res); err != nil { + return nil, fmt.Errorf("call to warp_add_message failed. err: %w", err) + } + return res, nil +} + func (c *client) GetMessage(ctx context.Context, messageID ids.ID) ([]byte, error) { var res tmbytes.HexBytes if _, err := c.Call(ctx, "warp_get_message", map[string]interface{}{"messageID": messageID}, &res); err != nil { From c8e3acbefa8a5c00c4f41ec08aea27dfd5b5d858 Mon Sep 17 00:00:00 2001 From: Ivan Sukach Date: Fri, 10 Jan 2025 20:47:34 +0100 Subject: [PATCH 09/10] add descriptional error messages --- vm/vm.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vm/vm.go b/vm/vm.go index 6212bdb2..f27a3066 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -456,7 +456,6 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest return nil, err } // vm.logger.Debug("initialize block", "bytes ", blockBytes) - vm.logger.Info("vm initialization completed") parentHash := block.ParentHash(blk) @@ -469,12 +468,12 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest // } chainID, err := ids.ToID(req.ChainId) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse chain ID: %w", err) } fmt.Println(vm.config.BLSSecretKey) secretKey, err := bls.SecretKeyFromBytes(vm.config.BLSSecretKey) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse secret key from bytes: %w", err) } vm.warpSigner = warputils.NewSigner(secretKey, req.NetworkId, chainID) vm.warpBackend = warp.NewBackend( @@ -487,22 +486,23 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest subnetID, err := ids.ToID(req.SubnetId) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse subnet ID: %w", err) } rpcClients := make(map[ids.NodeID]warp.Client) for id, nodeURI := range vm.config.AddressBook { nodeID, err := ids.ToNodeID([]byte(id)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse node ID: %w", err) } rpcClient, err := warp.NewClient(fmt.Sprintf("%s/ext/bc/%s/rpc", nodeURI, string(req.ChainId))) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create warp client: %w", err) } rpcClients[nodeID] = rpcClient } vm.warpService = NewAPI(vm, vm.logger, req.NetworkId, validatorStateClient, subnetID, chainID, vm.warpBackend, rpcClients, requirePrimaryNetworkSigners) + vm.logger.Info("vm initialization completed") return &vmpb.InitializeResponse{ LastAcceptedId: blk.Hash(), LastAcceptedParentId: parentHash[:], From 612e009998cd58a6b54e9dd0b538404438cf3200 Mon Sep 17 00:00:00 2001 From: Ivan Sukach Date: Mon, 13 Jan 2025 18:58:11 +0100 Subject: [PATCH 10/10] fix warp RPC service --- vm/vm.go | 4 ++-- vm/warp_service.go | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vm/vm.go b/vm/vm.go index f27a3066..6747b346 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -490,11 +490,11 @@ func (vm *LandslideVM) Initialize(_ context.Context, req *vmpb.InitializeRequest } rpcClients := make(map[ids.NodeID]warp.Client) for id, nodeURI := range vm.config.AddressBook { - nodeID, err := ids.ToNodeID([]byte(id)) + nodeID, err := ids.NodeIDFromString(id) if err != nil { return nil, fmt.Errorf("failed to parse node ID: %w", err) } - rpcClient, err := warp.NewClient(fmt.Sprintf("%s/ext/bc/%s/rpc", nodeURI, string(req.ChainId))) + rpcClient, err := warp.NewClient(fmt.Sprintf("%s/ext/bc/%s/rpc", nodeURI, chainID.String())) if err != nil { return nil, fmt.Errorf("failed to create warp client: %w", err) } diff --git a/vm/warp_service.go b/vm/warp_service.go index bc2c7d2d..80989311 100644 --- a/vm/warp_service.go +++ b/vm/warp_service.go @@ -39,6 +39,10 @@ type ResultGetMessageSignature struct { Signature []byte `json:"signature"` } +type ResultGetBlockSignature struct { + Signature []byte `json:"signature"` +} + // API introduces snowman specific functionality to the evm type API struct { vm *LandslideVM @@ -123,12 +127,12 @@ func (a *API) GetMessageAggregateSignature(ctx context.Context, messageID ids.ID } // GetBlockSignature returns the BLS signature associated with a blockID. -func (a *API) GetBlockSignature(ctx context.Context, blockID ids.ID) (tmbytes.HexBytes, error) { +func (a *API) GetBlockSignature(ctx context.Context, blockID ids.ID) (*ResultGetBlockSignature, error) { signature, err := a.backend.GetBlockSignature(blockID) if err != nil { return nil, fmt.Errorf("failed to get signature for block %s with error %w", blockID, err) } - return signature, nil + return &ResultGetBlockSignature{Signature: signature}, nil } // GetBlockAggregateSignature fetches the aggregate signature for the requested [blockID]