// SPDX-FileCopyrightText: 2021 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

package fupoidc

import (
	"context"
	"crypto/rand"
	"encoding/hex"
	"errors"
	"fmt"
	"log"
	"net/http"
	"strings"

	"github.com/coreos/go-oidc/v3/oidc"
	"golang.org/x/oauth2"
	"git.lukegb.com/lukegb/depot/web/fup/fuphttp"
)

const (
	cookieOAuthState = "fup_state"
	cookieToken      = "fup_auth"
)

type Middleware struct {
	oidc *oidc.Provider

	verifier *oidc.IDTokenVerifier

	clientID     string
	clientSecret string
	baseURL      string
}

type Config struct {
	ClientID     string
	ClientSecret string
	ProviderURL  string

	BaseURL string
}

func New(ctx context.Context, c Config) (*Middleware, error) {
	provider, err := oidc.NewProvider(ctx, c.ProviderURL)
	if err != nil {
		return nil, fmt.Errorf("oidc.NewProvider: %w", err)
	}
	verifier := provider.Verifier(&oidc.Config{
		ClientID: c.ClientID,
	})

	return &Middleware{
		oidc:         provider,
		verifier:     verifier,
		clientID:     c.ClientID,
		clientSecret: c.ClientSecret,
		baseURL:      c.BaseURL,
	}, nil
}

func (mw *Middleware) oauth2Config() *oauth2.Config {
	redirectURL := mw.baseURL
	if strings.HasSuffix(redirectURL, "/") {
		redirectURL += "loginz"
	} else {
		redirectURL += "/loginz"
	}

	return &oauth2.Config{
		ClientID:     mw.clientID,
		ClientSecret: mw.clientSecret,
		RedirectURL:  redirectURL,

		Endpoint: mw.oidc.Endpoint(),

		Scopes: []string{oidc.ScopeOpenID},
	}
}

var errNoCredentials = errors.New("fupoidc: no auth credentials")

func (mw *Middleware) isAuthenticated(r *http.Request) error {
	// Check for credentials; either "Authorization: Bearer <token>" or "Cookie: fup_auth="
	var token string

	if cookie, err := r.Cookie(cookieToken); err == nil {
		token = cookie.Value
	}
	if authHeader := r.Header.Get("Authorization"); authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
		token = strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
	}

	if token == "" {
		return errNoCredentials
	}

	// Validate the credentials.
	ctx := r.Context()
	_, err := mw.verifier.Verify(ctx, token)
	return err
}

func (mw *Middleware) setStateCookie(rw http.ResponseWriter, value string) {
	cookie := &http.Cookie{
		Name:     cookieOAuthState,
		Value:    value,
		Secure:   true,
		HttpOnly: true,
		SameSite: http.SameSiteNoneMode,
	}
	if value == "" {
		cookie.MaxAge = -1
	}
	http.SetCookie(rw, cookie)
}

func (mw *Middleware) redirectToProvider(rw http.ResponseWriter, r *http.Request) {
	randBuf := make([]byte, 16)
	_, err := rand.Read(randBuf)
	if err != nil {
		http.Error(rw, "rand.Read in redirectToProvider", http.StatusInternalServerError)
		return
	}

	state := hex.EncodeToString(randBuf)
	mw.setStateCookie(rw, state)
	http.Redirect(rw, r, mw.oauth2Config().AuthCodeURL(state), http.StatusFound)
}

func (mw *Middleware) handleLoginzRequest(rw http.ResponseWriter, r *http.Request) {
	stateCookie, err := r.Cookie(cookieOAuthState)
	if err != nil {
		http.Error(rw, "fup_state cookie missing or invalid", http.StatusBadRequest)
		return
	}

	stateQuery := r.URL.Query().Get("state")
	if stateQuery == "" {
		http.Error(rw, "state query param missing or invalid", http.StatusBadRequest)
		return
	}
	if stateQuery != stateCookie.Value {
		http.Error(rw, "state mismatch", http.StatusBadRequest)
		return
	}

	mw.setStateCookie(rw, "") // Clear state cookie.

	// State is valid.
	ctx := r.Context()
	oauth2Token, err := mw.oauth2Config().Exchange(ctx, r.URL.Query().Get("code"))
	if err != nil {
		http.Error(rw, "code invalid", http.StatusBadRequest)
		return
	}

	rawIDToken, ok := oauth2Token.Extra("id_token").(string)
	if !ok {
		http.Error(rw, "missing id_token", http.StatusBadRequest)
		return
	}

	idToken, err := mw.verifier.Verify(ctx, rawIDToken)
	if err != nil {
		http.Error(rw, "id_token invalid", http.StatusBadRequest)
		return
	}

	var claims struct {
		Username string `json:"preferred_username"`
		Sub      string `json:"sub"`
	}
	if err := idToken.Claims(&claims); err != nil {
		http.Error(rw, "id_token claims invalid", http.StatusBadRequest)
		return
	}
	log.Printf("%s (%s) logged in with token %s", claims.Username, claims.Sub, rawIDToken)

	// Set the cookie.
	http.SetCookie(rw, &http.Cookie{
		Name:     cookieToken,
		Value:    rawIDToken,
		Secure:   true,
		HttpOnly: true,
		SameSite: http.SameSiteLaxMode,
	})
	http.Redirect(rw, r, mw.baseURL, http.StatusFound)
}

func (mw *Middleware) Middleware(handler http.Handler) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
		if strings.HasSuffix(r.URL.Path, "/loginz") {
			// We're up.
			mw.handleLoginzRequest(rw, r)
			return
		}

		ctx := r.Context()
		authnErr := mw.isAuthenticated(r)
		switch {
		case authnErr == nil:
			// Authn successful.
			handler.ServeHTTP(rw, r)
			return
		case authnErr == errNoCredentials && !fuphttp.IsMutate(ctx):
			// No credentials provided, but none needed.
			handler.ServeHTTP(rw, r)
			return
		case authnErr == errNoCredentials:
			// No credentials provided, but required.
			if fuphttp.IsAPIRequest(ctx) {
				http.Error(rw, "need OAuth credentials", http.StatusUnauthorized)
			} else {
				mw.redirectToProvider(rw, r)
			}
			return
		case authnErr != nil && !fuphttp.IsAPIRequest(ctx):
			// Credential validation error.
			log.Printf("invalid OIDC credential: %v", authnErr)
			if fuphttp.IsAPIRequest(ctx) {
				http.Error(rw, "oauth credentials invalid", http.StatusUnauthorized)
			} else {
				// try redirecting them through the login loop.
				mw.redirectToProvider(rw, r)
			}
			return
		default:
			// No credentials
			http.Error(rw, "request slipped through the auth cracks :(", http.StatusInternalServerError)
			return
		}
	})
}