diff --git a/internal/server/api/api.go b/internal/server/api/api.go index 8228e12..ba9fcd7 100644 --- a/internal/server/api/api.go +++ b/internal/server/api/api.go @@ -26,7 +26,7 @@ func Kickoff(logger *slog.Logger, env bootstrap.Environment, db *gorm.DB) { if env.Authentication { slog.Debug("injecting authentication middleware") // only log when actually doign the thing it logs to do - r.Use(middleware.AuthMiddleware()) + r.Use(middleware.AuthMiddleware(db)) } api := r.Group("/api") diff --git a/internal/server/api/assets/response.go b/internal/server/api/assets/response.go index fc05849..e2a4147 100644 --- a/internal/server/api/assets/response.go +++ b/internal/server/api/assets/response.go @@ -1,7 +1,6 @@ package assets import ( - "fmt" "net/http" "time" @@ -50,27 +49,12 @@ func FileDownloadResponse(c *gin.Context, fp string) { c.File(fp) } -func BasicResponse(c *gin.Context, data ...any) { - fmt.Println(len(data)) - switch len(data) { - case 0: - // empty slice so just an ok message - c.JSON(http.StatusOK, ResponseObject{ - Msg: OkMes, - }) +func BasicResponse(c *gin.Context, data any) { // single object in our slice (even if the object itself is a slice) - case 1: - c.JSON(http.StatusOK, ResponseObject{ - Msg: OkMes, - Data: data[0], - }) - // multiple objects inside our slice - default: - c.JSON(http.StatusOK, ResponseObject{ - Msg: OkMes, - Data: data, - }) - } + c.JSON(http.StatusOK, ResponseObject{ + Msg: OkMes, + Data: data, + }) } func CreationResponse(c *gin.Context, data any) { diff --git a/internal/server/api/middleware/middleware.go b/internal/server/api/middleware/middleware.go index f213d9c..feca7ea 100644 --- a/internal/server/api/middleware/middleware.go +++ b/internal/server/api/middleware/middleware.go @@ -4,10 +4,13 @@ import ( "log/slog" "net/http" "orbits-server/internal/server/api/assets" + "orbits-server/internal/server/service" + "orbits-server/internal/shared/security" "strings" "time" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) func SlogMiddleware(logger *slog.Logger) gin.HandlerFunc { @@ -38,7 +41,9 @@ func SlogMiddleware(logger *slog.Logger) gin.HandlerFunc { } } -func AuthMiddleware() gin.HandlerFunc { +func AuthMiddleware(db *gorm.DB) gin.HandlerFunc { + keyService := service.NewKeyService(db) + return func(c *gin.Context) { authorizationHeader := c.GetHeader("Authorization") if len(authorizationHeader) == 0 { @@ -58,6 +63,22 @@ func AuthMiddleware() gin.HandlerFunc { return } - //givenKey := headerParts[1] + candidateKey := headerParts[1] + storedKeys, err := keyService.ListValidKeyHashes() + if err != nil { + slog.Error("failed to retrieve key hashes", "error", err) + assets.InternalErrorResponse(c) + } + + for _, key := range storedKeys { + if match := security.CompareKey(key, candidateKey); match { + c.Next() + return + } + } + + c.AbortWithStatusJSON(http.StatusUnauthorized, assets.ResponseObject{ + Msg: "invalid key", + }) } } diff --git a/internal/server/service/keyservice.go b/internal/server/service/keyservice.go index f6fb1e5..6f1fe98 100644 --- a/internal/server/service/keyservice.go +++ b/internal/server/service/keyservice.go @@ -23,6 +23,20 @@ func NewKeyService(db *gorm.DB) *KeyService { } } +func (s *KeyService) ListValidKeyHashes() ([]string, error) { + keyRecords, err := database.ListKeys(s.db) + if err != nil { + return nil, err + } + + hashList := make([]string, 0, len(keyRecords)) + for _, k := range keyRecords { + hashList = append(hashList, k.KeyHash) + } + + return hashList, nil +} + func (s *KeyService) Create(name string, expiresAt time.Time) (assets.KeyResponse, error) { keyContent := security.GenerateChars(accessKeyLen) diff --git a/internal/shared/security/hash.go b/internal/shared/security/hash.go index f026eeb..2ca499d 100644 --- a/internal/shared/security/hash.go +++ b/internal/shared/security/hash.go @@ -30,6 +30,7 @@ func HashFileReader(r io.Reader) (string, error) { return base64.StdEncoding.EncodeToString(h.Sum(nil)), nil } +// we use argon2 for key hashing - since it won the key encryption "war" func HashKey(key string) (string, error) { salt := make([]byte, argonSaltLen) rand.Read(salt) @@ -46,8 +47,8 @@ func HashKey(key string) (string, error) { return encoded, nil } -func CompareKey(key, candidate string) bool { - parts := strings.Split(candidate, ":") +func CompareKey(storedKey, candidate string) bool { + parts := strings.Split(storedKey, ":") if len(parts) != 3 { return false } @@ -65,7 +66,7 @@ func CompareKey(key, candidate string) bool { return false } - actual := argon2.IDKey([]byte(key), salt, argonTime, argonMemory, argonThreads, argonKeyLen) + actual := argon2.IDKey([]byte(candidate), salt, argonTime, argonMemory, argonThreads, argonKeyLen) return subtle.ConstantTimeCompare(actual, expected) == 1 }