198 lines
6.2 KiB
Python
198 lines
6.2 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from sqlalchemy.orm import Session
|
|
from datetime import timedelta
|
|
from pydantic import BaseModel, EmailStr
|
|
from app.db.database import get_db
|
|
from app.db.models import User
|
|
from app.core.security import verify_password, get_password_hash, create_access_token, decode_token
|
|
from app.core.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
security = HTTPBearer()
|
|
|
|
# Pydantic models
|
|
class UserRegistration(BaseModel):
|
|
username: str
|
|
email: EmailStr
|
|
password: str
|
|
fullName: str = None
|
|
|
|
class UserLogin(BaseModel):
|
|
username: str
|
|
password: str
|
|
|
|
class UserResponse(BaseModel):
|
|
id: int
|
|
username: str
|
|
email: str
|
|
fullName: str = None
|
|
createdAt: str
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class TokenResponse(BaseModel):
|
|
access_token: str
|
|
token_type: str = "bearer"
|
|
user: UserResponse
|
|
|
|
# Dependency to get current user
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
db: Session = Depends(get_db)
|
|
) -> User:
|
|
"""Get current authenticated user"""
|
|
token = credentials.credentials
|
|
username = decode_token(token)
|
|
|
|
if username is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authentication credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
user = db.query(User).filter(User.username == username).first()
|
|
if user is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return user
|
|
|
|
@router.post("/register", response_model=TokenResponse)
|
|
async def register_user(user_data: UserRegistration, db: Session = Depends(get_db)):
|
|
"""Register a new user"""
|
|
try:
|
|
# Check if username already exists
|
|
existing_user = db.query(User).filter(User.username == user_data.username).first()
|
|
if existing_user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Username already registered"
|
|
)
|
|
|
|
# Check if email already exists
|
|
existing_email = db.query(User).filter(User.email == user_data.email).first()
|
|
if existing_email:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Email already registered"
|
|
)
|
|
|
|
# Create new user
|
|
hashed_password = get_password_hash(user_data.password)
|
|
new_user = User(
|
|
username=user_data.username,
|
|
email=user_data.email,
|
|
hashed_password=hashed_password,
|
|
full_name=user_data.fullName
|
|
)
|
|
|
|
db.add(new_user)
|
|
db.commit()
|
|
db.refresh(new_user)
|
|
|
|
# Create access token
|
|
access_token = create_access_token(
|
|
subject=new_user.username,
|
|
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
)
|
|
|
|
logger.info(f"New user registered: {new_user.username}")
|
|
|
|
return TokenResponse(
|
|
access_token=access_token,
|
|
user=UserResponse(
|
|
id=new_user.id,
|
|
username=new_user.username,
|
|
email=new_user.email,
|
|
fullName=new_user.full_name,
|
|
createdAt=new_user.created_at.isoformat()
|
|
)
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Registration error: {e}")
|
|
db.rollback()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Registration failed"
|
|
)
|
|
|
|
@router.post("/login", response_model=TokenResponse)
|
|
async def login_user(login_data: UserLogin, db: Session = Depends(get_db)):
|
|
"""Authenticate user and return token"""
|
|
try:
|
|
# Find user
|
|
user = db.query(User).filter(User.username == login_data.username).first()
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid username or password"
|
|
)
|
|
|
|
# Verify password
|
|
if not verify_password(login_data.password, user.hashed_password):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid username or password"
|
|
)
|
|
|
|
# Update last login
|
|
from datetime import datetime
|
|
user.last_login = datetime.utcnow()
|
|
db.commit()
|
|
|
|
# Create access token
|
|
access_token = create_access_token(
|
|
subject=user.username,
|
|
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
)
|
|
|
|
logger.info(f"User logged in: {user.username}")
|
|
|
|
return TokenResponse(
|
|
access_token=access_token,
|
|
user=UserResponse(
|
|
id=user.id,
|
|
username=user.username,
|
|
email=user.email,
|
|
fullName=user.full_name,
|
|
createdAt=user.created_at.isoformat()
|
|
)
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Login error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Login failed"
|
|
)
|
|
|
|
@router.post("/logout")
|
|
async def logout_user(current_user: User = Depends(get_current_user)):
|
|
"""Logout user (client-side token removal)"""
|
|
logger.info(f"User logged out: {current_user.username}")
|
|
return {"message": "Successfully logged out"}
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def get_current_user_info(current_user: User = Depends(get_current_user)):
|
|
"""Get current user information"""
|
|
return UserResponse(
|
|
id=current_user.id,
|
|
username=current_user.username,
|
|
email=current_user.email,
|
|
fullName=current_user.full_name,
|
|
createdAt=current_user.created_at.isoformat()
|
|
) |