diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index f44eed89..848ac79f 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -21,20 +21,17 @@ import ( "crypto/tls" "encoding/hex" "fmt" - "io" "net" "net/http" "os" "path/filepath" "strings" - "sync" "time" - "go.uber.org/atomic" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conduit" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" @@ -291,16 +288,14 @@ func (m *DendriteMonolith) DisconnectPort(port int) { m.PineconeRouter.Disconnect(types.SwitchPortID(port), nil) } -func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) { +func (m *DendriteMonolith) Conduit(zone string, peertype int) (*conduit.Conduit, error) { l, r := net.Pipe() - conduit := &Conduit{conn: r, port: 0} + newConduit := conduit.NewConduit(r, 0) go func() { - conduit.portMutex.Lock() - defer conduit.portMutex.Unlock() - logrus.Errorf("Attempting authenticated connect") + var port types.SwitchPortID var err error - if conduit.port, err = m.PineconeRouter.Connect( + if port, err = m.PineconeRouter.Connect( l, pineconeRouter.ConnectionZone(zone), pineconeRouter.ConnectionPeerType(peertype), @@ -308,12 +303,13 @@ func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) logrus.Errorf("Authenticated connect failed: %s", err) _ = l.Close() _ = r.Close() - _ = conduit.Close() + _ = newConduit.Close() return } - logrus.Infof("Authenticated connect succeeded (port %d)", conduit.port) + newConduit.SetPort(port) + logrus.Infof("Authenticated connect succeeded (port %d)", newConduit.Port()) }() - return conduit, nil + return &newConduit, nil } func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, error) { @@ -602,52 +598,3 @@ func (m *DendriteMonolith) Stop() { _ = m.PineconeQUIC.Close() _ = m.PineconeRouter.Close() } - -const MaxFrameSize = types.MaxFrameSize - -type Conduit struct { - closed atomic.Bool - conn net.Conn - port types.SwitchPortID - portMutex sync.Mutex -} - -func (c *Conduit) Port() int { - c.portMutex.Lock() - defer c.portMutex.Unlock() - return int(c.port) -} - -func (c *Conduit) Read(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Read(b) -} - -func (c *Conduit) ReadCopy() ([]byte, error) { - if c.closed.Load() { - return nil, io.EOF - } - var buf [65535 * 2]byte - n, err := c.conn.Read(buf[:]) - if err != nil { - return nil, err - } - return buf[:n], nil -} - -func (c *Conduit) Write(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Write(b) -} - -func (c *Conduit) Close() error { - if c.closed.Load() { - return io.ErrClosedPipe - } - c.closed.Store(true) - return c.conn.Close() -} diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index f6bf2ef0..434e07ef 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -15,113 +15,12 @@ package gobind import ( - "fmt" - "net" "strings" "testing" "github.com/matrix-org/gomatrixserverlib" - "github.com/stretchr/testify/assert" ) -var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} - -type TestNetConn struct { - net.Conn - shouldFail bool -} - -func (t *TestNetConn) Read(b []byte) (int, error) { - if t.shouldFail { - return 0, fmt.Errorf("Failed") - } else { - n := copy(b, TestBuf) - return n, nil - } -} - -func (t *TestNetConn) Write(b []byte) (int, error) { - if t.shouldFail { - return 0, fmt.Errorf("Failed") - } else { - return len(b), nil - } -} - -func (t *TestNetConn) Close() error { - if t.shouldFail { - return fmt.Errorf("Failed") - } else { - return nil - } -} - -func TestConduitStoresPort(t *testing.T) { - conduit := Conduit{port: 7} - assert.Equal(t, 7, conduit.Port()) -} - -func TestConduitRead(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - b := make([]byte, len(TestBuf)) - bytes, err := conduit.Read(b) - assert.NoError(t, err) - assert.Equal(t, len(TestBuf), bytes) - assert.Equal(t, TestBuf, b) -} - -func TestConduitReadCopy(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - result, err := conduit.ReadCopy() - assert.NoError(t, err) - assert.Equal(t, TestBuf, result) -} - -func TestConduitWrite(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - bytes, err := conduit.Write(TestBuf) - assert.NoError(t, err) - assert.Equal(t, len(TestBuf), bytes) -} - -func TestConduitClose(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - assert.True(t, conduit.closed.Load()) -} - -func TestConduitReadClosed(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - b := make([]byte, len(TestBuf)) - _, err = conduit.Read(b) - assert.Error(t, err) -} - -func TestConduitReadCopyClosed(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - _, err = conduit.ReadCopy() - assert.Error(t, err) -} - -func TestConduitWriteClosed(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - _, err = conduit.Write(TestBuf) - assert.Error(t, err) -} - -func TestConduitReadCopyFails(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{shouldFail: true}} - _, err := conduit.ReadCopy() - assert.Error(t, err) -} - func TestMonolithStarts(t *testing.T) { monolith := DendriteMonolith{} monolith.Start() diff --git a/cmd/dendrite-demo-pinecone/conduit/conduit.go b/cmd/dendrite-demo-pinecone/conduit/conduit.go new file mode 100644 index 00000000..be139c19 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/conduit/conduit.go @@ -0,0 +1,84 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conduit + +import ( + "io" + "net" + "sync" + + "github.com/matrix-org/pinecone/types" + "go.uber.org/atomic" +) + +type Conduit struct { + closed atomic.Bool + conn net.Conn + portMutex sync.Mutex + port types.SwitchPortID +} + +func NewConduit(conn net.Conn, port int) Conduit { + return Conduit{ + conn: conn, + port: types.SwitchPortID(port), + } +} + +func (c *Conduit) Port() int { + c.portMutex.Lock() + defer c.portMutex.Unlock() + return int(c.port) +} + +func (c *Conduit) SetPort(port types.SwitchPortID) { + c.portMutex.Lock() + defer c.portMutex.Unlock() + c.port = port +} + +func (c *Conduit) Read(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } + return c.conn.Read(b) +} + +func (c *Conduit) ReadCopy() ([]byte, error) { + if c.closed.Load() { + return nil, io.EOF + } + var buf [65535 * 2]byte + n, err := c.conn.Read(buf[:]) + if err != nil { + return nil, err + } + return buf[:n], nil +} + +func (c *Conduit) Write(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } + return c.conn.Write(b) +} + +func (c *Conduit) Close() error { + if c.closed.Load() { + return io.ErrClosedPipe + } + c.closed.Store(true) + return c.conn.Close() +} diff --git a/cmd/dendrite-demo-pinecone/conduit/conduit_test.go b/cmd/dendrite-demo-pinecone/conduit/conduit_test.go new file mode 100644 index 00000000..d8cd3133 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/conduit/conduit_test.go @@ -0,0 +1,121 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conduit + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +type TestNetConn struct { + net.Conn + shouldFail bool +} + +func (t *TestNetConn) Read(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + n := copy(b, TestBuf) + return n, nil + } +} + +func (t *TestNetConn) Write(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + return len(b), nil + } +} + +func (t *TestNetConn) Close() error { + if t.shouldFail { + return fmt.Errorf("Failed") + } else { + return nil + } +} + +func TestConduitStoresPort(t *testing.T) { + conduit := Conduit{port: 7} + assert.Equal(t, 7, conduit.Port()) +} + +func TestConduitRead(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + b := make([]byte, len(TestBuf)) + bytes, err := conduit.Read(b) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) + assert.Equal(t, TestBuf, b) +} + +func TestConduitReadCopy(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + result, err := conduit.ReadCopy() + assert.NoError(t, err) + assert.Equal(t, TestBuf, result) +} + +func TestConduitWrite(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + bytes, err := conduit.Write(TestBuf) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) +} + +func TestConduitClose(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + assert.True(t, conduit.closed.Load()) +} + +func TestConduitReadClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + b := make([]byte, len(TestBuf)) + _, err = conduit.Read(b) + assert.Error(t, err) +} + +func TestConduitReadCopyClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.ReadCopy() + assert.Error(t, err) +} + +func TestConduitWriteClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.Write(TestBuf) + assert.Error(t, err) +} + +func TestConduitReadCopyFails(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{shouldFail: true}} + _, err := conduit.ReadCopy() + assert.Error(t, err) +}