Add JWT-based user authentication to backend
- Create User model with bcrypt password hashing - Add auth routes: register, login, refresh, me - Implement JWT access and refresh tokens - Add get_current_user dependency for protected routes - Update Task model with user_id foreign key for data isolation - Update TaskService to filter tasks by authenticated user - Add auth configuration (secret key, token expiry)
This commit is contained in:
parent
5cd79e096d
commit
911f192c38
5
backend/app/auth/__init__.py
Normal file
5
backend/app/auth/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .models import User
|
||||||
|
from .routes import router as auth_router
|
||||||
|
from .dependencies import get_current_user
|
||||||
|
|
||||||
|
__all__ = ["User", "auth_router", "get_current_user"]
|
||||||
46
backend/app/auth/dependencies.py
Normal file
46
backend/app/auth/dependencies.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ..db import get_db
|
||||||
|
from .models import User
|
||||||
|
from .utils import decode_token
|
||||||
|
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> User:
|
||||||
|
token = credentials.credentials
|
||||||
|
payload = decode_token(token)
|
||||||
|
|
||||||
|
if payload is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token type",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
if user_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token payload",
|
||||||
|
)
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
18
backend/app/auth/models.py
Normal file
18
backend/app/auth/models.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Column, String, DateTime
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from ..db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
email = Column(String, unique=True, nullable=False, index=True)
|
||||||
|
password_hash = Column(String, nullable=False)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
tasks = relationship("Task", back_populates="owner", cascade="all, delete-orphan")
|
||||||
67
backend/app/auth/routes.py
Normal file
67
backend/app/auth/routes.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ..db import get_db
|
||||||
|
from .schemas import UserCreate, UserLogin, UserResponse, TokenResponse, TokenRefresh
|
||||||
|
from .services import AuthService
|
||||||
|
from .dependencies import get_current_user
|
||||||
|
from .models import User
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_service(db: Session = Depends(get_db)) -> AuthService:
|
||||||
|
return AuthService(db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=201)
|
||||||
|
def register(
|
||||||
|
user_data: UserCreate,
|
||||||
|
service: AuthService = Depends(get_auth_service),
|
||||||
|
):
|
||||||
|
existing_user = service.get_user_by_email(user_data.email)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered",
|
||||||
|
)
|
||||||
|
|
||||||
|
user = service.create_user(user_data)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
def login(
|
||||||
|
credentials: UserLogin,
|
||||||
|
service: AuthService = Depends(get_auth_service),
|
||||||
|
):
|
||||||
|
user = service.authenticate_user(credentials.email, credentials.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
)
|
||||||
|
|
||||||
|
return service.create_tokens(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
def refresh(
|
||||||
|
token_data: TokenRefresh,
|
||||||
|
service: AuthService = Depends(get_auth_service),
|
||||||
|
):
|
||||||
|
tokens = service.refresh_tokens(token_data.refresh_token)
|
||||||
|
if tokens is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
def get_current_user_info(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
return current_user
|
||||||
39
backend/app/auth/schemas.py
Normal file
39
backend/app/auth/schemas.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
name: str = Field(..., min_length=1, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
password: str = Field(..., min_length=8, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class UserLogin(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(UserBase):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRefresh(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
email: Optional[str] = None
|
||||||
58
backend/app/auth/services.py
Normal file
58
backend/app/auth/services.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from .models import User
|
||||||
|
from .schemas import UserCreate
|
||||||
|
from .utils import get_password_hash, verify_password, create_access_token, create_refresh_token, decode_token
|
||||||
|
|
||||||
|
|
||||||
|
class AuthService:
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def get_user_by_email(self, email: str) -> Optional[User]:
|
||||||
|
return self.db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
|
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
||||||
|
return self.db.query(User).filter(User.id == user_id).first()
|
||||||
|
|
||||||
|
def create_user(self, user_data: UserCreate) -> User:
|
||||||
|
password_hash = get_password_hash(user_data.password)
|
||||||
|
user = User(
|
||||||
|
email=user_data.email,
|
||||||
|
name=user_data.name,
|
||||||
|
password_hash=password_hash,
|
||||||
|
)
|
||||||
|
self.db.add(user)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||||
|
user = self.get_user_by_email(email)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
if not verify_password(password, user.password_hash):
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
def create_tokens(self, user: User) -> dict:
|
||||||
|
access_token = create_access_token(data={"sub": user.id, "email": user.email})
|
||||||
|
refresh_token = create_refresh_token(data={"sub": user.id, "email": user.email})
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
}
|
||||||
|
|
||||||
|
def refresh_tokens(self, refresh_token: str) -> Optional[dict]:
|
||||||
|
payload = decode_token(refresh_token)
|
||||||
|
if payload is None or payload.get("type") != "refresh":
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
user = self.get_user_by_id(user_id)
|
||||||
|
if user is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.create_tokens(user)
|
||||||
45
backend/app/auth/utils.py
Normal file
45
backend/app/auth/utils.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from ..config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
return bcrypt.checkpw(
|
||||||
|
plain_password.encode('utf-8'),
|
||||||
|
hashed_password.encode('utf-8')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_password_hash(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(
|
||||||
|
password.encode('utf-8'),
|
||||||
|
bcrypt.gensalt()
|
||||||
|
).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||||
|
to_encode = data.copy()
|
||||||
|
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=settings.access_token_expire_minutes))
|
||||||
|
to_encode.update({"exp": expire, "type": "access"})
|
||||||
|
return jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(data: dict) -> str:
|
||||||
|
to_encode = data.copy()
|
||||||
|
expire = datetime.utcnow() + timedelta(days=settings.refresh_token_expire_days)
|
||||||
|
to_encode.update({"exp": expire, "type": "refresh"})
|
||||||
|
return jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token: str) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
||||||
|
return payload
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
@ -6,6 +6,11 @@ class Settings(BaseSettings):
|
|||||||
database_url: str = "sqlite:///./tasks.db"
|
database_url: str = "sqlite:///./tasks.db"
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
|
||||||
|
secret_key: str = "your-secret-key-change-in-production-min-32-chars"
|
||||||
|
algorithm: str = "HS256"
|
||||||
|
access_token_expire_minutes: int = 30
|
||||||
|
refresh_token_expire_days: int = 7
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,10 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from .db import Base, engine
|
from .db import Base, engine
|
||||||
|
from .models import Task # noqa: F401 - needed for table creation
|
||||||
|
from .auth.models import User # noqa: F401 - needed for table creation
|
||||||
from .routes import router
|
from .routes import router
|
||||||
|
from .auth.routes import router as auth_router
|
||||||
from .config import get_settings
|
from .config import get_settings
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
@ -30,6 +33,7 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.include_router(auth_router)
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Index
|
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from .db import Base
|
from .db import Base
|
||||||
|
|
||||||
@ -17,3 +18,6 @@ class Task(Base):
|
|||||||
is_done = Column(Boolean, default=False)
|
is_done = Column(Boolean, default=False)
|
||||||
created_at = Column(DateTime, default=datetime.utcnow)
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
user_id = Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
owner = relationship("User", back_populates="tasks")
|
||||||
|
|||||||
@ -11,12 +11,17 @@ from .schemas import (
|
|||||||
HealthResponse,
|
HealthResponse,
|
||||||
)
|
)
|
||||||
from .services import TaskService
|
from .services import TaskService
|
||||||
|
from .auth.dependencies import get_current_user
|
||||||
|
from .auth.models import User
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def get_task_service(db: Session = Depends(get_db)) -> TaskService:
|
def get_task_service(
|
||||||
return TaskService(db)
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> TaskService:
|
||||||
|
return TaskService(db, current_user.id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health", response_model=HealthResponse)
|
@router.get("/health", response_model=HealthResponse)
|
||||||
|
|||||||
@ -6,13 +6,17 @@ from .schemas import TaskCreate, TaskUpdate
|
|||||||
|
|
||||||
|
|
||||||
class TaskService:
|
class TaskService:
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session, user_id: str):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.user_id = user_id
|
||||||
|
|
||||||
def get_tasks_by_date(
|
def get_tasks_by_date(
|
||||||
self, date: str, status: Optional[str] = None
|
self, date: str, status: Optional[str] = None
|
||||||
) -> list[Task]:
|
) -> list[Task]:
|
||||||
query = self.db.query(Task).filter(Task.date == date)
|
query = self.db.query(Task).filter(
|
||||||
|
Task.date == date,
|
||||||
|
Task.user_id == self.user_id
|
||||||
|
)
|
||||||
|
|
||||||
if status == "active":
|
if status == "active":
|
||||||
query = query.filter(Task.is_done == False)
|
query = query.filter(Task.is_done == False)
|
||||||
@ -22,7 +26,10 @@ class TaskService:
|
|||||||
return query.order_by(Task.created_at.desc()).all()
|
return query.order_by(Task.created_at.desc()).all()
|
||||||
|
|
||||||
def get_task_by_id(self, task_id: str) -> Optional[Task]:
|
def get_task_by_id(self, task_id: str) -> Optional[Task]:
|
||||||
return self.db.query(Task).filter(Task.id == task_id).first()
|
return self.db.query(Task).filter(
|
||||||
|
Task.id == task_id,
|
||||||
|
Task.user_id == self.user_id
|
||||||
|
).first()
|
||||||
|
|
||||||
def create_task(self, task_data: TaskCreate) -> Task:
|
def create_task(self, task_data: TaskCreate) -> Task:
|
||||||
task = Task(
|
task = Task(
|
||||||
@ -31,6 +38,7 @@ class TaskService:
|
|||||||
date=task_data.date,
|
date=task_data.date,
|
||||||
time=task_data.time,
|
time=task_data.time,
|
||||||
priority=task_data.priority.value,
|
priority=task_data.priority.value,
|
||||||
|
user_id=self.user_id,
|
||||||
)
|
)
|
||||||
self.db.add(task)
|
self.db.add(task)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
fastapi==0.115.0
|
fastapi==0.115.0
|
||||||
uvicorn[standard]==0.32.0
|
uvicorn[standard]==0.32.0
|
||||||
sqlalchemy==2.0.36
|
sqlalchemy==2.0.36
|
||||||
pydantic==2.10.0
|
pydantic[email]==2.10.0
|
||||||
pydantic-settings==2.6.0
|
pydantic-settings==2.6.0
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
pytest==8.3.0
|
pytest==8.3.0
|
||||||
httpx==0.28.0
|
httpx==0.28.0
|
||||||
|
python-jose[cryptography]==3.3.0
|
||||||
|
bcrypt==4.2.0
|
||||||
|
python-multipart==0.0.6
|
||||||
|
psycopg[binary]==3.2.3
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user