Implement /keys/query locally (#1204)

* Implement /keys/query locally

* Fix sqlite tests and close rows
This commit is contained in:
Kegsay 2020-07-15 18:40:41 +01:00 committed by GitHub
parent df8d6823ee
commit f5e7e7513c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 183 additions and 27 deletions

View File

@ -17,6 +17,7 @@ package routing
import (
"encoding/json"
"net/http"
"time"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -25,18 +26,6 @@ import (
"github.com/matrix-org/util"
)
func QueryKeys(
req *http.Request,
) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusOK,
JSON: map[string]interface{}{
"failures": map[string]interface{}{},
"device_keys": map[string]interface{}{},
},
}
}
type uploadKeysRequest struct {
DeviceKeys json.RawMessage `json:"device_keys"`
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
@ -94,3 +83,37 @@ func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.De
}{keyCount},
}
}
type queryKeysRequest struct {
Timeout int `json:"timeout"`
Token string `json:"token"`
DeviceKeys map[string][]string `json:"device_keys"`
}
func (r *queryKeysRequest) GetTimeout() time.Duration {
if r.Timeout == 0 {
return 10 * time.Second
}
return time.Duration(r.Timeout) * time.Millisecond
}
func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse {
var r queryKeysRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
queryRes := api.QueryKeysResponse{}
keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
UserToDevices: r.DeviceKeys,
Timeout: r.GetTimeout(),
// TODO: Token?
}, &queryRes)
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
"device_keys": queryRes.DeviceKeys,
"failures": queryRes.Failures,
},
}
}

View File

@ -698,12 +698,6 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/keys/query",
httputil.MakeAuthAPI("queryKeys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -715,4 +709,9 @@ func Setup(
return UploadKeys(req, keyAPI, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
}

View File

@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"strings"
"time"
)
type KeyInternalAPI interface {
@ -108,8 +109,16 @@ type PerformClaimKeysResponse struct {
}
type QueryKeysRequest struct {
// Maps user IDs to a list of devices
UserToDevices map[string][]string
Timeout time.Duration
}
type QueryKeysResponse struct {
// Map of remote server domain to error JSON
Failures map[string]interface{}
// Map of user_id to device_id to device_key
DeviceKeys map[string]map[string]json.RawMessage
// Set if there was a fatal error processing this query
Error *KeyError
}

View File

@ -17,15 +17,19 @@ package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type KeyInternalAPI struct {
DB storage.Database
DB storage.Database
ThisServer gomatrixserverlib.ServerName
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
@ -37,7 +41,45 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
}
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// make a map from domain to device keys
domainToUserToDevice := make(map[string][]api.DeviceKeys)
for userID, deviceIDs := range req.UserToDevices {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
continue // ignore invalid users
}
domain := string(serverName)
// query local devices
if serverName == a.ThisServer {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
}
return
}
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, dk := range deviceKeys {
// inject an empty 'unsigned' key which should be used for display names
// (but not via this API? unsure when they should be added)
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
}
} else {
for _, deviceID := range deviceIDs {
domainToUserToDevice[domain] = append(domainToUserToDevice[domain], api.DeviceKeys{
UserID: userID,
DeviceID: deviceID,
})
}
}
}
// TODO: set device display names when they are known
// TODO: perform key queries for remote devices
}
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {

View File

@ -41,6 +41,7 @@ func NewInternalAPI(cfg *config.Dendrite) api.KeyInternalAPI {
logrus.WithError(err).Panicf("failed to connect to key server database")
}
return &internal.KeyInternalAPI{
DB: db,
DB: db,
ThisServer: cfg.Matrix.ServerName,
}
}

View File

@ -35,4 +35,8 @@ type Database interface {
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced.
// Returns an error if there was a problem storing the keys.
StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
}

View File

@ -19,6 +19,7 @@ import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
@ -45,10 +46,14 @@ const upsertDeviceKeysSQL = "" +
const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -65,6 +70,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil
}
@ -95,3 +103,30 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.
return nil
})
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
var result []api.DeviceKeys
for rows.Next() {
var dk api.DeviceKeys
dk.UserID = userID
var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
}
}
return result, rows.Err()
}

View File

@ -44,3 +44,7 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) er
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys)
}
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
}

View File

@ -19,6 +19,7 @@ import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
@ -45,10 +46,14 @@ const upsertDeviceKeysSQL = "" +
const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -65,9 +70,39 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceKeys
for rows.Next() {
var dk api.DeviceKeys
dk.UserID = userID
var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
}
}
return result, rows.Err()
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
for i, key := range keys {
var keyJSONStr string

View File

@ -29,4 +29,5 @@ type OneTimeKeys interface {
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
}

View File

@ -121,6 +121,9 @@ local user can join room with version 1
User can invite local user to room with version 1
Can upload device keys
Should reject keys claiming to belong to a different user
Can query device keys using POST
Can query specific device keys using POST
query for user with no keys returns empty key dict
Can add account data
Can add account data to room
Can get account data without syncing