package appservice_test import ( "context" "encoding/json" "net/http" "net/http/httptest" "reflect" "regexp" "strings" "testing" "time" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/test/testrig" ) func TestAppserviceInternalAPI(t *testing.T) { // Set expected results existingProtocol := "irc" wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}} wantProtocolResult := map[string]api.ASProtocolResponse{ existingProtocol: wantProtocolResponse, } // create a dummy AS url, handling some cases srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "location"): // Check if we've got an existing protocol, if so, return a proper response. if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil { t.Fatalf("failed to encode response: %s", err) } return } if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil { t.Fatalf("failed to encode response: %s", err) } return case strings.Contains(r.URL.Path, "user"): if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil { t.Fatalf("failed to encode response: %s", err) } return } if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil { t.Fatalf("failed to encode response: %s", err) } return case strings.Contains(r.URL.Path, "protocol"): if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil { t.Fatalf("failed to encode response: %s", err) } return } if err := json.NewEncoder(w).Encode(nil); err != nil { t.Fatalf("failed to encode response: %s", err) } return default: t.Logf("hit location: %s", r.URL.Path) } })) // The test cases to run runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) { t.Run("UserIDExists", func(t *testing.T) { testUserIDExists(t, testAPI, "@as-testing:test", true) testUserIDExists(t, testAPI, "@as1-testing:test", false) }) t.Run("AliasExists", func(t *testing.T) { testAliasExists(t, testAPI, "@asroom-testing:test", true) testAliasExists(t, testAPI, "@asroom1-testing:test", false) }) t.Run("Locations", func(t *testing.T) { testLocations(t, testAPI, existingProtocol, wantLocationResponse) testLocations(t, testAPI, "abc", nil) }) t.Run("User", func(t *testing.T) { testUser(t, testAPI, existingProtocol, wantUserResponse) testUser(t, testAPI, "abc", nil) }) t.Run("Protocols", func(t *testing.T) { testProtocol(t, testAPI, existingProtocol, wantProtocolResult) testProtocol(t, testAPI, existingProtocol, wantProtocolResult) // tests the cache testProtocol(t, testAPI, "", wantProtocolResult) // tests getting all protocols testProtocol(t, testAPI, "abc", nil) }) } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, ctx, close := testrig.CreateConfig(t, dbType) defer close() // Create a dummy application service cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ { ID: "someID", URL: srv.URL, ASToken: "", HSToken: "", SenderLocalpart: "senderLocalPart", NamespaceMap: map[string][]config.ApplicationServiceNamespace{ "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, }, Protocols: []string{existingProtocol}, }, } t.Cleanup(func() { ctx.ShutdownDendrite() ctx.WaitForShutdown() }) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) // Create required internal APIs natsInstance := jetstream.NATSInstance{} cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil) asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI) runCases(t, asAPI) }) } func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { ctx := context.Background() userResp := &api.UserIDExistsResponse{} if err := asAPI.UserIDExists(ctx, &api.UserIDExistsRequest{ UserID: userID, }, userResp); err != nil { t.Errorf("failed to get userID: %s", err) } if userResp.UserIDExists != wantExists { t.Errorf("unexpected result for UserIDExists(%s): %v, expected %v", userID, userResp.UserIDExists, wantExists) } } func testAliasExists(t *testing.T, asAPI api.AppServiceInternalAPI, alias string, wantExists bool) { ctx := context.Background() aliasResp := &api.RoomAliasExistsResponse{} if err := asAPI.RoomAliasExists(ctx, &api.RoomAliasExistsRequest{ Alias: alias, }, aliasResp); err != nil { t.Errorf("failed to get alias: %s", err) } if aliasResp.AliasExists != wantExists { t.Errorf("unexpected result for RoomAliasExists(%s): %v, expected %v", alias, aliasResp.AliasExists, wantExists) } } func testLocations(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASLocationResponse) { ctx := context.Background() locationResp := &api.LocationResponse{} if err := asAPI.Locations(ctx, &api.LocationRequest{ Protocol: proto, }, locationResp); err != nil { t.Errorf("failed to get locations: %s", err) } if !reflect.DeepEqual(locationResp.Locations, wantResult) { t.Errorf("unexpected result for Locations(%s): %+v, expected %+v", proto, locationResp.Locations, wantResult) } } func testUser(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASUserResponse) { ctx := context.Background() userResp := &api.UserResponse{} if err := asAPI.User(ctx, &api.UserRequest{ Protocol: proto, }, userResp); err != nil { t.Errorf("failed to get user: %s", err) } if !reflect.DeepEqual(userResp.Users, wantResult) { t.Errorf("unexpected result for User(%s): %+v, expected %+v", proto, userResp.Users, wantResult) } } func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult map[string]api.ASProtocolResponse) { ctx := context.Background() protoResp := &api.ProtocolResponse{} if err := asAPI.Protocols(ctx, &api.ProtocolRequest{ Protocol: proto, }, protoResp); err != nil { t.Errorf("failed to get Protocols: %s", err) } if !reflect.DeepEqual(protoResp.Protocols, wantResult) { t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult) } }