修改只有一次验证机会的bug

This commit is contained in:
2026-03-26 02:04:41 +08:00
parent b18901a2d5
commit ca796a5fd2
5 changed files with 94 additions and 64 deletions
@@ -29,7 +29,12 @@ class MemoryRepository(VerificationRepository):
with self._lock: with self._lock:
self._store[key] = (json.dumps(payload), expire_at) self._store[key] = (json.dumps(payload), expire_at)
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: def compare_and_consume(
self,
email: str,
purpose: VerificationPurpose,
code_hash: str,
) -> Optional[bool]:
key = self._key(email, purpose) key = self._key(email, purpose)
with self._lock: with self._lock:
@@ -44,13 +49,17 @@ class MemoryRepository(VerificationRepository):
del self._store[key] del self._store[key]
return None return None
del self._store[key] try:
payload = json.loads(value)
stored_hash = payload.get("code_hash")
except Exception:
return None
try: if stored_hash == code_hash:
payload = json.loads(value) del self._store[key]
return payload.get("code_hash") return True
except Exception: else:
return None return False
def incr(self, key: str, ttl: int) -> int: def incr(self, key: str, ttl: int) -> int:
now = time.time() now = time.time()
@@ -2,83 +2,90 @@ from functools import lru_cache
import os import os
import logging import logging
import json import json
import datetime
import redis import redis
from typing import Optional, TYPE_CHECKING from typing import Optional
from app.models.models import VerificationPurpose from app.models.models import VerificationPurpose
from app.core.verification.email.verificationRepository import VerificationRepository from app.core.verification.email.verificationRepository import VerificationRepository
from app.utils.redis_client import get_redis_client from app.utils.redis_client import get_redis_client
from app.core.security import hash_verification_code
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AUTH_CODE_REDIS_PREFIX = os.getenv("AUTH_CODE_REDIS_PREFIX", "insightradar:auth_code").strip() AUTH_CODE_REDIS_PREFIX = os.getenv(
"AUTH_CODE_REDIS_PREFIX", "insightradar:auth_code"
).strip()
class RedisRepository(VerificationRepository): class RedisRepository(VerificationRepository):
_consume_lua = """ local val = redis.call("GET", KEYS[1]) if val then redis.call("DEL", KEYS[1]) end return val """
_compare_and_consume_lua = """
local val = redis.call("GET", KEYS[1])
if not val then
return nil
end
local data = cjson.decode(val)
if data["code_hash"] == ARGV[1] then
redis.call("DEL", KEYS[1])
return 1
else
return 0
end
"""
def __init__(self, client: redis.Redis): def __init__(self, client: redis.Redis):
self.client = client self.client = client
self._consume_script = self.client.register_script(self._consume_lua) self._compare_script = self.client.register_script(
self._compare_and_consume_lua
)
def _key(self, email, purpose): def _key(self, email: str, purpose: VerificationPurpose) -> str:
return f"{AUTH_CODE_REDIS_PREFIX}:{purpose.value.lower()}:{email}:code" return f"{AUTH_CODE_REDIS_PREFIX}:{purpose.value.lower()}:{email}:code"
def set_code(self, email: str, purpose: VerificationPurpose, code_hash: str, ttl: int) -> None: def set_code(
"""store the code into the redis self,
email: str,
Args: purpose: VerificationPurpose,
email (str): email of user code_hash: str,
purpose (VerificationPurpose): purpose of the code, such as "login", "register" ttl: int,
code_hash: the hash of the code ) -> None:
ttl: duration of the code
"""
key = self._key(email, purpose) key = self._key(email, purpose)
payload = json.dumps({ payload = json.dumps({
"code_hash": code_hash, "code_hash": code_hash,
"exp": datetime.datetime.now().timestamp()
}) })
self.client.set(key, payload, ex=ttl) self.client.set(key, payload, ex=ttl)
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: def compare_and_consume(
"""consume the code of email self,
email: str,
Args: purpose: VerificationPurpose,
email (str): email of user code_hash: str,
purpose (VerificationPurpose): purpose of the code, such as "login", "register" ) -> Optional[bool]:
Returns:
_type_: if email has a code which has not been consumed, return the code, else return None
"""
key = self._key(email, purpose) key = self._key(email, purpose)
data = self._consume_script(keys=[key])
if not data: result = self._compare_script(
return None keys=[key],
args=[code_hash],
)
if result is None:
return None
if result == 1:
return True
return False
try:
payload = json.loads(data) # type: ignore
return payload.get("code_hash")
except Exception:
return None
def incr(self, key: str, ttl: int) -> int: def incr(self, key: str, ttl: int) -> int:
super().incr(key, ttl)
value = self.client.incr(key) value = self.client.incr(key)
if value == 1: if value == 1:
self.client.expire(key, ttl) self.client.expire(key, ttl)
return value # type: ignore return int(value) # type: ignore
@lru_cache @lru_cache
def get_redis_repo(): def get_redis_repo():
@@ -1,4 +1,3 @@
# app/verification/backends/hybrid.py
from functools import lru_cache from functools import lru_cache
import logging import logging
@@ -28,14 +27,23 @@ class HybridRepository(VerificationRepository):
self.memory.set_code(email, purpose, code_hash, ttl) self.memory.set_code(email, purpose, code_hash, ttl)
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: def compare_and_consume(
self,
email: str,
purpose: VerificationPurpose,
code_hash: str,
) -> Optional[bool]:
if self.redis: if self.redis:
try: try:
return self.redis.consume_code(email, purpose) return self.redis.compare_and_consume(
email, purpose, code_hash
)
except Exception as e: except Exception as e:
logger.warning("Redis consume_code failed, fallback to memory: %s", e) logger.warning(
"Redis compare_and_consume failed, fallback to memory: %s", e
)
return self.memory.consume_code(email, purpose) return self.memory.compare_and_consume(email, purpose, code_hash)
def incr(self, key: str, ttl: int) -> int: def incr(self, key: str, ttl: int) -> int:
if self.redis: if self.redis:
@@ -20,7 +20,12 @@ class VerificationRepository(ABC):
@abstractmethod @abstractmethod
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]: def compare_and_consume(
self,
email: str,
purpose: VerificationPurpose,
code_hash: str,
) -> Optional[bool]:
"""consume the code atomically """consume the code atomically
Args: Args:
@@ -28,7 +33,7 @@ class VerificationRepository(ABC):
purpose (VerificationPurpose): the purpose of the code, such as, "login", "register" purpose (VerificationPurpose): the purpose of the code, such as, "login", "register"
Returns: Returns:
Optional[str]: if success return the code, else return None Optional[str]: if success return the true
""" """
pass pass
@@ -62,15 +62,16 @@ class EmailVerificationService:
def verify_code(self,email: str, code: str, purpose: VerificationPurpose): def verify_code(self,email: str, code: str, purpose: VerificationPurpose):
email = email.lower() email = email.lower()
code_hash = hash_verification_code(code)
stored_hash: Optional[str] = self.repo.consume_code(email, purpose) stored = self.repo.compare_and_consume(email, purpose, code_hash)
if not stored_hash: if stored == False:
raise CodeExpiredError("Code expired or not found")
if stored_hash != hash_verification_code(code):
raise CodeInvalidError("Invalid code") raise CodeInvalidError("Invalid code")
if not stored:
raise CodeExpiredError("Code expired or not found")
return True return True
@lru_cache @lru_cache