diff --git a/.gitignore b/.gitignore index fe5e8279..dcfbf800 100644 --- a/.gitignore +++ b/.gitignore @@ -74,3 +74,4 @@ complement/ docs/_site media_store/ +build \ No newline at end of file diff --git a/appservice/appservice.go b/appservice/appservice.go index d13d9eb1..1f6037ee 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -16,10 +16,7 @@ package appservice import ( "context" - "crypto/tls" - "net/http" "sync" - "time" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -44,20 +41,10 @@ func NewInternalAPI( userAPI userapi.AppserviceUserAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceInternalAPI { - client := &http.Client{ - Timeout: time.Second * 30, - Transport: &http.Transport{ - DisableKeepAlives: true, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: cfg.AppServiceAPI.DisableTLSValidation, - }, - Proxy: http.ProxyFromEnvironment, - }, - } + // Create appserivce query API with an HTTP client that will be used for all // outbound and inbound requests (inbound only for the internal API) appserviceQueryAPI := &query.AppServiceQueryAPI{ - HTTPClient: client, Cfg: &cfg.AppServiceAPI, ProtocolCache: map[string]appserviceAPI.ASProtocolResponse{}, CacheMu: sync.Mutex{}, @@ -84,7 +71,7 @@ func NewInternalAPI( js, _ := natsInstance.Prepare(processContext, &cfg.Global.JetStream) consumer := consumers.NewOutputRoomEventConsumer( processContext, &cfg.AppServiceAPI, - client, js, rsAPI, + js, rsAPI, ) if err := consumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start appservice roomserver consumer") diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 752901a9..5189bdf9 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -3,14 +3,19 @@ package appservice_test import ( "context" "encoding/json" + "fmt" + "net" "net/http" "net/http/httptest" + "path" "reflect" "regexp" "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal/caching" @@ -114,20 +119,20 @@ func TestAppserviceInternalAPI(t *testing.T) { defer close() // Create a dummy application service - cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ - { - ID: "someID", - URL: srv.URL, - ASToken: "", - HSToken: "", - SenderLocalpart: "senderLocalPart", - NamespaceMap: map[string][]config.ApplicationServiceNamespace{ - "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, - "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, - }, - Protocols: []string{existingProtocol}, + as := &config.ApplicationService{ + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, }, + Protocols: []string{existingProtocol}, } + as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation) + cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as} t.Cleanup(func() { ctx.ShutdownDendrite() @@ -145,6 +150,103 @@ func TestAppserviceInternalAPI(t *testing.T) { }) } +func TestAppserviceInternalAPI_UnixSocket_Simple(t *testing.T) { + + // Set expected results + existingProtocol := "irc" + wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}} + + // create a dummy AS url, handling some cases + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "location"): + // Check if we've got an existing protocol, if so, return a proper response. + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "user"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "protocol"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode(nil); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + default: + t.Logf("hit location: %s", r.URL.Path) + } + })) + + tmpDir := t.TempDir() + socket := path.Join(tmpDir, "socket") + l, err := net.Listen("unix", socket) + assert.NoError(t, err) + _ = srv.Listener.Close() + srv.Listener = l + srv.Start() + defer srv.Close() + + cfg, ctx, tearDown := testrig.CreateConfig(t, test.DBTypeSQLite) + defer tearDown() + + // Create a dummy application service + as := &config.ApplicationService{ + ID: "someID", + URL: fmt.Sprintf("unix://%s", socket), + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, + }, + Protocols: []string{existingProtocol}, + } + as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation) + cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as} + + t.Cleanup(func() { + ctx.ShutdownDendrite() + ctx.WaitForShutdown() + }) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + // Create required internal APIs + natsInstance := jetstream.NATSInstance{} + cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil) + asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI) + + t.Run("UserIDExists", func(t *testing.T) { + testUserIDExists(t, asAPI, "@as-testing:test", true) + testUserIDExists(t, asAPI, "@as1-testing:test", false) + }) + +} + func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { ctx := context.Background() userResp := &api.UserIDExistsResponse{} @@ -254,20 +356,21 @@ func TestRoomserverConsumerOneInvite(t *testing.T) { })) defer srv.Close() - // Create a dummy application service - cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ - { - ID: "someID", - URL: srv.URL, - ASToken: "", - HSToken: "", - SenderLocalpart: "senderLocalPart", - NamespaceMap: map[string][]config.ApplicationServiceNamespace{ - "users": {{RegexpObject: regexp.MustCompile(bob.ID)}}, - "aliases": {{RegexpObject: regexp.MustCompile(room.ID)}}, - }, + as := &config.ApplicationService{ + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile(bob.ID)}}, + "aliases": {{RegexpObject: regexp.MustCompile(room.ID)}}, }, } + as.CreateHTTPClient(cfg.AppServiceAPI.DisableTLSValidation) + + // Create a dummy application service + cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{*as} caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) // Create required internal APIs diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 16b3b823..308b0367 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -40,7 +40,6 @@ import ( type OutputRoomEventConsumer struct { ctx context.Context cfg *config.AppServiceAPI - client *http.Client jetstream nats.JetStreamContext topic string rsAPI api.AppserviceRoomserverAPI @@ -56,14 +55,12 @@ type appserviceState struct { func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.AppServiceAPI, - client *http.Client, js nats.JetStreamContext, rsAPI api.AppserviceRoomserverAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), cfg: cfg, - client: client, jetstream: js, topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), rsAPI: rsAPI, @@ -189,13 +186,13 @@ func (s *OutputRoomEventConsumer) sendEvents( // Send the transaction to the appservice. // https://matrix.org/docs/spec/application_service/r0.1.2#put-matrix-app-v1-transactions-txnid - address := fmt.Sprintf("%s/transactions/%s?access_token=%s", state.URL, txnID, url.QueryEscape(state.HSToken)) + address := fmt.Sprintf("%s/transactions/%s?access_token=%s", state.RequestUrl(), txnID, url.QueryEscape(state.HSToken)) req, err := http.NewRequestWithContext(ctx, "PUT", address, bytes.NewBuffer(transaction)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") - resp, err := s.client.Do(req) + resp, err := state.HTTPClient.Do(req) if err != nil { return state.backoffAndPause(err) } @@ -206,7 +203,7 @@ func (s *OutputRoomEventConsumer) sendEvents( case http.StatusOK: state.backoff = 0 default: - return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice", resp.StatusCode)) + return state.backoffAndPause(fmt.Errorf("received HTTP status code %d from appservice url %s", resp.StatusCode, address)) } return nil } diff --git a/appservice/query/query.go b/appservice/query/query.go index 0466f81d..ca8d7b3a 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -37,7 +37,6 @@ const userIDExistsPath = "/users/" // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI type AppServiceQueryAPI struct { - HTTPClient *http.Client Cfg *config.AppServiceAPI ProtocolCache map[string]api.ASProtocolResponse CacheMu sync.Mutex @@ -57,7 +56,7 @@ func (a *AppServiceQueryAPI) RoomAliasExists( for _, appservice := range a.Cfg.Derived.ApplicationServices { if appservice.URL != "" && appservice.IsInterestedInRoomAlias(request.Alias) { // The full path to the rooms API, includes hs token - URL, err := url.Parse(appservice.URL + roomAliasExistsPath) + URL, err := url.Parse(appservice.RequestUrl() + roomAliasExistsPath) if err != nil { return err } @@ -73,7 +72,7 @@ func (a *AppServiceQueryAPI) RoomAliasExists( } req = req.WithContext(ctx) - resp, err := a.HTTPClient.Do(req) + resp, err := appservice.HTTPClient.Do(req) if resp != nil { defer func() { err = resp.Body.Close() @@ -124,7 +123,7 @@ func (a *AppServiceQueryAPI) UserIDExists( for _, appservice := range a.Cfg.Derived.ApplicationServices { if appservice.URL != "" && appservice.IsInterestedInUserID(request.UserID) { // The full path to the rooms API, includes hs token - URL, err := url.Parse(appservice.URL + userIDExistsPath) + URL, err := url.Parse(appservice.RequestUrl() + userIDExistsPath) if err != nil { return err } @@ -137,7 +136,7 @@ func (a *AppServiceQueryAPI) UserIDExists( if err != nil { return err } - resp, err := a.HTTPClient.Do(req.WithContext(ctx)) + resp, err := appservice.HTTPClient.Do(req.WithContext(ctx)) if resp != nil { defer func() { err = resp.Body.Close() @@ -212,12 +211,12 @@ func (a *AppServiceQueryAPI) Locations( var asLocations []api.ASLocationResponse params.Set("access_token", as.HSToken) - url := as.URL + api.ASLocationPath + url := as.RequestUrl() + api.ASLocationPath if req.Protocol != "" { url += "/" + req.Protocol } - if err := requestDo[[]api.ASLocationResponse](a.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil { + if err := requestDo[[]api.ASLocationResponse](as.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil { log.WithError(err).Error("unable to get 'locations' from application service") continue } @@ -247,12 +246,12 @@ func (a *AppServiceQueryAPI) User( var asUsers []api.ASUserResponse params.Set("access_token", as.HSToken) - url := as.URL + api.ASUserPath + url := as.RequestUrl() + api.ASUserPath if req.Protocol != "" { url += "/" + req.Protocol } - if err := requestDo[[]api.ASUserResponse](a.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil { + if err := requestDo[[]api.ASUserResponse](as.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil { log.WithError(err).Error("unable to get 'user' from application service") continue } @@ -290,7 +289,7 @@ func (a *AppServiceQueryAPI) Protocols( response := api.ASProtocolResponse{} for _, as := range a.Cfg.Derived.ApplicationServices { var proto api.ASProtocolResponse - if err := requestDo[api.ASProtocolResponse](a.HTTPClient, as.URL+api.ASProtocolPath+req.Protocol, &proto); err != nil { + if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+req.Protocol, &proto); err != nil { log.WithError(err).Error("unable to get 'protocol' from application service") continue } @@ -320,7 +319,7 @@ func (a *AppServiceQueryAPI) Protocols( for _, as := range a.Cfg.Derived.ApplicationServices { for _, p := range as.Protocols { var proto api.ASProtocolResponse - if err := requestDo[api.ASProtocolResponse](a.HTTPClient, as.URL+api.ASProtocolPath+p, &proto); err != nil { + if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+p, &proto); err != nil { log.WithError(err).Error("unable to get 'protocol' from application service") continue } diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index 37e20a97..ef10649d 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -15,16 +15,23 @@ package config import ( + "context" + "crypto/tls" "fmt" + "net" + "net/http" "os" "path/filepath" "regexp" "strings" + "time" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) +const UnixSocketPrefix = "unix://" + type AppServiceAPI struct { Matrix *Global `yaml:"-"` Derived *Derived `yaml:"-"` // TODO: Nuke Derived from orbit @@ -80,7 +87,41 @@ type ApplicationService struct { // Whether rate limiting is applied to each application service user RateLimited bool `yaml:"rate_limited"` // Any custom protocols that this application service provides (e.g. IRC) - Protocols []string `yaml:"protocols"` + Protocols []string `yaml:"protocols"` + HTTPClient *http.Client + isUnixSocket bool + unixSocket string +} + +func (a *ApplicationService) CreateHTTPClient(insecureSkipVerify bool) { + client := &http.Client{ + Timeout: time.Second * 30, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: insecureSkipVerify, + }, + Proxy: http.ProxyFromEnvironment, + }, + } + if strings.HasPrefix(a.URL, UnixSocketPrefix) { + a.isUnixSocket = true + a.unixSocket = "http://unix" + client.Transport = &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimPrefix(a.URL, UnixSocketPrefix)) + }, + } + } + a.HTTPClient = client +} + +func (a *ApplicationService) RequestUrl() string { + if a.isUnixSocket { + return a.unixSocket + } else { + return a.URL + } } // IsInterestedInRoomID returns a bool on whether an application service's @@ -152,7 +193,7 @@ func (a *ApplicationService) IsInterestedInRoomAlias( func loadAppServices(config *AppServiceAPI, derived *Derived) error { for _, configPath := range config.ConfigFiles { // Create a new application service with default options - appservice := ApplicationService{ + appservice := &ApplicationService{ RateLimited: true, } @@ -169,13 +210,13 @@ func loadAppServices(config *AppServiceAPI, derived *Derived) error { } // Load the config data into our struct - if err = yaml.Unmarshal(configData, &appservice); err != nil { + if err = yaml.Unmarshal(configData, appservice); err != nil { return err } - + appservice.CreateHTTPClient(config.DisableTLSValidation) // Append the parsed application service to the global config derived.ApplicationServices = append( - derived.ApplicationServices, appservice, + derived.ApplicationServices, *appservice, ) }