diff --git a/backend/app/core/verification/email/RespositoryImpl/MemoryRepository.py b/backend/app/core/verification/email/RespositoryImpl/MemoryRepository.py index 211a9fa..926579a 100644 --- a/backend/app/core/verification/email/RespositoryImpl/MemoryRepository.py +++ b/backend/app/core/verification/email/RespositoryImpl/MemoryRepository.py @@ -29,7 +29,12 @@ class MemoryRepository(VerificationRepository): with self._lock: self._store[key] = (json.dumps(payload), expire_at) - def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: + def compare_and_consume( + self, + email: str, + purpose: VerificationPurpose, + code_hash: str, + ) -> Optional[bool]: key = self._key(email, purpose) with self._lock: @@ -44,13 +49,17 @@ class MemoryRepository(VerificationRepository): del self._store[key] return None - del self._store[key] + try: + payload = json.loads(value) + stored_hash = payload.get("code_hash") + except Exception: + return None - try: - payload = json.loads(value) - return payload.get("code_hash") - except Exception: - return None + if stored_hash == code_hash: + del self._store[key] + return True + else: + return False def incr(self, key: str, ttl: int) -> int: now = time.time() diff --git a/backend/app/core/verification/email/RespositoryImpl/RedisRepository.py b/backend/app/core/verification/email/RespositoryImpl/RedisRepository.py index c2e15a8..679f8bc 100644 --- a/backend/app/core/verification/email/RespositoryImpl/RedisRepository.py +++ b/backend/app/core/verification/email/RespositoryImpl/RedisRepository.py @@ -2,83 +2,90 @@ from functools import lru_cache import os import logging import json -import datetime import redis -from typing import Optional, TYPE_CHECKING +from typing import Optional from app.models.models import VerificationPurpose from app.core.verification.email.verificationRepository import VerificationRepository from app.utils.redis_client import get_redis_client -from app.core.security import hash_verification_code logger = logging.getLogger(__name__) -AUTH_CODE_REDIS_PREFIX = os.getenv("AUTH_CODE_REDIS_PREFIX", "insightradar:auth_code").strip() +AUTH_CODE_REDIS_PREFIX = os.getenv( + "AUTH_CODE_REDIS_PREFIX", "insightradar:auth_code" +).strip() class RedisRepository(VerificationRepository): - _consume_lua = """ local val = redis.call("GET", KEYS[1]) if val then redis.call("DEL", KEYS[1]) end return val """ + _compare_and_consume_lua = """ + local val = redis.call("GET", KEYS[1]) + if not val then + return nil + end + + local data = cjson.decode(val) + + if data["code_hash"] == ARGV[1] then + redis.call("DEL", KEYS[1]) + return 1 + else + return 0 + end + """ def __init__(self, client: redis.Redis): self.client = client - self._consume_script = self.client.register_script(self._consume_lua) - - + self._compare_script = self.client.register_script( + self._compare_and_consume_lua + ) - def _key(self, email, purpose): + def _key(self, email: str, purpose: VerificationPurpose) -> str: return f"{AUTH_CODE_REDIS_PREFIX}:{purpose.value.lower()}:{email}:code" - - def set_code(self, email: str, purpose: VerificationPurpose, code_hash: str, ttl: int) -> None: - """store the code into the redis - - Args: - email (str): email of user - purpose (VerificationPurpose): purpose of the code, such as "login", "register" - code_hash: the hash of the code - ttl: duration of the code - - """ + + def set_code( + self, + email: str, + purpose: VerificationPurpose, + code_hash: str, + ttl: int, + ) -> None: key = self._key(email, purpose) payload = json.dumps({ "code_hash": code_hash, - "exp": datetime.datetime.now().timestamp() }) self.client.set(key, payload, ex=ttl) - def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: - """consume the code of email - - Args: - email (str): email of user - purpose (VerificationPurpose): purpose of the code, such as "login", "register" - - Returns: - _type_: if email has a code which has not been consumed, return the code, else return None - """ - + def compare_and_consume( + self, + email: str, + purpose: VerificationPurpose, + code_hash: str, + ) -> Optional[bool]: key = self._key(email, purpose) - data = self._consume_script(keys=[key]) - if not data: - return None + result = self._compare_script( + keys=[key], + args=[code_hash], + ) + + if result is None: + return None + + if result == 1: + return True + + return False - try: - payload = json.loads(data) # type: ignore - return payload.get("code_hash") - except Exception: - return None - def incr(self, key: str, ttl: int) -> int: - super().incr(key, ttl) value = self.client.incr(key) - + if value == 1: self.client.expire(key, ttl) - - return value # type: ignore + + return int(value) # type: ignore @lru_cache def get_redis_repo(): diff --git a/backend/app/core/verification/email/RespositoryImpl/hybirdRepository.py b/backend/app/core/verification/email/RespositoryImpl/hybirdRepository.py index 5d2e67d..3b0eeae 100644 --- a/backend/app/core/verification/email/RespositoryImpl/hybirdRepository.py +++ b/backend/app/core/verification/email/RespositoryImpl/hybirdRepository.py @@ -1,4 +1,3 @@ -# app/verification/backends/hybrid.py from functools import lru_cache import logging @@ -28,14 +27,23 @@ class HybridRepository(VerificationRepository): self.memory.set_code(email, purpose, code_hash, ttl) - def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: + def compare_and_consume( + self, + email: str, + purpose: VerificationPurpose, + code_hash: str, + ) -> Optional[bool]: if self.redis: try: - return self.redis.consume_code(email, purpose) + return self.redis.compare_and_consume( + email, purpose, code_hash + ) except Exception as e: - logger.warning("Redis consume_code failed, fallback to memory: %s", e) + logger.warning( + "Redis compare_and_consume failed, fallback to memory: %s", e + ) - return self.memory.consume_code(email, purpose) + return self.memory.compare_and_consume(email, purpose, code_hash) def incr(self, key: str, ttl: int) -> int: if self.redis: diff --git a/backend/app/core/verification/email/verificationRepository.py b/backend/app/core/verification/email/verificationRepository.py index 9a74524..d1bc994 100644 --- a/backend/app/core/verification/email/verificationRepository.py +++ b/backend/app/core/verification/email/verificationRepository.py @@ -20,7 +20,12 @@ class VerificationRepository(ABC): @abstractmethod - def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: + def compare_and_consume( + self, + email: str, + purpose: VerificationPurpose, + code_hash: str, + ) -> Optional[bool]: """consume the code atomically Args: @@ -28,7 +33,7 @@ class VerificationRepository(ABC): purpose (VerificationPurpose): the purpose of the code, such as, "login", "register" Returns: - Optional[str]: if success return the code, else return None + Optional[str]: if success return the true """ pass diff --git a/backend/app/core/verification/email/verificationService.py b/backend/app/core/verification/email/verificationService.py index 230911e..eb5636c 100644 --- a/backend/app/core/verification/email/verificationService.py +++ b/backend/app/core/verification/email/verificationService.py @@ -62,15 +62,16 @@ class EmailVerificationService: def verify_code(self,email: str, code: str, purpose: VerificationPurpose): email = email.lower() + code_hash = hash_verification_code(code) - stored_hash: Optional[str] = self.repo.consume_code(email, purpose) + stored = self.repo.compare_and_consume(email, purpose, code_hash) - if not stored_hash: - raise CodeExpiredError("Code expired or not found") - - if stored_hash != hash_verification_code(code): + if stored == False: raise CodeInvalidError("Invalid code") + if not stored: + raise CodeExpiredError("Code expired or not found") + return True @lru_cache